llm_coding_tools_rig/allowed/
grep.rs1use llm_coding_tools_core::operations::{grep_search, DEFAULT_MAX_LINE_LENGTH};
4use llm_coding_tools_core::path::AllowedPathResolver;
5use llm_coding_tools_core::tool_names;
6use llm_coding_tools_core::{ToolContext, ToolError, ToolOutput};
7use rig::completion::ToolDefinition;
8use rig::tool::Tool;
9use schemars::{schema_for, JsonSchema};
10use serde::Deserialize;
11
12const DEFAULT_LIMIT: usize = 100;
13const MAX_LIMIT: usize = 2000;
14
15fn default_limit() -> Option<usize> {
16 Some(DEFAULT_LIMIT)
17}
18
19#[derive(Debug, Deserialize, JsonSchema)]
21pub struct GrepArgs {
22 pub pattern: String,
24 pub path: String,
26 #[serde(default)]
28 pub include: Option<String>,
29 #[serde(default = "default_limit")]
31 pub limit: Option<usize>,
32}
33
34#[derive(Debug, Clone)]
36pub struct GrepTool<const LINE_NUMBERS: bool = true> {
37 resolver: AllowedPathResolver,
38}
39
40impl<const LINE_NUMBERS: bool> GrepTool<LINE_NUMBERS> {
41 pub fn new(resolver: AllowedPathResolver) -> Self {
47 Self { resolver }
48 }
49}
50
51impl<const LINE_NUMBERS: bool> Tool for GrepTool<LINE_NUMBERS> {
52 const NAME: &'static str = tool_names::GREP;
53
54 type Error = ToolError;
55 type Args = GrepArgs;
56 type Output = ToolOutput;
57
58 async fn definition(&self, _prompt: String) -> ToolDefinition {
59 ToolDefinition {
60 name: <Self as Tool>::NAME.to_string(),
61 description: "Search file contents using regex patterns within allowed directories. \
62 Paths are relative to configured base directories."
63 .to_string(),
64 parameters: serde_json::to_value(schema_for!(GrepArgs))
65 .expect("schema serialization should not fail"),
66 }
67 }
68
69 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
70 let pattern = args.pattern.trim();
71 if pattern.is_empty() {
72 return Err(ToolError::InvalidPattern(
73 "pattern must not be empty".into(),
74 ));
75 }
76
77 let limit = args.limit.unwrap_or(DEFAULT_LIMIT).min(MAX_LIMIT);
78 if limit == 0 {
79 return Err(ToolError::Validation(
80 "limit must be greater than zero".into(),
81 ));
82 }
83
84 let include = args.include.as_deref().and_then(|s| {
85 let trimmed = s.trim();
86 if trimmed.is_empty() {
87 None
88 } else {
89 Some(trimmed)
90 }
91 });
92
93 let result = grep_search(&self.resolver, pattern, include, &args.path, limit)?;
94
95 if result.files.is_empty() {
96 return Ok(ToolOutput::new("No matches found."));
97 }
98
99 let output = result.format::<LINE_NUMBERS>(limit, DEFAULT_MAX_LINE_LENGTH);
100
101 Ok(if result.truncated {
102 ToolOutput::truncated(output)
103 } else {
104 ToolOutput::new(output)
105 })
106 }
107}
108
109impl<const LINE_NUMBERS: bool> ToolContext for GrepTool<LINE_NUMBERS> {
110 const NAME: &'static str = tool_names::GREP;
111
112 fn context(&self) -> &'static str {
113 llm_coding_tools_core::context::GREP_ALLOWED
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use tempfile::TempDir;
121
122 #[tokio::test]
123 async fn finds_matching_content() {
124 let dir = TempDir::new().unwrap();
125 std::fs::write(dir.path().join("test.txt"), "hello world").unwrap();
126
127 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
128 let tool: GrepTool<true> = GrepTool::new(resolver);
129 let result = tool
130 .call(GrepArgs {
131 pattern: "hello".to_string(),
132 path: ".".to_string(),
133 include: None,
134 limit: None,
135 })
136 .await
137 .unwrap();
138 assert!(result.content.contains("Found 1 matches"));
139 assert!(result.content.contains("L1: hello world"));
140 }
141
142 #[tokio::test]
143 async fn rejects_path_traversal() {
144 let dir = TempDir::new().unwrap();
145 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
146 let tool: GrepTool = GrepTool::new(resolver);
147 let result = tool
148 .call(GrepArgs {
149 pattern: "test".to_string(),
150 path: "../../../etc".to_string(),
151 include: None,
152 limit: None,
153 })
154 .await;
155 assert!(matches!(result, Err(ToolError::InvalidPath(_))));
156 }
157
158 #[tokio::test]
159 async fn rejects_empty_pattern() {
160 let dir = TempDir::new().unwrap();
161 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
162 let tool: GrepTool = GrepTool::new(resolver);
163 let result = tool
164 .call(GrepArgs {
165 pattern: " ".to_string(),
166 path: ".".to_string(),
167 include: None,
168 limit: None,
169 })
170 .await;
171 assert!(matches!(result, Err(ToolError::InvalidPattern(_))));
172 }
173
174 #[tokio::test]
175 async fn truncates_long_lines_at_utf8_boundary() {
176 let dir = TempDir::new().unwrap();
177
178 let long_line = format!("match_me {}{}", "a".repeat(1989), "日本語");
182 assert!(
183 long_line.len() > 2000,
184 "test setup: line must exceed MAX_LINE_LENGTH"
185 );
186
187 std::fs::write(dir.path().join("utf8_test.txt"), &long_line).unwrap();
188
189 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
190 let tool: GrepTool<true> = GrepTool::new(resolver);
191 let result = tool
192 .call(GrepArgs {
193 pattern: "match_me".to_string(),
194 path: ".".to_string(),
195 include: None,
196 limit: None,
197 })
198 .await
199 .unwrap();
200
201 assert!(result.content.contains("Found 1 matches"));
203 assert!(result.content.contains("L1:"));
204 }
206}