Skip to main content

lash_tool_support/
lib.rs

1use lash_core::{ToolDefinition, ToolFailure, ToolFailureClass, ToolResult};
2use schemars::JsonSchema;
3use serde::de::DeserializeOwned;
4use serde::{Deserialize, Deserializer, Serialize};
5use std::future::Future;
6use std::io::{BufRead, BufReader};
7use std::path::{Component, Path, PathBuf};
8use std::time::{SystemTime, UNIX_EPOCH};
9
10mod static_provider;
11#[cfg(feature = "lashlang")]
12pub use lash_lashlang_runtime::LashlangToolBinding;
13pub use static_provider::{StaticToolExecute, StaticToolProvider};
14
15#[cfg(not(feature = "lashlang"))]
16#[derive(Clone, Debug, Default)]
17pub struct LashlangToolBinding;
18
19#[cfg(not(feature = "lashlang"))]
20impl LashlangToolBinding {
21    pub fn new(
22        module_path: impl IntoIterator<Item = impl Into<String>>,
23        operation: impl Into<String>,
24    ) -> Self {
25        let _ = module_path
26            .into_iter()
27            .map(Into::into)
28            .collect::<Vec<String>>();
29        let _ = operation.into();
30        Self
31    }
32
33    pub fn with_authority_type(self, authority_type: impl Into<String>) -> Self {
34        let _ = authority_type.into();
35        self
36    }
37
38    pub fn with_aliases(self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
39        let _ = aliases.into_iter().map(Into::into).collect::<Vec<String>>();
40        self
41    }
42}
43
44pub trait ToolDefinitionLashlangExt {
45    fn with_lashlang_binding(self, lashlang_binding: LashlangToolBinding) -> Self;
46}
47
48#[cfg(feature = "lashlang")]
49impl ToolDefinitionLashlangExt for ToolDefinition {
50    fn with_lashlang_binding(self, lashlang_binding: LashlangToolBinding) -> Self {
51        lash_lashlang_runtime::ToolDefinitionLashlangExt::with_lashlang_binding(
52            self,
53            lashlang_binding,
54        )
55    }
56}
57
58#[cfg(not(feature = "lashlang"))]
59impl ToolDefinitionLashlangExt for ToolDefinition {
60    fn with_lashlang_binding(self, _lashlang_binding: LashlangToolBinding) -> Self {
61        self
62    }
63}
64
65/// Resolve a possibly-relative `path` against `base`, returning a lexically
66/// normalized [`PathBuf`].
67///
68/// Behavior:
69/// - Absolute `path` passes through unchanged (only normalized).
70/// - Relative `path` is joined onto `base`.
71/// - `.` and `..` components are collapsed *lexically* — purely by string
72///   manipulation, without touching the filesystem and without requiring the
73///   path (or its parents) to exist.
74///
75/// Lexical (rather than `std::fs::canonicalize`) resolution is the deliberate
76/// choice for tool path handling: write/patch tools must resolve targets that
77/// do not yet exist on disk, and canonicalization both fails for missing paths
78/// and silently rewrites symlinks. Tools that genuinely need symlink-real-path
79/// resolution for an existence/scope check should use [`canonicalize_under`]
80/// instead and accept that it requires the path to exist.
81pub fn resolve_under(base: &Path, path: &Path) -> PathBuf {
82    let joined = if path.is_absolute() {
83        path.to_path_buf()
84    } else {
85        base.join(path)
86    };
87    normalize_lexical(&joined)
88}
89
90/// Lexically collapse `.` and `..` components in `path` without touching the
91/// filesystem. Leading `..` components (that would escape the root) are
92/// preserved verbatim, matching `Path::join` intuitions for relative roots.
93pub fn normalize_lexical(path: &Path) -> PathBuf {
94    let mut normalized = PathBuf::new();
95    for component in path.components() {
96        match component {
97            Component::CurDir => {}
98            Component::ParentDir => {
99                if !normalized.pop() {
100                    normalized.push(component.as_os_str());
101                }
102            }
103            Component::Prefix(_) | Component::RootDir | Component::Normal(_) => {
104                normalized.push(component.as_os_str());
105            }
106        }
107    }
108    normalized
109}
110
111/// Resolve `path` against `base` (via [`resolve_under`]) and then canonicalize
112/// it on disk, resolving symlinks to their real path. Fails if the path does
113/// not exist. Use this only when a tool needs a real, existence-checked path
114/// (e.g. a security/scope decision or distinguishing a file from a directory);
115/// prefer [`resolve_under`] for write/patch targets that may not exist yet.
116pub fn canonicalize_under(base: &Path, path: &Path) -> std::io::Result<PathBuf> {
117    std::fs::canonicalize(resolve_under(base, path))
118}
119
120/// Render `path` relative to `base` for display, falling back to the file name
121/// (then the full path) when `path` is not under `base`. Backslashes are
122/// normalized to forward slashes so output is stable across platforms.
123pub fn display_relative(base: &Path, path: &Path) -> String {
124    let display = path
125        .strip_prefix(base)
126        .unwrap_or(path)
127        .display()
128        .to_string();
129    let display = if display.is_empty() {
130        path.file_name()
131            .and_then(|name| name.to_str())
132            .unwrap_or(".")
133            .to_string()
134    } else {
135        display
136    };
137    display.replace('\\', "/")
138}
139
140/// Shared preamble describing default filesystem-listing behavior.
141/// Used by `ls` and `glob` so both tools document hidden-file and
142/// `.gitignore` handling in identical wording.
143pub const FS_DEFAULTS_PREAMBLE: &str =
144    "By default this includes hidden files and respects `.gitignore` only inside Git repos.";
145
146#[derive(Clone, Debug, Serialize, JsonSchema)]
147pub struct PathEntry {
148    pub path: String,
149    pub kind: String,
150    pub size_bytes: u64,
151    pub lines: Option<u64>,
152    pub modified_at: String,
153}
154
155#[derive(Clone, Debug, Serialize, JsonSchema)]
156pub struct TruncationMeta {
157    pub shown: usize,
158    pub total: usize,
159    pub omitted: usize,
160}
161
162#[derive(Clone, Debug, Serialize, JsonSchema)]
163#[serde(deny_unknown_fields)]
164pub struct FilesystemEntriesOutput {
165    pub items: Vec<PathEntry>,
166    pub truncated: Option<TruncationMeta>,
167}
168
169pub fn invalid_tool_args(message: impl Into<String>) -> ToolResult {
170    ToolResult::failure(ToolFailure::tool(
171        ToolFailureClass::InvalidRequest,
172        "invalid_tool_args",
173        message.into(),
174    ))
175}
176
177pub fn typed_tool_args<Args>(args: &serde_json::Value) -> Result<Args, ToolResult>
178where
179    Args: DeserializeOwned + JsonSchema,
180{
181    serde_json::from_value(args.clone())
182        .map_err(|err| invalid_tool_args(format!("Invalid tool arguments: {err}")))
183}
184
185pub fn typed_tool_ok<Output>(output: Output) -> ToolResult
186where
187    Output: Serialize + JsonSchema,
188{
189    match serde_json::to_value(output) {
190        Ok(value) => ToolResult::ok(value),
191        Err(err) => ToolResult::err_fmt(format_args!("Failed to serialize tool result: {err}")),
192    }
193}
194
195pub async fn execute_typed_tool<Args, Output, F, Fut>(
196    args: &serde_json::Value,
197    execute: F,
198) -> ToolResult
199where
200    Args: DeserializeOwned + JsonSchema,
201    Output: Serialize + JsonSchema,
202    F: FnOnce(Args) -> Fut,
203    Fut: Future<Output = Result<Output, ToolResult>>,
204{
205    let args = match typed_tool_args::<Args>(args) {
206        Ok(args) => args,
207        Err(err) => return err,
208    };
209    match execute(args).await {
210        Ok(output) => typed_tool_ok(output),
211        Err(err) => err,
212    }
213}
214
215pub async fn execute_typed_tool_result<Args, F, Fut>(
216    args: &serde_json::Value,
217    execute: F,
218) -> ToolResult
219where
220    Args: DeserializeOwned + JsonSchema,
221    F: FnOnce(Args) -> Fut,
222    Fut: Future<Output = ToolResult>,
223{
224    let args = match typed_tool_args::<Args>(args) {
225        Ok(args) => args,
226        Err(err) => return err,
227    };
228    execute(args).await
229}
230
231pub fn non_empty_string(value: &str, key: &str) -> Result<(), ToolResult> {
232    if value.is_empty() {
233        Err(invalid_tool_args(format!(
234            "Missing required parameter: {key}"
235        )))
236    } else {
237        Ok(())
238    }
239}
240
241pub fn default_true() -> bool {
242    true
243}
244
245pub fn default_path_dot() -> String {
246    ".".to_string()
247}
248
249#[derive(Clone, Debug, Deserialize, JsonSchema)]
250#[serde(untagged)]
251pub enum OptionalUsizeArg {
252    Value(usize),
253    NoneString(String),
254    Null(()),
255}
256
257impl OptionalUsizeArg {
258    pub fn into_option(self, key: &str, min: usize) -> Result<Option<usize>, ToolResult> {
259        match self {
260            Self::Value(value) if value >= min => Ok(Some(value)),
261            Self::Value(_) => Err(invalid_tool_args(format!(
262                "Invalid {key}: must be >= {min}, or use null/\"none\" for no cap"
263            ))),
264            Self::NoneString(value) if value.eq_ignore_ascii_case("none") => Ok(None),
265            Self::NoneString(_) => Err(invalid_tool_args(format!(
266                "Invalid {key}: expected int, null, or \"none\""
267            ))),
268            Self::Null(()) => Ok(None),
269        }
270    }
271}
272
273pub fn deserialize_optional_usize_none<'de, D>(deserializer: D) -> Result<Option<usize>, D::Error>
274where
275    D: Deserializer<'de>,
276{
277    #[derive(Deserialize)]
278    #[serde(untagged)]
279    enum OptionalUsize {
280        Int(usize),
281        String(String),
282        Null,
283    }
284
285    match Option::<OptionalUsize>::deserialize(deserializer)? {
286        None | Some(OptionalUsize::Null) => Ok(None),
287        Some(OptionalUsize::Int(value)) => Ok(Some(value)),
288        Some(OptionalUsize::String(value)) if value.eq_ignore_ascii_case("none") => Ok(None),
289        Some(OptionalUsize::String(_)) => Err(serde::de::Error::custom(
290            "expected integer, null, or \"none\"",
291        )),
292    }
293}
294
295pub fn default_ls_depth() -> OptionalUsizeArg {
296    OptionalUsizeArg::Value(3)
297}
298
299pub fn default_ls_limit() -> OptionalUsizeArg {
300    OptionalUsizeArg::Value(500)
301}
302
303pub fn default_glob_limit() -> OptionalUsizeArg {
304    OptionalUsizeArg::Value(100)
305}
306
307/// Extract a required non-empty string arg, or return ToolResult::err.
308pub fn require_str<'a>(args: &'a serde_json::Value, key: &str) -> Result<&'a str, ToolResult> {
309    args.get(key)
310        .and_then(|v| v.as_str())
311        .filter(|s| !s.is_empty())
312        .ok_or_else(|| ToolResult::err_fmt(format_args!("Missing required parameter: {key}")))
313}
314
315/// Parse optional bool arg with a default.
316pub fn parse_optional_bool(
317    args: &serde_json::Value,
318    key: &str,
319    default: bool,
320) -> Result<bool, ToolResult> {
321    match args.get(key) {
322        None => Ok(default),
323        Some(v) if v.is_null() => Ok(default),
324        Some(v) => match v.as_bool() {
325            Some(b) => Ok(b),
326            None => Err(ToolResult::err_fmt(format_args!(
327                "Invalid {key}: expected bool"
328            ))),
329        },
330    }
331}
332
333/// Parse an optional positive integer arg.
334/// Accepts `null` or `"none"` when `allow_none` is true.
335pub fn parse_optional_usize_arg(
336    args: &serde_json::Value,
337    key: &str,
338    default: Option<usize>,
339    allow_none: bool,
340    min: usize,
341) -> Result<Option<usize>, ToolResult> {
342    match args.get(key) {
343        None => Ok(default),
344        Some(v) if v.is_null() => {
345            if allow_none {
346                Ok(None)
347            } else {
348                Err(ToolResult::err_fmt(format_args!(
349                    "Invalid {key}: expected int >= {min}"
350                )))
351            }
352        }
353        Some(v) => {
354            if let Some(s) = v.as_str() {
355                if allow_none && s.eq_ignore_ascii_case("none") {
356                    return Ok(None);
357                }
358                return Err(ToolResult::err_fmt(format_args!(
359                    "Invalid {key}: expected int{}",
360                    if allow_none {
361                        ", null, or \"none\""
362                    } else {
363                        ""
364                    }
365                )));
366            }
367            let n = v.as_u64().ok_or_else(|| {
368                ToolResult::err_fmt(format_args!(
369                    "Invalid {key}: expected int{}",
370                    if allow_none {
371                        ", null, or \"none\""
372                    } else {
373                        ""
374                    }
375                ))
376            })? as usize;
377            if n < min {
378                return Err(ToolResult::err_fmt(format_args!(
379                    "Invalid {key}: must be >= {min}{}",
380                    if allow_none {
381                        ", or use null/\"none\" for no cap"
382                    } else {
383                        ""
384                    }
385                )));
386            }
387            Ok(Some(n))
388        }
389    }
390}
391
392pub fn object_schema(properties: serde_json::Value, required: &[&str]) -> serde_json::Value {
393    serde_json::json!({
394        "type": "object",
395        "properties": properties,
396        "required": required,
397        "additionalProperties": false,
398    })
399}
400
401pub fn path_entry_output_schema() -> serde_json::Value {
402    serde_json::json!({
403        "type": "object",
404        "properties": {
405            "path": { "type": "string" },
406            "kind": { "type": "string", "enum": ["file", "dir", "symlink", "other"] },
407            "size_bytes": { "type": "integer", "minimum": 0 },
408            "lines": {
409                "anyOf": [
410                    { "type": "integer", "minimum": 0 },
411                    { "type": "null" }
412                ]
413            },
414            "modified_at": {
415                "type": "string",
416                "description": "Modification timestamp formatted as RFC3339 UTC."
417            }
418        },
419        "required": ["path", "kind", "size_bytes", "lines", "modified_at"],
420        "additionalProperties": false,
421    })
422}
423
424pub fn filesystem_entries_output_schema() -> serde_json::Value {
425    serde_json::json!({
426        "type": "object",
427        "properties": {
428            "items": {
429                "type": "array",
430                "items": path_entry_output_schema()
431            },
432            "truncated": {
433                "anyOf": [
434                    {
435                        "type": "object",
436                        "properties": {
437                            "shown": { "type": "integer", "minimum": 0 },
438                            "total": { "type": "integer", "minimum": 0 },
439                            "omitted": { "type": "integer", "minimum": 0 }
440                        },
441                        "required": ["shown", "total", "omitted"],
442                        "additionalProperties": false
443                    },
444                    { "type": "null" }
445                ]
446            }
447        },
448        "required": ["items", "truncated"],
449        "additionalProperties": false,
450    })
451}
452
453pub fn lashlang_binding(
454    module_path: impl IntoIterator<Item = impl Into<String>>,
455    operation: impl Into<String>,
456    aliases: &[&str],
457) -> LashlangToolBinding {
458    LashlangToolBinding::new(module_path, operation).with_aliases(aliases.iter().copied())
459}
460
461/// Run blocking filesystem work off the async runtime.
462pub async fn run_blocking<F>(f: F) -> ToolResult
463where
464    F: FnOnce() -> ToolResult + Send + 'static,
465{
466    match tokio::task::spawn_blocking(f).await {
467        Ok(result) => result,
468        Err(e) => ToolResult::err_fmt(format_args!("blocking task failed: {e}")),
469    }
470}
471
472/// Run blocking work off the async runtime and return a typed value.
473pub async fn run_blocking_value<F, T>(f: F) -> Result<T, String>
474where
475    F: FnOnce() -> T + Send + 'static,
476    T: Send + 'static,
477{
478    tokio::task::spawn_blocking(f)
479        .await
480        .map_err(|err| format!("blocking task failed: {err}"))
481}
482
483/// Build a normalized filesystem entry for tool output.
484/// Returns the entry plus raw mtime for optional sorting.
485pub fn build_path_entry(path: &Path, with_lines: bool) -> (PathEntry, SystemTime) {
486    let fallback_mtime = UNIX_EPOCH;
487    let path_str = path.to_string_lossy().to_string();
488
489    let metadata = match std::fs::symlink_metadata(path) {
490        Ok(m) => m,
491        Err(_) => {
492            let entry = PathEntry {
493                path: path_str,
494                kind: "other".to_string(),
495                size_bytes: 0,
496                lines: None,
497                modified_at: format_time_rfc3339(fallback_mtime),
498            };
499            return (entry, fallback_mtime);
500        }
501    };
502
503    let file_type = metadata.file_type();
504    let kind = if file_type.is_symlink() {
505        "symlink"
506    } else if file_type.is_dir() {
507        "dir"
508    } else if file_type.is_file() {
509        "file"
510    } else {
511        "other"
512    };
513
514    let mtime = metadata.modified().unwrap_or(fallback_mtime);
515    let lines = if with_lines && kind == "file" {
516        count_text_lines(path)
517    } else {
518        None
519    };
520
521    let entry = PathEntry {
522        path: path_str,
523        kind: kind.to_string(),
524        size_bytes: metadata.len(),
525        lines,
526        modified_at: format_time_rfc3339(mtime),
527    };
528    (entry, mtime)
529}
530
531pub fn rg_file_list(
532    base: &Path,
533    include_hidden: bool,
534    respect_gitignore: bool,
535    max_depth: Option<usize>,
536    globs: &[String],
537) -> Result<Vec<PathBuf>, ToolResult> {
538    let mut builder = ignore::WalkBuilder::new(base);
539    builder.hidden(!include_hidden).max_depth(max_depth);
540
541    if respect_gitignore {
542        builder.git_ignore(true).git_exclude(true).git_global(true);
543        builder.require_git(true);
544    } else {
545        builder
546            .git_ignore(false)
547            .git_exclude(false)
548            .git_global(false)
549            .ignore(false)
550            .parents(false)
551            .require_git(false);
552    }
553
554    if !globs.is_empty() {
555        let mut override_builder = ignore::overrides::OverrideBuilder::new(base);
556        for glob in globs {
557            override_builder.add(glob).map_err(|err| {
558                ToolResult::err_fmt(format_args!(
559                    "invalid ignore glob for {}: {err}",
560                    base.display()
561                ))
562            })?;
563        }
564
565        let overrides = override_builder.build().map_err(|err| {
566            ToolResult::err_fmt(format_args!(
567                "failed to build ignore globs for {}: {err}",
568                base.display()
569            ))
570        })?;
571        builder.overrides(overrides);
572    }
573
574    let files = builder
575        .build()
576        .filter_map(Result::ok)
577        .filter(|entry| entry.path() != base)
578        .map(ignore::DirEntry::into_path)
579        .collect();
580    Ok(files)
581}
582
583/// Build the standard result envelope returned by filesystem listing tools.
584pub fn filesystem_entries_output(
585    items: Vec<PathEntry>,
586    total_count: usize,
587) -> FilesystemEntriesOutput {
588    let shown = items.len();
589    let truncated = if total_count > shown {
590        Some(TruncationMeta {
591            shown,
592            total: total_count,
593            omitted: total_count - shown,
594        })
595    } else {
596        None
597    };
598    FilesystemEntriesOutput { items, truncated }
599}
600
601pub fn filesystem_entries_result(items: Vec<PathEntry>, total_count: usize) -> serde_json::Value {
602    serde_json::to_value(filesystem_entries_output(items, total_count))
603        .unwrap_or_else(|_| serde_json::json!({ "items": [], "truncated": null }))
604}
605
606fn count_text_lines(path: &Path) -> Option<u64> {
607    let file = std::fs::File::open(path).ok()?;
608    let reader = BufReader::new(file);
609    let mut count = 0_u64;
610    for line in reader.lines() {
611        if line.is_err() {
612            return None;
613        }
614        count += 1;
615    }
616    Some(count)
617}
618
619fn format_time_rfc3339(ts: SystemTime) -> String {
620    chrono::DateTime::<chrono::Utc>::from(ts).to_rfc3339_opts(chrono::SecondsFormat::Secs, true)
621}
622
623/// Generate a compact unified diff between old and new content.
624/// Truncates to `max_lines` lines if the diff is too long.
625pub fn compact_diff(old: &str, new: &str, path: &str, max_lines: usize) -> String {
626    let diff = similar::TextDiff::from_lines(old, new);
627    let unified = diff
628        .unified_diff()
629        .header(&format!("a/{path}"), &format!("b/{path}"))
630        .to_string();
631    if unified.is_empty() {
632        return String::new();
633    }
634    let lines: Vec<&str> = unified.lines().collect();
635    if lines.len() <= max_lines {
636        unified
637    } else {
638        let mut truncated: String = lines[..max_lines].join("\n");
639        truncated.push_str(&format!("\n... ({} more lines)", lines.len() - max_lines));
640        truncated
641    }
642}