llm_coding_tools_rig/allowed/
glob.rs

1//! Glob pattern file finding tool using [`AllowedPathResolver`].
2
3use llm_coding_tools_core::operations::glob_files;
4use llm_coding_tools_core::path::AllowedPathResolver;
5use llm_coding_tools_core::tool_names;
6use llm_coding_tools_core::{GlobOutput, ToolContext, ToolError};
7use rig::completion::ToolDefinition;
8use rig::tool::Tool;
9use schemars::{schema_for, JsonSchema};
10use serde::Deserialize;
11
12/// Arguments for the glob tool.
13#[derive(Debug, Deserialize, JsonSchema)]
14pub struct GlobArgs {
15    /// Glob pattern to match files against (e.g., "**/*.rs", "src/**/*.ts").
16    pub pattern: String,
17    /// Relative directory path to search in (within allowed directories).
18    pub path: String,
19}
20
21/// Tool for finding files matching glob patterns within allowed directories.
22#[derive(Debug, Clone)]
23pub struct GlobTool {
24    resolver: AllowedPathResolver,
25}
26
27impl GlobTool {
28    /// Creates a new glob tool with a shared resolver.
29    ///
30    /// See [`ReadTool::new`](crate::allowed::read::ReadTool::new) for usage example.
31    pub fn new(resolver: AllowedPathResolver) -> Self {
32        Self { resolver }
33    }
34}
35
36impl Tool for GlobTool {
37    const NAME: &'static str = tool_names::GLOB;
38
39    type Error = ToolError;
40    type Args = GlobArgs;
41    type Output = GlobOutput;
42
43    async fn definition(&self, _prompt: String) -> ToolDefinition {
44        ToolDefinition {
45            name: <Self as Tool>::NAME.to_string(),
46            description: "Find files matching a glob pattern within allowed directories. \
47                          Paths are relative to configured base directories."
48                .to_string(),
49            parameters: serde_json::to_value(schema_for!(GlobArgs))
50                .expect("schema serialization should not fail"),
51        }
52    }
53
54    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
55        glob_files(&self.resolver, &args.pattern, &args.path)
56    }
57}
58
59impl ToolContext for GlobTool {
60    const NAME: &'static str = tool_names::GLOB;
61
62    fn context(&self) -> &'static str {
63        llm_coding_tools_core::context::GLOB_ALLOWED
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use std::fs::{self, File};
71    use tempfile::TempDir;
72
73    #[tokio::test]
74    async fn finds_matching_files() {
75        let dir = TempDir::new().unwrap();
76        fs::create_dir_all(dir.path().join("src")).unwrap();
77        File::create(dir.path().join("src/lib.rs")).unwrap();
78
79        let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
80        let tool = GlobTool::new(resolver);
81        let result = tool
82            .call(GlobArgs {
83                pattern: "**/*.rs".to_string(),
84                path: ".".to_string(),
85            })
86            .await
87            .unwrap();
88        assert!(result.files.iter().any(|f| f.ends_with("lib.rs")));
89    }
90
91    #[tokio::test]
92    async fn rejects_path_traversal() {
93        let dir = TempDir::new().unwrap();
94        let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
95        let tool = GlobTool::new(resolver);
96        let result = tool
97            .call(GlobArgs {
98                pattern: "*.rs".to_string(),
99                path: "../../../etc".to_string(),
100            })
101            .await;
102        assert!(matches!(result, Err(ToolError::InvalidPath(_))));
103    }
104}