perspt_agent/
context_retriever.rs1use anyhow::Result;
7use grep::regex::RegexMatcher;
8use grep::searcher::sinks::UTF8;
9use grep::searcher::Searcher;
10use ignore::WalkBuilder;
11use std::path::{Path, PathBuf};
12
13#[derive(Debug, Clone)]
15pub struct SearchHit {
16 pub file: PathBuf,
18 pub line: u32,
20 pub content: String,
22 pub column: Option<usize>,
24}
25
26pub struct ContextRetriever {
28 working_dir: PathBuf,
30 max_file_bytes: usize,
32 max_context_bytes: usize,
34}
35
36impl ContextRetriever {
37 pub fn new(working_dir: PathBuf) -> Self {
39 Self {
40 working_dir,
41 max_file_bytes: 50 * 1024, max_context_bytes: 100 * 1024, }
44 }
45
46 pub fn with_max_file_bytes(mut self, bytes: usize) -> Self {
48 self.max_file_bytes = bytes;
49 self
50 }
51
52 pub fn with_max_context_bytes(mut self, bytes: usize) -> Self {
54 self.max_context_bytes = bytes;
55 self
56 }
57
58 pub fn search(&self, pattern: &str, max_results: usize) -> Vec<SearchHit> {
61 let mut hits = Vec::new();
62
63 let matcher = match RegexMatcher::new(pattern) {
65 Ok(m) => m,
66 Err(e) => {
67 log::warn!("Invalid search pattern '{}': {}", pattern, e);
68 return hits;
69 }
70 };
71
72 let walker = WalkBuilder::new(&self.working_dir)
74 .hidden(true) .git_ignore(true) .git_global(true) .git_exclude(true) .build();
79
80 let mut searcher = Searcher::new();
81
82 for entry in walker.flatten() {
83 if hits.len() >= max_results {
84 break;
85 }
86
87 let path = entry.path();
88
89 if !path.is_file() {
91 continue;
92 }
93
94 if Self::is_binary_extension(path) {
96 continue;
97 }
98
99 let _ = searcher.search_path(
101 &matcher,
102 path,
103 UTF8(|line_num, line| {
104 if hits.len() < max_results {
105 let relative_path = path
106 .strip_prefix(&self.working_dir)
107 .unwrap_or(path)
108 .to_path_buf();
109
110 hits.push(SearchHit {
111 file: relative_path,
112 line: line_num as u32,
113 content: line.trim_end().to_string(),
114 column: None,
115 });
116 }
117 Ok(hits.len() < max_results)
118 }),
119 );
120 }
121
122 hits
123 }
124
125 pub fn read_file_truncated(&self, path: &Path) -> Result<String> {
127 let full_path = if path.is_absolute() {
128 path.to_path_buf()
129 } else {
130 self.working_dir.join(path)
131 };
132
133 let content = std::fs::read_to_string(&full_path)?;
134
135 if content.len() > self.max_file_bytes {
136 let truncated = &content[..self.max_file_bytes];
137 let last_newline = truncated.rfind('\n').unwrap_or(self.max_file_bytes);
139 Ok(format!(
140 "{}\n\n... [truncated, {} more bytes]",
141 &content[..last_newline],
142 content.len() - last_newline
143 ))
144 } else {
145 Ok(content)
146 }
147 }
148
149 pub fn get_task_context(&self, context_files: &[PathBuf], output_files: &[PathBuf]) -> String {
152 let mut context = String::new();
153 let mut remaining_budget = self.max_context_bytes;
154
155 if !context_files.is_empty() {
157 context.push_str("## Context Files (for reference)\n\n");
158 for file in context_files {
159 if remaining_budget == 0 {
160 break;
161 }
162 if let Ok(content) = self.read_file_truncated(file) {
163 let section = format!("### {}\n```\n{}\n```\n\n", file.display(), content);
164 if section.len() <= remaining_budget {
165 remaining_budget -= section.len();
166 context.push_str(§ion);
167 }
168 }
169 }
170 }
171
172 if !output_files.is_empty() {
174 context.push_str("## Target Files (to modify)\n\n");
175 for file in output_files {
176 if remaining_budget == 0 {
177 break;
178 }
179 let full_path = self.working_dir.join(file);
180 if full_path.exists() {
181 if let Ok(content) = self.read_file_truncated(file) {
182 let section = format!(
183 "### {} (current content)\n```\n{}\n```\n\n",
184 file.display(),
185 content
186 );
187 if section.len() <= remaining_budget {
188 remaining_budget -= section.len();
189 context.push_str(§ion);
190 }
191 }
192 } else {
193 context.push_str(&format!("### {} (new file)\n\n", file.display()));
194 }
195 }
196 }
197
198 context
199 }
200
201 pub fn search_for_context(&self, query: &str, max_results: usize) -> String {
204 let hits = self.search(query, max_results);
205
206 if hits.is_empty() {
207 return String::new();
208 }
209
210 let mut context = format!("## Related Code (search: '{}')\n\n", query);
211
212 for hit in &hits {
213 context.push_str(&format!(
214 "- **{}:{}**: `{}`\n",
215 hit.file.display(),
216 hit.line,
217 hit.content.trim()
218 ));
219 }
220 context.push('\n');
221
222 context
223 }
224
225 fn is_binary_extension(path: &Path) -> bool {
227 match path.extension().and_then(|e| e.to_str()) {
228 Some(ext) => matches!(
229 ext.to_lowercase().as_str(),
230 "png"
231 | "jpg"
232 | "jpeg"
233 | "gif"
234 | "bmp"
235 | "ico"
236 | "webp"
237 | "pdf"
238 | "doc"
239 | "docx"
240 | "xls"
241 | "xlsx"
242 | "ppt"
243 | "pptx"
244 | "zip"
245 | "tar"
246 | "gz"
247 | "bz2"
248 | "7z"
249 | "rar"
250 | "exe"
251 | "dll"
252 | "so"
253 | "dylib"
254 | "a"
255 | "wasm"
256 | "o"
257 | "obj"
258 | "pyc"
259 | "pyo"
260 | "class"
261 | "db"
262 | "sqlite"
263 | "sqlite3"
264 ),
265 None => false,
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use std::fs;
274 use tempfile::tempdir;
275
276 #[test]
277 fn test_search_finds_pattern() {
278 let dir = tempdir().unwrap();
279 let file_path = dir.path().join("test.py");
280 fs::write(&file_path, "def hello_world():\n print('Hello')\n").unwrap();
281
282 let retriever = ContextRetriever::new(dir.path().to_path_buf());
283 let hits = retriever.search("hello_world", 10);
284
285 assert_eq!(hits.len(), 1);
286 assert!(hits[0].content.contains("def hello_world"));
287 }
288
289 #[test]
290 fn test_read_file_truncated() {
291 let dir = tempdir().unwrap();
292 let file_path = dir.path().join("large.txt");
293 let content = "line\n".repeat(10000); fs::write(&file_path, &content).unwrap();
295
296 let retriever = ContextRetriever::new(dir.path().to_path_buf()).with_max_file_bytes(1000);
297
298 let result = retriever.read_file_truncated(&file_path).unwrap();
299 assert!(result.contains("truncated"));
300 assert!(result.len() < 2000); }
302}