Skip to main content

perspt_agent/
context_retriever.rs

1//! Context Retriever
2//!
3//! Uses the grep crate (ripgrep library) for fast code search across the workspace.
4//! Provides context retrieval for LLM prompts while respecting token budgets.
5
6use anyhow::Result;
7use grep::regex::RegexMatcher;
8use grep::searcher::sinks::UTF8;
9use grep::searcher::Searcher;
10use ignore::WalkBuilder;
11use std::path::{Path, PathBuf};
12
13/// A search hit from grep
14#[derive(Debug, Clone)]
15pub struct SearchHit {
16    /// File path (relative to workspace)
17    pub file: PathBuf,
18    /// Line number (1-indexed)
19    pub line: u32,
20    /// Content of the matching line
21    pub content: String,
22    /// Column where match starts (0-indexed)
23    pub column: Option<usize>,
24}
25
26/// Context retriever for gathering relevant code context
27pub struct ContextRetriever {
28    /// Workspace root directory
29    working_dir: PathBuf,
30    /// Maximum bytes to read per file
31    max_file_bytes: usize,
32    /// Maximum total context bytes
33    max_context_bytes: usize,
34}
35
36impl ContextRetriever {
37    /// Create a new context retriever
38    pub fn new(working_dir: PathBuf) -> Self {
39        Self {
40            working_dir,
41            max_file_bytes: 50 * 1024,     // 50KB per file
42            max_context_bytes: 100 * 1024, // 100KB total
43        }
44    }
45
46    /// Set max bytes per file
47    pub fn with_max_file_bytes(mut self, bytes: usize) -> Self {
48        self.max_file_bytes = bytes;
49        self
50    }
51
52    /// Set max total context bytes
53    pub fn with_max_context_bytes(mut self, bytes: usize) -> Self {
54        self.max_context_bytes = bytes;
55        self
56    }
57
58    /// Search for a pattern in the workspace using ripgrep
59    /// Respects .gitignore and common ignore patterns
60    pub fn search(&self, pattern: &str, max_results: usize) -> Vec<SearchHit> {
61        let mut hits = Vec::new();
62
63        // Create regex matcher
64        let matcher = match RegexMatcher::new(pattern) {
65            Ok(m) => m,
66            Err(e) => {
67                log::warn!("Invalid search pattern '{}': {}", pattern, e);
68                return hits;
69            }
70        };
71
72        // Walk workspace respecting .gitignore
73        let walker = WalkBuilder::new(&self.working_dir)
74            .hidden(true) // Skip hidden files
75            .git_ignore(true) // Respect .gitignore
76            .git_global(true) // Respect global gitignore
77            .git_exclude(true) // Respect .git/info/exclude
78            .build();
79
80        let mut searcher = Searcher::new();
81
82        for entry in walker.flatten() {
83            if hits.len() >= max_results {
84                break;
85            }
86
87            let path = entry.path();
88
89            // Only search files
90            if !path.is_file() {
91                continue;
92            }
93
94            // Skip binary files by extension
95            if Self::is_binary_extension(path) {
96                continue;
97            }
98
99            // Search the file
100            let _ = searcher.search_path(
101                &matcher,
102                path,
103                UTF8(|line_num, line| {
104                    if hits.len() < max_results {
105                        let relative_path = path
106                            .strip_prefix(&self.working_dir)
107                            .unwrap_or(path)
108                            .to_path_buf();
109
110                        hits.push(SearchHit {
111                            file: relative_path,
112                            line: line_num as u32,
113                            content: line.trim_end().to_string(),
114                            column: None,
115                        });
116                    }
117                    Ok(hits.len() < max_results)
118                }),
119            );
120        }
121
122        hits
123    }
124
125    /// Read a file with truncation if it exceeds max bytes
126    pub fn read_file_truncated(&self, path: &Path) -> Result<String> {
127        let full_path = if path.is_absolute() {
128            path.to_path_buf()
129        } else {
130            self.working_dir.join(path)
131        };
132
133        let content = std::fs::read_to_string(&full_path)?;
134
135        if content.len() > self.max_file_bytes {
136            let truncated = &content[..self.max_file_bytes];
137            // Find last newline to avoid cutting mid-line
138            let last_newline = truncated.rfind('\n').unwrap_or(self.max_file_bytes);
139            Ok(format!(
140                "{}\n\n... [truncated, {} more bytes]",
141                &content[..last_newline],
142                content.len() - last_newline
143            ))
144        } else {
145            Ok(content)
146        }
147    }
148
149    /// Get context for a task based on its context_files and output_files
150    /// Returns a formatted string suitable for LLM prompts
151    pub fn get_task_context(&self, context_files: &[PathBuf], output_files: &[PathBuf]) -> String {
152        let mut context = String::new();
153        let mut remaining_budget = self.max_context_bytes;
154
155        // Add context files (files to read for understanding)
156        if !context_files.is_empty() {
157            context.push_str("## Context Files (for reference)\n\n");
158            for file in context_files {
159                if remaining_budget == 0 {
160                    break;
161                }
162                if let Ok(content) = self.read_file_truncated(file) {
163                    let section = format!("### {}\n```\n{}\n```\n\n", file.display(), content);
164                    if section.len() <= remaining_budget {
165                        remaining_budget -= section.len();
166                        context.push_str(&section);
167                    }
168                }
169            }
170        }
171
172        // Add output files (files to modify - show current state)
173        if !output_files.is_empty() {
174            context.push_str("## Target Files (to modify)\n\n");
175            for file in output_files {
176                if remaining_budget == 0 {
177                    break;
178                }
179                let full_path = self.working_dir.join(file);
180                if full_path.exists() {
181                    if let Ok(content) = self.read_file_truncated(file) {
182                        let section = format!(
183                            "### {} (current content)\n```\n{}\n```\n\n",
184                            file.display(),
185                            content
186                        );
187                        if section.len() <= remaining_budget {
188                            remaining_budget -= section.len();
189                            context.push_str(&section);
190                        }
191                    }
192                } else {
193                    context.push_str(&format!("### {} (new file)\n\n", file.display()));
194                }
195            }
196        }
197
198        context
199    }
200
201    /// Search for relevant code based on a query (e.g., function name, class name)
202    /// Returns formatted context for LLM
203    pub fn search_for_context(&self, query: &str, max_results: usize) -> String {
204        let hits = self.search(query, max_results);
205
206        if hits.is_empty() {
207            return String::new();
208        }
209
210        let mut context = format!("## Related Code (search: '{}')\n\n", query);
211
212        for hit in &hits {
213            context.push_str(&format!(
214                "- **{}:{}**: `{}`\n",
215                hit.file.display(),
216                hit.line,
217                hit.content.trim()
218            ));
219        }
220        context.push('\n');
221
222        context
223    }
224
225    // =========================================================================
226    // PSP-5 Phase 3: Context Provenance & Structural Digests
227    // =========================================================================
228
229    /// PSP-5 Phase 3: Build a restriction map for a node
230    ///
231    /// The restriction map defines the context boundary: what files, digests,
232    /// and summaries a node is allowed to see. Built from the ownership manifest,
233    /// task graph, and parent scope.
234    pub fn build_restriction_map(
235        &self,
236        node: &perspt_core::types::SRBNNode,
237        manifest: &perspt_core::types::OwnershipManifest,
238    ) -> perspt_core::types::RestrictionMap {
239        let mut map = perspt_core::types::RestrictionMap::for_node(node.node_id.clone());
240
241        // Add files owned by this node
242        let owned = manifest.files_owned_by(&node.node_id);
243        map.owned_files = owned.iter().map(|s| s.to_string()).collect();
244
245        // Add output targets (node's primary files)
246        for target in &node.output_targets {
247            let path_str = target.to_string_lossy().to_string();
248            if !map.owned_files.contains(&path_str) {
249                map.owned_files.push(path_str);
250            }
251        }
252
253        // Add context files as sealed interfaces (read-only dependencies)
254        for ctx_file in &node.context_files {
255            map.sealed_interfaces
256                .push(ctx_file.to_string_lossy().to_string());
257        }
258
259        // Apply budget from retriever limits
260        map.budget = perspt_core::types::ContextBudget {
261            byte_limit: self.max_context_bytes,
262            file_count_limit: 20,
263        };
264
265        map
266    }
267
268    /// PSP-5 Phase 3: Assemble a reproducible context package for a node
269    ///
270    /// Builds a complete, bounded context package from the restriction map.
271    /// Prioritizes: owned files (full content) > sealed interfaces (digest or content) > summaries.
272    pub fn assemble_context_package(
273        &self,
274        node: &perspt_core::types::SRBNNode,
275        restriction_map: &perspt_core::types::RestrictionMap,
276    ) -> perspt_core::types::ContextPackage {
277        let mut package = perspt_core::types::ContextPackage::new(node.node_id.clone());
278        package.restriction_map = restriction_map.clone();
279
280        // 1. Include owned files in full (highest priority — node needs these)
281        for file_path in &restriction_map.owned_files {
282            let full_path = self.working_dir.join(file_path);
283            if full_path.exists() {
284                if let Ok(content) = self.read_file_truncated(&full_path) {
285                    if !package.add_file(file_path, content) {
286                        log::warn!(
287                            "Budget exceeded adding owned file '{}' for node '{}'",
288                            file_path,
289                            node.node_id
290                        );
291                        break;
292                    }
293                }
294            }
295        }
296
297        // 2. Include sealed interfaces (prefer digest if budget is tight)
298        for iface_path in &restriction_map.sealed_interfaces {
299            let full_path = self.working_dir.join(iface_path);
300            if full_path.exists() {
301                // Try to include full content if budget allows
302                if let Ok(content) = self.read_file_truncated(&full_path) {
303                    if !package.add_file(iface_path, content) {
304                        // Budget exceeded — compute digest instead
305                        if let Ok(raw) = std::fs::read(&full_path) {
306                            let digest = perspt_core::types::StructuralDigest::from_content(
307                                &node.node_id,
308                                iface_path,
309                                perspt_core::types::ArtifactKind::InterfaceSeal,
310                                &raw,
311                            );
312                            package.add_structural_digest(digest);
313                        }
314                    }
315                }
316            }
317        }
318
319        // 3. Include any pre-existing structural digests from the restriction map
320        for digest in &restriction_map.structural_digests {
321            package.add_structural_digest(digest.clone());
322        }
323
324        // 4. Include summary digests
325        for summary in &restriction_map.summary_digests {
326            package.add_summary_digest(summary.clone());
327        }
328
329        package
330    }
331
332    /// PSP-5 Phase 3: Compute a structural digest for a file
333    pub fn compute_structural_digest(
334        &self,
335        path: &str,
336        artifact_kind: perspt_core::types::ArtifactKind,
337        source_node_id: &str,
338    ) -> Result<perspt_core::types::StructuralDigest> {
339        let full_path = self.working_dir.join(path);
340        let content = std::fs::read(&full_path)?;
341        Ok(perspt_core::types::StructuralDigest::from_content(
342            source_node_id,
343            path,
344            artifact_kind,
345            &content,
346        ))
347    }
348
349    /// PSP-5 Phase 3: Format a context package as text for LLM prompts
350    pub fn format_context_package(&self, package: &perspt_core::types::ContextPackage) -> String {
351        let mut context = String::new();
352
353        // Owned/included files
354        if !package.included_files.is_empty() {
355            context.push_str("## Context Files\n\n");
356            for (path, content) in &package.included_files {
357                context.push_str(&format!("### {}\n```\n{}\n```\n\n", path, content));
358            }
359        }
360
361        // Structural digests (compact representation)
362        if !package.structural_digests.is_empty() {
363            context.push_str("## Structural Dependencies (digests)\n\n");
364            for digest in &package.structural_digests {
365                context.push_str(&format!(
366                    "- {} ({}) from node '{}' [hash: {:02x}{:02x}..]\n",
367                    digest.source_path,
368                    digest.artifact_kind,
369                    digest.source_node_id,
370                    digest.hash[0],
371                    digest.hash[1],
372                ));
373            }
374            context.push('\n');
375        }
376
377        // Summary digests
378        if !package.summary_digests.is_empty() {
379            context.push_str("## Advisory Summaries\n\n");
380            for summary in &package.summary_digests {
381                context.push_str(&format!(
382                    "### {} (from {})\n{}\n\n",
383                    summary.digest_id, summary.source_node_id, summary.summary_text
384                ));
385            }
386        }
387
388        if package.budget_exceeded {
389            context.push_str(
390                "\n> Note: Context budget was exceeded. Some files replaced with structural digests.\n",
391            );
392        }
393
394        context
395    }
396
397    /// Check if a file extension indicates a binary file
398    fn is_binary_extension(path: &Path) -> bool {
399        match path.extension().and_then(|e| e.to_str()) {
400            Some(ext) => matches!(
401                ext.to_lowercase().as_str(),
402                "png"
403                    | "jpg"
404                    | "jpeg"
405                    | "gif"
406                    | "bmp"
407                    | "ico"
408                    | "webp"
409                    | "pdf"
410                    | "doc"
411                    | "docx"
412                    | "xls"
413                    | "xlsx"
414                    | "ppt"
415                    | "pptx"
416                    | "zip"
417                    | "tar"
418                    | "gz"
419                    | "bz2"
420                    | "7z"
421                    | "rar"
422                    | "exe"
423                    | "dll"
424                    | "so"
425                    | "dylib"
426                    | "a"
427                    | "wasm"
428                    | "o"
429                    | "obj"
430                    | "pyc"
431                    | "pyo"
432                    | "class"
433                    | "db"
434                    | "sqlite"
435                    | "sqlite3"
436            ),
437            None => false,
438        }
439    }
440
441    /// PSP-5 Phase 3: Validate a persisted provenance record against the current workspace.
442    ///
443    /// Parses structural digest references from the provenance record and checks
444    /// whether the referenced source files still exist on disk. Returns a list
445    /// of missing file paths — empty means no drift detected.
446    pub fn validate_provenance_record(
447        &self,
448        record: &perspt_store::ContextProvenanceRecord,
449    ) -> Vec<String> {
450        let mut missing = Vec::new();
451
452        // Parse structural_hashes JSON: entries have format "digest_id:hex_hash"
453        // where digest_id is "source_node_id:source_path:artifact_kind".
454        if let Ok(entries) = serde_json::from_str::<Vec<String>>(&record.structural_hashes) {
455            for entry in &entries {
456                // Entry format: "source_node_id:source_path:artifact_kind:hex_hash"
457                // Split and extract source_path (second segment)
458                let parts: Vec<&str> = entry.splitn(4, ':').collect();
459                if parts.len() >= 3 {
460                    // parts[0] = source_node_id, parts[1] = source_path,
461                    // parts[2..] = artifact_kind:hex_hash
462                    let source_path = parts[1];
463                    let full_path = self.working_dir.join(source_path);
464                    if !full_path.exists() {
465                        missing.push(source_path.to_string());
466                    }
467                }
468            }
469        }
470
471        missing
472    }
473
474    // =========================================================================
475    // PSP-5: Project Summary for Existing-Project Context
476    // =========================================================================
477
478    /// Gather a structured project summary for injection into sheafification prompts.
479    ///
480    /// Returns a formatted string describing: detected language plugins,
481    /// dependency manifests, entry points, test locations, and build system.
482    /// Uses the plugin registry and file-system inspection; zero LLM calls.
483    pub fn get_project_summary(&self) -> String {
484        let registry = perspt_core::plugin::PluginRegistry::new();
485        let detected = registry.detect_all(&self.working_dir);
486
487        if detected.is_empty() {
488            return String::new();
489        }
490
491        let mut summary = String::from("## Existing Project Summary\n\n");
492
493        for plugin in &detected {
494            summary.push_str(&format!("**Language/Plugin:** {}\n", plugin.name()));
495        }
496        summary.push('\n');
497
498        // Dependency manifests
499        let manifest_candidates = [
500            "Cargo.toml",
501            "pyproject.toml",
502            "setup.py",
503            "requirements.txt",
504            "package.json",
505            "uv.lock",
506            "Cargo.lock",
507            "poetry.lock",
508        ];
509        let mut found_manifests = Vec::new();
510        for candidate in &manifest_candidates {
511            if self.working_dir.join(candidate).exists() {
512                found_manifests.push(*candidate);
513            }
514        }
515        if !found_manifests.is_empty() {
516            summary.push_str(&format!(
517                "**Dependency manifests:** {}\n",
518                found_manifests.join(", ")
519            ));
520        }
521
522        // Entry points
523        let entry_candidates = [
524            "src/main.rs",
525            "src/lib.rs",
526            "src/main.py",
527            "main.py",
528            "app.py",
529            "__main__.py",
530            "src/index.ts",
531            "src/index.js",
532            "index.ts",
533            "index.js",
534        ];
535        let mut found_entries = Vec::new();
536        for candidate in &entry_candidates {
537            if self.working_dir.join(candidate).exists() {
538                found_entries.push(*candidate);
539            }
540        }
541        if !found_entries.is_empty() {
542            summary.push_str(&format!("**Entry points:** {}\n", found_entries.join(", ")));
543        }
544
545        // Test locations
546        let test_candidates = ["tests/", "test/", "src/tests/", "tests.py", "test_*.py"];
547        let mut found_tests = Vec::new();
548        for candidate in &test_candidates {
549            if self.working_dir.join(candidate).exists() {
550                found_tests.push(*candidate);
551            }
552        }
553        if !found_tests.is_empty() {
554            summary.push_str(&format!("**Test locations:** {}\n", found_tests.join(", ")));
555        }
556
557        // Read key manifest content (truncated) for context
558        for manifest in &found_manifests {
559            if let Ok(content) = self.read_file_truncated(Path::new(manifest)) {
560                // Only include first 2KB of each manifest
561                let truncated = if content.len() > 2048 {
562                    format!("{}...\n[truncated]", &content[..2048])
563                } else {
564                    content
565                };
566                summary.push_str(&format!("\n### {}\n```\n{}\n```\n", manifest, truncated));
567            }
568        }
569
570        summary
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577    use std::fs;
578    use tempfile::tempdir;
579
580    #[test]
581    fn test_search_finds_pattern() {
582        let dir = tempdir().unwrap();
583        let file_path = dir.path().join("test.py");
584        fs::write(&file_path, "def hello_world():\n    print('Hello')\n").unwrap();
585
586        let retriever = ContextRetriever::new(dir.path().to_path_buf());
587        let hits = retriever.search("hello_world", 10);
588
589        assert_eq!(hits.len(), 1);
590        assert!(hits[0].content.contains("def hello_world"));
591    }
592
593    #[test]
594    fn test_read_file_truncated() {
595        let dir = tempdir().unwrap();
596        let file_path = dir.path().join("large.txt");
597        let content = "line\n".repeat(10000); // ~50KB
598        fs::write(&file_path, &content).unwrap();
599
600        let retriever = ContextRetriever::new(dir.path().to_path_buf()).with_max_file_bytes(1000);
601
602        let result = retriever.read_file_truncated(&file_path).unwrap();
603        assert!(result.contains("truncated"));
604        assert!(result.len() < 2000); // Should be truncated + message
605    }
606
607    // =========================================================================
608    // PSP-5 Phase 3: Restriction Maps & Context Packages
609    // =========================================================================
610
611    #[test]
612    fn test_build_restriction_map() {
613        let dir = tempdir().unwrap();
614        let retriever = ContextRetriever::new(dir.path().to_path_buf());
615
616        let mut node = perspt_core::types::SRBNNode::new(
617            "node_1".to_string(),
618            "test goal".to_string(),
619            perspt_core::types::ModelTier::Actuator,
620        );
621        node.output_targets = vec![std::path::PathBuf::from("src/main.rs")];
622        node.context_files = vec![std::path::PathBuf::from("src/lib.rs")];
623
624        let mut manifest = perspt_core::types::OwnershipManifest::new();
625        manifest.assign(
626            "src/main.rs",
627            "node_1",
628            "rust",
629            perspt_core::types::NodeClass::Implementation,
630        );
631        manifest.assign(
632            "src/utils.rs",
633            "node_1",
634            "rust",
635            perspt_core::types::NodeClass::Implementation,
636        );
637
638        let map = retriever.build_restriction_map(&node, &manifest);
639
640        assert_eq!(map.node_id, "node_1");
641        // Owned files: src/main.rs (from output_targets) + src/utils.rs (from manifest)
642        assert!(map.owned_files.contains(&"src/main.rs".to_string()));
643        assert!(map.owned_files.contains(&"src/utils.rs".to_string()));
644        // Sealed interfaces: src/lib.rs (from context_files)
645        assert_eq!(map.sealed_interfaces, vec!["src/lib.rs".to_string()]);
646    }
647
648    #[test]
649    fn test_assemble_context_package_with_files() {
650        let dir = tempdir().unwrap();
651        // Create a file that the node owns
652        let src_dir = dir.path().join("src");
653        fs::create_dir_all(&src_dir).unwrap();
654        fs::write(src_dir.join("main.rs"), "fn main() {}").unwrap();
655
656        let retriever = ContextRetriever::new(dir.path().to_path_buf());
657
658        let node = perspt_core::types::SRBNNode::new(
659            "node_1".to_string(),
660            "test goal".to_string(),
661            perspt_core::types::ModelTier::Actuator,
662        );
663
664        let mut map = perspt_core::types::RestrictionMap::for_node("node_1".to_string());
665        map.owned_files.push("src/main.rs".to_string());
666        map.budget.byte_limit = 10 * 1024; // 10KB
667
668        let package = retriever.assemble_context_package(&node, &map);
669
670        assert_eq!(package.node_id, "node_1");
671        assert!(package.included_files.contains_key("src/main.rs"));
672        assert!(!package.budget_exceeded);
673        assert!(package.total_bytes > 0);
674    }
675
676    #[test]
677    fn test_assemble_context_package_budget_exceeded() {
678        let dir = tempdir().unwrap();
679        let src_dir = dir.path().join("src");
680        fs::create_dir_all(&src_dir).unwrap();
681        // Create a file larger than the budget
682        fs::write(src_dir.join("big.rs"), "x".repeat(500)).unwrap();
683
684        let retriever = ContextRetriever::new(dir.path().to_path_buf());
685
686        let node = perspt_core::types::SRBNNode::new(
687            "node_1".to_string(),
688            "test goal".to_string(),
689            perspt_core::types::ModelTier::Actuator,
690        );
691
692        let mut map = perspt_core::types::RestrictionMap::for_node("node_1".to_string());
693        map.owned_files.push("src/big.rs".to_string());
694        map.budget.byte_limit = 100; // Very small budget
695
696        let package = retriever.assemble_context_package(&node, &map);
697        assert!(package.budget_exceeded);
698    }
699
700    #[test]
701    fn test_format_context_package_empty() {
702        let retriever = ContextRetriever::new(PathBuf::from("."));
703        let package = perspt_core::types::ContextPackage::new("node_1".to_string());
704
705        let formatted = retriever.format_context_package(&package);
706        assert!(formatted.is_empty());
707    }
708
709    #[test]
710    fn test_format_context_package_with_files() {
711        let retriever = ContextRetriever::new(PathBuf::from("."));
712        let mut package = perspt_core::types::ContextPackage::new("node_1".to_string());
713        package.add_file("src/main.rs", "fn main() {}".to_string());
714
715        let formatted = retriever.format_context_package(&package);
716        assert!(formatted.contains("## Context Files"));
717        assert!(formatted.contains("src/main.rs"));
718        assert!(formatted.contains("fn main() {}"));
719    }
720
721    #[test]
722    fn test_compute_structural_digest() {
723        let dir = tempdir().unwrap();
724        fs::write(dir.path().join("test.rs"), "fn test() {}").unwrap();
725
726        let retriever = ContextRetriever::new(dir.path().to_path_buf());
727        let digest = retriever
728            .compute_structural_digest(
729                "test.rs",
730                perspt_core::types::ArtifactKind::Signature,
731                "node_1",
732            )
733            .unwrap();
734
735        assert_eq!(digest.source_node_id, "node_1");
736        assert_eq!(digest.source_path, "test.rs");
737        assert_ne!(digest.hash, [0u8; 32]);
738    }
739}