llm_coding_tools_rig/allowed/
grep.rs

1//! Grep content search tool using [`AllowedPathResolver`].
2
3use llm_coding_tools_core::operations::{grep_search, DEFAULT_MAX_LINE_LENGTH};
4use llm_coding_tools_core::path::AllowedPathResolver;
5use llm_coding_tools_core::tool_names;
6use llm_coding_tools_core::{ToolContext, ToolError, ToolOutput};
7use rig::completion::ToolDefinition;
8use rig::tool::Tool;
9use schemars::{schema_for, JsonSchema};
10use serde::Deserialize;
11
12const DEFAULT_LIMIT: usize = 100;
13const MAX_LIMIT: usize = 2000;
14
15fn default_limit() -> Option<usize> {
16    Some(DEFAULT_LIMIT)
17}
18
19/// Arguments for the grep tool.
20#[derive(Debug, Deserialize, JsonSchema)]
21pub struct GrepArgs {
22    /// Regex pattern to search for in file contents.
23    pub pattern: String,
24    /// Relative directory path to search in (within allowed directories).
25    pub path: String,
26    /// Optional file glob filter (e.g., "*.rs", "*.{ts,tsx}").
27    #[serde(default)]
28    pub include: Option<String>,
29    /// Maximum number of matches to return (default: 100, max: 2000).
30    #[serde(default = "default_limit")]
31    pub limit: Option<usize>,
32}
33
34/// Tool for searching file contents within allowed directories.
35#[derive(Debug, Clone)]
36pub struct GrepTool<const LINE_NUMBERS: bool = true> {
37    resolver: AllowedPathResolver,
38}
39
40impl<const LINE_NUMBERS: bool> GrepTool<LINE_NUMBERS> {
41    /// Creates a new grep tool with a shared resolver.
42    ///
43    /// See [`ReadTool::new`] for usage example.
44    ///
45    /// [`ReadTool::new`]: super::ReadTool::new
46    pub fn new(resolver: AllowedPathResolver) -> Self {
47        Self { resolver }
48    }
49}
50
51impl<const LINE_NUMBERS: bool> Tool for GrepTool<LINE_NUMBERS> {
52    const NAME: &'static str = tool_names::GREP;
53
54    type Error = ToolError;
55    type Args = GrepArgs;
56    type Output = ToolOutput;
57
58    async fn definition(&self, _prompt: String) -> ToolDefinition {
59        ToolDefinition {
60            name: <Self as Tool>::NAME.to_string(),
61            description: "Search file contents using regex patterns within allowed directories. \
62                          Paths are relative to configured base directories."
63                .to_string(),
64            parameters: serde_json::to_value(schema_for!(GrepArgs))
65                .expect("schema serialization should not fail"),
66        }
67    }
68
69    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
70        let pattern = args.pattern.trim();
71        if pattern.is_empty() {
72            return Err(ToolError::InvalidPattern(
73                "pattern must not be empty".into(),
74            ));
75        }
76
77        let limit = args.limit.unwrap_or(DEFAULT_LIMIT).min(MAX_LIMIT);
78        if limit == 0 {
79            return Err(ToolError::Validation(
80                "limit must be greater than zero".into(),
81            ));
82        }
83
84        let include = args.include.as_deref().and_then(|s| {
85            let trimmed = s.trim();
86            if trimmed.is_empty() {
87                None
88            } else {
89                Some(trimmed)
90            }
91        });
92
93        let result = grep_search(&self.resolver, pattern, include, &args.path, limit)?;
94
95        if result.files.is_empty() {
96            return Ok(ToolOutput::new("No matches found."));
97        }
98
99        let output = result.format::<LINE_NUMBERS>(limit, DEFAULT_MAX_LINE_LENGTH);
100
101        Ok(if result.truncated {
102            ToolOutput::truncated(output)
103        } else {
104            ToolOutput::new(output)
105        })
106    }
107}
108
109impl<const LINE_NUMBERS: bool> ToolContext for GrepTool<LINE_NUMBERS> {
110    const NAME: &'static str = tool_names::GREP;
111
112    fn context(&self) -> &'static str {
113        llm_coding_tools_core::context::GREP_ALLOWED
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use tempfile::TempDir;
121
122    #[tokio::test]
123    async fn finds_matching_content() {
124        let dir = TempDir::new().unwrap();
125        std::fs::write(dir.path().join("test.txt"), "hello world").unwrap();
126
127        let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
128        let tool: GrepTool<true> = GrepTool::new(resolver);
129        let result = tool
130            .call(GrepArgs {
131                pattern: "hello".to_string(),
132                path: ".".to_string(),
133                include: None,
134                limit: None,
135            })
136            .await
137            .unwrap();
138        assert!(result.content.contains("Found 1 matches"));
139        assert!(result.content.contains("L1: hello world"));
140    }
141
142    #[tokio::test]
143    async fn rejects_path_traversal() {
144        let dir = TempDir::new().unwrap();
145        let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
146        let tool: GrepTool = GrepTool::new(resolver);
147        let result = tool
148            .call(GrepArgs {
149                pattern: "test".to_string(),
150                path: "../../../etc".to_string(),
151                include: None,
152                limit: None,
153            })
154            .await;
155        assert!(matches!(result, Err(ToolError::InvalidPath(_))));
156    }
157
158    #[tokio::test]
159    async fn rejects_empty_pattern() {
160        let dir = TempDir::new().unwrap();
161        let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
162        let tool: GrepTool = GrepTool::new(resolver);
163        let result = tool
164            .call(GrepArgs {
165                pattern: "   ".to_string(),
166                path: ".".to_string(),
167                include: None,
168                limit: None,
169            })
170            .await;
171        assert!(matches!(result, Err(ToolError::InvalidPattern(_))));
172    }
173
174    #[tokio::test]
175    async fn truncates_long_lines_at_utf8_boundary() {
176        let dir = TempDir::new().unwrap();
177
178        // Create a line that's > MAX_LINE_LENGTH (2000) bytes with multibyte chars at the boundary.
179        // Use 1998 ASCII chars + "日本語" (9 bytes for 3 chars) = 2007 bytes total.
180        // Truncating at byte 2000 would land inside the multibyte sequence without floor_char_boundary.
181        let long_line = format!("match_me {}{}", "a".repeat(1989), "日本語");
182        assert!(
183            long_line.len() > 2000,
184            "test setup: line must exceed MAX_LINE_LENGTH"
185        );
186
187        std::fs::write(dir.path().join("utf8_test.txt"), &long_line).unwrap();
188
189        let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
190        let tool: GrepTool<true> = GrepTool::new(resolver);
191        let result = tool
192            .call(GrepArgs {
193                pattern: "match_me".to_string(),
194                path: ".".to_string(),
195                include: None,
196                limit: None,
197            })
198            .await
199            .unwrap();
200
201        // Should not panic and output should be valid UTF-8
202        assert!(result.content.contains("Found 1 matches"));
203        assert!(result.content.contains("L1:"));
204        // The output should be valid UTF-8 (this is implicitly tested by using .contains on a String)
205    }
206}