perspt_agent/
context_retriever.rs

1//! Context Retriever
2//!
3//! Uses the grep crate (ripgrep library) for fast code search across the workspace.
4//! Provides context retrieval for LLM prompts while respecting token budgets.
5
6use anyhow::Result;
7use grep::regex::RegexMatcher;
8use grep::searcher::sinks::UTF8;
9use grep::searcher::Searcher;
10use ignore::WalkBuilder;
11use std::path::{Path, PathBuf};
12
13/// A search hit from grep
14#[derive(Debug, Clone)]
15pub struct SearchHit {
16    /// File path (relative to workspace)
17    pub file: PathBuf,
18    /// Line number (1-indexed)
19    pub line: u32,
20    /// Content of the matching line
21    pub content: String,
22    /// Column where match starts (0-indexed)
23    pub column: Option<usize>,
24}
25
26/// Context retriever for gathering relevant code context
27pub struct ContextRetriever {
28    /// Workspace root directory
29    working_dir: PathBuf,
30    /// Maximum bytes to read per file
31    max_file_bytes: usize,
32    /// Maximum total context bytes
33    max_context_bytes: usize,
34}
35
36impl ContextRetriever {
37    /// Create a new context retriever
38    pub fn new(working_dir: PathBuf) -> Self {
39        Self {
40            working_dir,
41            max_file_bytes: 50 * 1024,     // 50KB per file
42            max_context_bytes: 100 * 1024, // 100KB total
43        }
44    }
45
46    /// Set max bytes per file
47    pub fn with_max_file_bytes(mut self, bytes: usize) -> Self {
48        self.max_file_bytes = bytes;
49        self
50    }
51
52    /// Set max total context bytes
53    pub fn with_max_context_bytes(mut self, bytes: usize) -> Self {
54        self.max_context_bytes = bytes;
55        self
56    }
57
58    /// Search for a pattern in the workspace using ripgrep
59    /// Respects .gitignore and common ignore patterns
60    pub fn search(&self, pattern: &str, max_results: usize) -> Vec<SearchHit> {
61        let mut hits = Vec::new();
62
63        // Create regex matcher
64        let matcher = match RegexMatcher::new(pattern) {
65            Ok(m) => m,
66            Err(e) => {
67                log::warn!("Invalid search pattern '{}': {}", pattern, e);
68                return hits;
69            }
70        };
71
72        // Walk workspace respecting .gitignore
73        let walker = WalkBuilder::new(&self.working_dir)
74            .hidden(true) // Skip hidden files
75            .git_ignore(true) // Respect .gitignore
76            .git_global(true) // Respect global gitignore
77            .git_exclude(true) // Respect .git/info/exclude
78            .build();
79
80        let mut searcher = Searcher::new();
81
82        for entry in walker.flatten() {
83            if hits.len() >= max_results {
84                break;
85            }
86
87            let path = entry.path();
88
89            // Only search files
90            if !path.is_file() {
91                continue;
92            }
93
94            // Skip binary files by extension
95            if Self::is_binary_extension(path) {
96                continue;
97            }
98
99            // Search the file
100            let _ = searcher.search_path(
101                &matcher,
102                path,
103                UTF8(|line_num, line| {
104                    if hits.len() < max_results {
105                        let relative_path = path
106                            .strip_prefix(&self.working_dir)
107                            .unwrap_or(path)
108                            .to_path_buf();
109
110                        hits.push(SearchHit {
111                            file: relative_path,
112                            line: line_num as u32,
113                            content: line.trim_end().to_string(),
114                            column: None,
115                        });
116                    }
117                    Ok(hits.len() < max_results)
118                }),
119            );
120        }
121
122        hits
123    }
124
125    /// Read a file with truncation if it exceeds max bytes
126    pub fn read_file_truncated(&self, path: &Path) -> Result<String> {
127        let full_path = if path.is_absolute() {
128            path.to_path_buf()
129        } else {
130            self.working_dir.join(path)
131        };
132
133        let content = std::fs::read_to_string(&full_path)?;
134
135        if content.len() > self.max_file_bytes {
136            let truncated = &content[..self.max_file_bytes];
137            // Find last newline to avoid cutting mid-line
138            let last_newline = truncated.rfind('\n').unwrap_or(self.max_file_bytes);
139            Ok(format!(
140                "{}\n\n... [truncated, {} more bytes]",
141                &content[..last_newline],
142                content.len() - last_newline
143            ))
144        } else {
145            Ok(content)
146        }
147    }
148
149    /// Get context for a task based on its context_files and output_files
150    /// Returns a formatted string suitable for LLM prompts
151    pub fn get_task_context(&self, context_files: &[PathBuf], output_files: &[PathBuf]) -> String {
152        let mut context = String::new();
153        let mut remaining_budget = self.max_context_bytes;
154
155        // Add context files (files to read for understanding)
156        if !context_files.is_empty() {
157            context.push_str("## Context Files (for reference)\n\n");
158            for file in context_files {
159                if remaining_budget == 0 {
160                    break;
161                }
162                if let Ok(content) = self.read_file_truncated(file) {
163                    let section = format!("### {}\n```\n{}\n```\n\n", file.display(), content);
164                    if section.len() <= remaining_budget {
165                        remaining_budget -= section.len();
166                        context.push_str(&section);
167                    }
168                }
169            }
170        }
171
172        // Add output files (files to modify - show current state)
173        if !output_files.is_empty() {
174            context.push_str("## Target Files (to modify)\n\n");
175            for file in output_files {
176                if remaining_budget == 0 {
177                    break;
178                }
179                let full_path = self.working_dir.join(file);
180                if full_path.exists() {
181                    if let Ok(content) = self.read_file_truncated(file) {
182                        let section = format!(
183                            "### {} (current content)\n```\n{}\n```\n\n",
184                            file.display(),
185                            content
186                        );
187                        if section.len() <= remaining_budget {
188                            remaining_budget -= section.len();
189                            context.push_str(&section);
190                        }
191                    }
192                } else {
193                    context.push_str(&format!("### {} (new file)\n\n", file.display()));
194                }
195            }
196        }
197
198        context
199    }
200
201    /// Search for relevant code based on a query (e.g., function name, class name)
202    /// Returns formatted context for LLM
203    pub fn search_for_context(&self, query: &str, max_results: usize) -> String {
204        let hits = self.search(query, max_results);
205
206        if hits.is_empty() {
207            return String::new();
208        }
209
210        let mut context = format!("## Related Code (search: '{}')\n\n", query);
211
212        for hit in &hits {
213            context.push_str(&format!(
214                "- **{}:{}**: `{}`\n",
215                hit.file.display(),
216                hit.line,
217                hit.content.trim()
218            ));
219        }
220        context.push('\n');
221
222        context
223    }
224
225    /// Check if a file extension indicates a binary file
226    fn is_binary_extension(path: &Path) -> bool {
227        match path.extension().and_then(|e| e.to_str()) {
228            Some(ext) => matches!(
229                ext.to_lowercase().as_str(),
230                "png"
231                    | "jpg"
232                    | "jpeg"
233                    | "gif"
234                    | "bmp"
235                    | "ico"
236                    | "webp"
237                    | "pdf"
238                    | "doc"
239                    | "docx"
240                    | "xls"
241                    | "xlsx"
242                    | "ppt"
243                    | "pptx"
244                    | "zip"
245                    | "tar"
246                    | "gz"
247                    | "bz2"
248                    | "7z"
249                    | "rar"
250                    | "exe"
251                    | "dll"
252                    | "so"
253                    | "dylib"
254                    | "a"
255                    | "wasm"
256                    | "o"
257                    | "obj"
258                    | "pyc"
259                    | "pyo"
260                    | "class"
261                    | "db"
262                    | "sqlite"
263                    | "sqlite3"
264            ),
265            None => false,
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use std::fs;
274    use tempfile::tempdir;
275
276    #[test]
277    fn test_search_finds_pattern() {
278        let dir = tempdir().unwrap();
279        let file_path = dir.path().join("test.py");
280        fs::write(&file_path, "def hello_world():\n    print('Hello')\n").unwrap();
281
282        let retriever = ContextRetriever::new(dir.path().to_path_buf());
283        let hits = retriever.search("hello_world", 10);
284
285        assert_eq!(hits.len(), 1);
286        assert!(hits[0].content.contains("def hello_world"));
287    }
288
289    #[test]
290    fn test_read_file_truncated() {
291        let dir = tempdir().unwrap();
292        let file_path = dir.path().join("large.txt");
293        let content = "line\n".repeat(10000); // ~50KB
294        fs::write(&file_path, &content).unwrap();
295
296        let retriever = ContextRetriever::new(dir.path().to_path_buf()).with_max_file_bytes(1000);
297
298        let result = retriever.read_file_truncated(&file_path).unwrap();
299        assert!(result.contains("truncated"));
300        assert!(result.len() < 2000); // Should be truncated + message
301    }
302}