Skip to main content

lash_tool_support/
lib.rs

1use lash_core::{ToolDefinition, ToolFailure, ToolFailureClass, ToolResult};
2use schemars::JsonSchema;
3use serde::de::DeserializeOwned;
4use serde::{Deserialize, Deserializer, Serialize};
5use std::future::Future;
6use std::path::{Component, Path, PathBuf};
7
8mod static_provider;
9#[cfg(feature = "lashlang")]
10pub use lash_lashlang_runtime::LashlangToolBinding;
11pub use static_provider::{StaticToolExecute, StaticToolProvider};
12
13#[cfg(not(feature = "lashlang"))]
14#[derive(Clone, Debug, Default)]
15pub struct LashlangToolBinding;
16
17#[cfg(not(feature = "lashlang"))]
18impl LashlangToolBinding {
19    pub fn new(
20        module_path: impl IntoIterator<Item = impl Into<String>>,
21        operation: impl Into<String>,
22    ) -> Self {
23        let _ = module_path
24            .into_iter()
25            .map(Into::into)
26            .collect::<Vec<String>>();
27        let _ = operation.into();
28        Self
29    }
30
31    pub fn with_authority_type(self, authority_type: impl Into<String>) -> Self {
32        let _ = authority_type.into();
33        self
34    }
35
36    pub fn with_aliases(self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
37        let _ = aliases.into_iter().map(Into::into).collect::<Vec<String>>();
38        self
39    }
40}
41
42pub trait ToolDefinitionLashlangExt {
43    fn with_lashlang_binding(self, lashlang_binding: LashlangToolBinding) -> Self;
44}
45
46#[cfg(feature = "lashlang")]
47impl ToolDefinitionLashlangExt for ToolDefinition {
48    fn with_lashlang_binding(self, lashlang_binding: LashlangToolBinding) -> Self {
49        lash_lashlang_runtime::ToolDefinitionLashlangExt::with_lashlang_binding(
50            self,
51            lashlang_binding,
52        )
53    }
54}
55
56#[cfg(not(feature = "lashlang"))]
57impl ToolDefinitionLashlangExt for ToolDefinition {
58    fn with_lashlang_binding(self, _lashlang_binding: LashlangToolBinding) -> Self {
59        self
60    }
61}
62
63/// Resolve a possibly-relative `path` against `base`, returning a lexically
64/// normalized [`PathBuf`].
65///
66/// Behavior:
67/// - Absolute `path` passes through unchanged (only normalized).
68/// - Relative `path` is joined onto `base`.
69/// - `.` and `..` components are collapsed *lexically* — purely by string
70///   manipulation, without touching the filesystem and without requiring the
71///   path (or its parents) to exist.
72///
73/// Lexical (rather than `std::fs::canonicalize`) resolution is the deliberate
74/// choice for tool path handling: write/patch tools must resolve targets that
75/// do not yet exist on disk, and canonicalization both fails for missing paths
76/// and silently rewrites symlinks. Tools that genuinely need symlink-real-path
77/// resolution for an existence/scope check should use [`canonicalize_under`]
78/// instead and accept that it requires the path to exist.
79pub fn resolve_under(base: &Path, path: &Path) -> PathBuf {
80    let joined = if path.is_absolute() {
81        path.to_path_buf()
82    } else {
83        base.join(path)
84    };
85    normalize_lexical(&joined)
86}
87
88/// Lexically collapse `.` and `..` components in `path` without touching the
89/// filesystem. Leading `..` components (that would escape the root) are
90/// preserved verbatim, matching `Path::join` intuitions for relative roots.
91pub fn normalize_lexical(path: &Path) -> PathBuf {
92    let mut normalized = PathBuf::new();
93    for component in path.components() {
94        match component {
95            Component::CurDir => {}
96            Component::ParentDir => {
97                if !normalized.pop() {
98                    normalized.push(component.as_os_str());
99                }
100            }
101            Component::Prefix(_) | Component::RootDir | Component::Normal(_) => {
102                normalized.push(component.as_os_str());
103            }
104        }
105    }
106    normalized
107}
108
109/// Resolve `path` against `base` (via [`resolve_under`]) and then canonicalize
110/// it on disk, resolving symlinks to their real path. Fails if the path does
111/// not exist. Use this only when a tool needs a real, existence-checked path
112/// (e.g. a security/scope decision or distinguishing a file from a directory);
113/// prefer [`resolve_under`] for write/patch targets that may not exist yet.
114pub fn canonicalize_under(base: &Path, path: &Path) -> std::io::Result<PathBuf> {
115    std::fs::canonicalize(resolve_under(base, path))
116}
117
118/// Render `path` relative to `base` for display, falling back to the file name
119/// (then the full path) when `path` is not under `base`. Backslashes are
120/// normalized to forward slashes so output is stable across platforms.
121pub fn display_relative(base: &Path, path: &Path) -> String {
122    let display = path
123        .strip_prefix(base)
124        .unwrap_or(path)
125        .display()
126        .to_string();
127    let display = if display.is_empty() {
128        path.file_name()
129            .and_then(|name| name.to_str())
130            .unwrap_or(".")
131            .to_string()
132    } else {
133        display
134    };
135    display.replace('\\', "/")
136}
137
138/// Shared preamble describing default filesystem discovery behavior.
139pub const FS_DEFAULTS_PREAMBLE: &str = "By default this excludes hidden entries, `.git`, and `node_modules`, and respects ignore files.";
140
141#[derive(Clone, Debug, Serialize, JsonSchema)]
142pub struct TruncationMeta {
143    pub shown: usize,
144    pub total: usize,
145    pub omitted: usize,
146}
147
148pub fn invalid_tool_args(message: impl Into<String>) -> ToolResult {
149    ToolResult::failure(ToolFailure::tool(
150        ToolFailureClass::InvalidRequest,
151        "invalid_tool_args",
152        message.into(),
153    ))
154}
155
156pub fn typed_tool_args<Args>(args: &serde_json::Value) -> Result<Args, ToolResult>
157where
158    Args: DeserializeOwned + JsonSchema,
159{
160    serde_json::from_value(args.clone())
161        .map_err(|err| invalid_tool_args(format!("Invalid tool arguments: {err}")))
162}
163
164pub fn typed_tool_ok<Output>(output: Output) -> ToolResult
165where
166    Output: Serialize + JsonSchema,
167{
168    match serde_json::to_value(output) {
169        Ok(value) => ToolResult::ok(value),
170        Err(err) => ToolResult::err_fmt(format_args!("Failed to serialize tool result: {err}")),
171    }
172}
173
174pub async fn execute_typed_tool<Args, Output, F, Fut>(
175    args: &serde_json::Value,
176    execute: F,
177) -> ToolResult
178where
179    Args: DeserializeOwned + JsonSchema,
180    Output: Serialize + JsonSchema,
181    F: FnOnce(Args) -> Fut,
182    Fut: Future<Output = Result<Output, ToolResult>>,
183{
184    let args = match typed_tool_args::<Args>(args) {
185        Ok(args) => args,
186        Err(err) => return err,
187    };
188    match execute(args).await {
189        Ok(output) => typed_tool_ok(output),
190        Err(err) => err,
191    }
192}
193
194pub async fn execute_typed_tool_result<Args, F, Fut>(
195    args: &serde_json::Value,
196    execute: F,
197) -> ToolResult
198where
199    Args: DeserializeOwned + JsonSchema,
200    F: FnOnce(Args) -> Fut,
201    Fut: Future<Output = ToolResult>,
202{
203    let args = match typed_tool_args::<Args>(args) {
204        Ok(args) => args,
205        Err(err) => return err,
206    };
207    execute(args).await
208}
209
210pub fn non_empty_string(value: &str, key: &str) -> Result<(), ToolResult> {
211    if value.is_empty() {
212        Err(invalid_tool_args(format!(
213            "Missing required parameter: {key}"
214        )))
215    } else {
216        Ok(())
217    }
218}
219
220pub fn default_path_dot() -> String {
221    ".".to_string()
222}
223
224#[derive(Clone, Debug, Deserialize, JsonSchema)]
225#[serde(untagged)]
226pub enum OptionalUsizeArg {
227    Value(usize),
228    NoneString(String),
229    Null(()),
230}
231
232impl OptionalUsizeArg {
233    pub fn into_option(self, key: &str, min: usize) -> Result<Option<usize>, ToolResult> {
234        match self {
235            Self::Value(value) if value >= min => Ok(Some(value)),
236            Self::Value(_) => Err(invalid_tool_args(format!(
237                "Invalid {key}: must be >= {min}, or use null/\"none\" for no cap"
238            ))),
239            Self::NoneString(value) if value.eq_ignore_ascii_case("none") => Ok(None),
240            Self::NoneString(_) => Err(invalid_tool_args(format!(
241                "Invalid {key}: expected int, null, or \"none\""
242            ))),
243            Self::Null(()) => Ok(None),
244        }
245    }
246}
247
248pub fn deserialize_optional_usize_none<'de, D>(deserializer: D) -> Result<Option<usize>, D::Error>
249where
250    D: Deserializer<'de>,
251{
252    #[derive(Deserialize)]
253    #[serde(untagged)]
254    enum OptionalUsize {
255        Int(usize),
256        String(String),
257        Null,
258    }
259
260    match Option::<OptionalUsize>::deserialize(deserializer)? {
261        None | Some(OptionalUsize::Null) => Ok(None),
262        Some(OptionalUsize::Int(value)) => Ok(Some(value)),
263        Some(OptionalUsize::String(value)) if value.eq_ignore_ascii_case("none") => Ok(None),
264        Some(OptionalUsize::String(_)) => Err(serde::de::Error::custom(
265            "expected integer, null, or \"none\"",
266        )),
267    }
268}
269
270pub fn default_glob_limit() -> OptionalUsizeArg {
271    OptionalUsizeArg::Value(100)
272}
273
274/// Extract a required non-empty string arg, or return ToolResult::err.
275pub fn require_str<'a>(args: &'a serde_json::Value, key: &str) -> Result<&'a str, ToolResult> {
276    args.get(key)
277        .and_then(|v| v.as_str())
278        .filter(|s| !s.is_empty())
279        .ok_or_else(|| ToolResult::err_fmt(format_args!("Missing required parameter: {key}")))
280}
281
282/// Parse optional bool arg with a default.
283pub fn parse_optional_bool(
284    args: &serde_json::Value,
285    key: &str,
286    default: bool,
287) -> Result<bool, ToolResult> {
288    match args.get(key) {
289        None => Ok(default),
290        Some(v) if v.is_null() => Ok(default),
291        Some(v) => match v.as_bool() {
292            Some(b) => Ok(b),
293            None => Err(ToolResult::err_fmt(format_args!(
294                "Invalid {key}: expected bool"
295            ))),
296        },
297    }
298}
299
300/// Parse an optional positive integer arg.
301/// Accepts `null` or `"none"` when `allow_none` is true.
302pub fn parse_optional_usize_arg(
303    args: &serde_json::Value,
304    key: &str,
305    default: Option<usize>,
306    allow_none: bool,
307    min: usize,
308) -> Result<Option<usize>, ToolResult> {
309    match args.get(key) {
310        None => Ok(default),
311        Some(v) if v.is_null() => {
312            if allow_none {
313                Ok(None)
314            } else {
315                Err(ToolResult::err_fmt(format_args!(
316                    "Invalid {key}: expected int >= {min}"
317                )))
318            }
319        }
320        Some(v) => {
321            if let Some(s) = v.as_str() {
322                if allow_none && s.eq_ignore_ascii_case("none") {
323                    return Ok(None);
324                }
325                return Err(ToolResult::err_fmt(format_args!(
326                    "Invalid {key}: expected int{}",
327                    if allow_none {
328                        ", null, or \"none\""
329                    } else {
330                        ""
331                    }
332                )));
333            }
334            let n = v.as_u64().ok_or_else(|| {
335                ToolResult::err_fmt(format_args!(
336                    "Invalid {key}: expected int{}",
337                    if allow_none {
338                        ", null, or \"none\""
339                    } else {
340                        ""
341                    }
342                ))
343            })? as usize;
344            if n < min {
345                return Err(ToolResult::err_fmt(format_args!(
346                    "Invalid {key}: must be >= {min}{}",
347                    if allow_none {
348                        ", or use null/\"none\" for no cap"
349                    } else {
350                        ""
351                    }
352                )));
353            }
354            Ok(Some(n))
355        }
356    }
357}
358
359pub fn object_schema(properties: serde_json::Value, required: &[&str]) -> serde_json::Value {
360    serde_json::json!({
361        "type": "object",
362        "properties": properties,
363        "required": required,
364        "additionalProperties": false,
365    })
366}
367
368pub fn lashlang_binding(
369    module_path: impl IntoIterator<Item = impl Into<String>>,
370    operation: impl Into<String>,
371    aliases: &[&str],
372) -> LashlangToolBinding {
373    LashlangToolBinding::new(module_path, operation).with_aliases(aliases.iter().copied())
374}
375
376/// Run blocking filesystem work off the async runtime.
377pub async fn run_blocking<F>(f: F) -> ToolResult
378where
379    F: FnOnce() -> ToolResult + Send + 'static,
380{
381    match tokio::task::spawn_blocking(f).await {
382        Ok(result) => result,
383        Err(e) => ToolResult::err_fmt(format_args!("blocking task failed: {e}")),
384    }
385}
386
387/// Run blocking work off the async runtime and return a typed value.
388pub async fn run_blocking_value<F, T>(f: F) -> Result<T, String>
389where
390    F: FnOnce() -> T + Send + 'static,
391    T: Send + 'static,
392{
393    tokio::task::spawn_blocking(f)
394        .await
395        .map_err(|err| format!("blocking task failed: {err}"))
396}
397
398pub fn rg_file_list(
399    base: &Path,
400    show_hidden_entries: bool,
401    respect_ignore_files: bool,
402    max_depth: Option<usize>,
403    globs: &[String],
404) -> Result<Vec<PathBuf>, ToolResult> {
405    if is_default_excluded_entry(base) {
406        return Ok(Vec::new());
407    }
408
409    let mut builder = ignore::WalkBuilder::new(base);
410    builder
411        .hidden(!show_hidden_entries)
412        .max_depth(max_depth)
413        .filter_entry(|entry| !is_default_excluded_entry(entry.path()));
414
415    if respect_ignore_files {
416        builder.git_ignore(true).git_exclude(true).git_global(true);
417        builder.require_git(true);
418    } else {
419        builder
420            .git_ignore(false)
421            .git_exclude(false)
422            .git_global(false)
423            .ignore(false)
424            .parents(false)
425            .require_git(false);
426    }
427
428    if !globs.is_empty() {
429        let mut override_builder = ignore::overrides::OverrideBuilder::new(base);
430        for glob in globs {
431            override_builder.add(glob).map_err(|err| {
432                ToolResult::err_fmt(format_args!(
433                    "invalid ignore glob for {}: {err}",
434                    base.display()
435                ))
436            })?;
437        }
438
439        let overrides = override_builder.build().map_err(|err| {
440            ToolResult::err_fmt(format_args!(
441                "failed to build ignore globs for {}: {err}",
442                base.display()
443            ))
444        })?;
445        builder.overrides(overrides);
446    }
447
448    let files = builder
449        .build()
450        .filter_map(Result::ok)
451        .filter(|entry| entry.path() != base)
452        .filter(|entry| !is_default_excluded_entry(entry.path()))
453        .map(ignore::DirEntry::into_path)
454        .collect();
455    Ok(files)
456}
457
458fn is_default_excluded_entry(path: &Path) -> bool {
459    path.file_name().is_some_and(|name| {
460        let name = name.to_string_lossy();
461        matches!(name.as_ref(), ".git" | "node_modules")
462    })
463}
464
465/// Generate a compact unified diff between old and new content.
466/// Truncates to `max_lines` lines if the diff is too long.
467pub fn compact_diff(old: &str, new: &str, path: &str, max_lines: usize) -> String {
468    let diff = similar::TextDiff::from_lines(old, new);
469    let unified = diff
470        .unified_diff()
471        .header(&format!("a/{path}"), &format!("b/{path}"))
472        .to_string();
473    if unified.is_empty() {
474        return String::new();
475    }
476    let lines: Vec<&str> = unified.lines().collect();
477    if lines.len() <= max_lines {
478        unified
479    } else {
480        let mut truncated: String = lines[..max_lines].join("\n");
481        truncated.push_str(&format!("\n... ({} more lines)", lines.len() - max_lines));
482        truncated
483    }
484}