llm_coding_tools_rig/allowed/
glob.rs1use llm_coding_tools_core::operations::glob_files;
4use llm_coding_tools_core::path::AllowedPathResolver;
5use llm_coding_tools_core::tool_names;
6use llm_coding_tools_core::{GlobOutput, ToolContext, ToolError};
7use rig::completion::ToolDefinition;
8use rig::tool::Tool;
9use schemars::{schema_for, JsonSchema};
10use serde::Deserialize;
11
12#[derive(Debug, Deserialize, JsonSchema)]
14pub struct GlobArgs {
15 pub pattern: String,
17 pub path: String,
19}
20
21#[derive(Debug, Clone)]
23pub struct GlobTool {
24 resolver: AllowedPathResolver,
25}
26
27impl GlobTool {
28 pub fn new(resolver: AllowedPathResolver) -> Self {
32 Self { resolver }
33 }
34}
35
36impl Tool for GlobTool {
37 const NAME: &'static str = tool_names::GLOB;
38
39 type Error = ToolError;
40 type Args = GlobArgs;
41 type Output = GlobOutput;
42
43 async fn definition(&self, _prompt: String) -> ToolDefinition {
44 ToolDefinition {
45 name: <Self as Tool>::NAME.to_string(),
46 description: "Find files matching a glob pattern within allowed directories. \
47 Paths are relative to configured base directories."
48 .to_string(),
49 parameters: serde_json::to_value(schema_for!(GlobArgs))
50 .expect("schema serialization should not fail"),
51 }
52 }
53
54 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
55 glob_files(&self.resolver, &args.pattern, &args.path)
56 }
57}
58
59impl ToolContext for GlobTool {
60 const NAME: &'static str = tool_names::GLOB;
61
62 fn context(&self) -> &'static str {
63 llm_coding_tools_core::context::GLOB_ALLOWED
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70 use std::fs::{self, File};
71 use tempfile::TempDir;
72
73 #[tokio::test]
74 async fn finds_matching_files() {
75 let dir = TempDir::new().unwrap();
76 fs::create_dir_all(dir.path().join("src")).unwrap();
77 File::create(dir.path().join("src/lib.rs")).unwrap();
78
79 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
80 let tool = GlobTool::new(resolver);
81 let result = tool
82 .call(GlobArgs {
83 pattern: "**/*.rs".to_string(),
84 path: ".".to_string(),
85 })
86 .await
87 .unwrap();
88 assert!(result.files.iter().any(|f| f.ends_with("lib.rs")));
89 }
90
91 #[tokio::test]
92 async fn rejects_path_traversal() {
93 let dir = TempDir::new().unwrap();
94 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
95 let tool = GlobTool::new(resolver);
96 let result = tool
97 .call(GlobArgs {
98 pattern: "*.rs".to_string(),
99 path: "../../../etc".to_string(),
100 })
101 .await;
102 assert!(matches!(result, Err(ToolError::InvalidPath(_))));
103 }
104}