Skip to main content

lean_ctx/tools/
ctx_preload.rs

1use crate::core::cache::SessionCache;
2use crate::core::graph_index::ProjectIndex;
3use crate::core::protocol;
4use crate::core::task_relevance::{compute_relevance, parse_task_hints};
5use crate::core::tokens::count_tokens;
6use crate::tools::CrpMode;
7
8const MAX_PRELOAD_FILES: usize = 5;
9const MAX_CRITICAL_LINES: usize = 15;
10const SIGNATURES_BUDGET: usize = 10;
11
12pub fn handle(
13    cache: &mut SessionCache,
14    task: &str,
15    path: Option<&str>,
16    crp_mode: CrpMode,
17) -> String {
18    if task.trim().is_empty() {
19        return "ERROR: ctx_preload requires a task description".to_string();
20    }
21
22    let project_root = path.map(|p| p.to_string()).unwrap_or_else(|| {
23        std::env::current_dir()
24            .map(|p| p.to_string_lossy().to_string())
25            .unwrap_or_else(|_| ".".to_string())
26    });
27
28    let mut index = ProjectIndex::load(&project_root).unwrap_or_else(|| {
29        let new_index = ProjectIndex::new(&project_root);
30        let _ = new_index.save();
31        new_index
32    });
33    if index.files.is_empty() {
34        index = ProjectIndex::new(&project_root);
35        let _ = index.save();
36    }
37
38    let (task_files, task_keywords) = parse_task_hints(task);
39    let relevance = compute_relevance(&index, &task_files, &task_keywords);
40
41    let critical: Vec<_> = relevance
42        .iter()
43        .filter(|r| r.score >= 0.5)
44        .take(MAX_PRELOAD_FILES)
45        .collect();
46
47    if critical.is_empty() {
48        return format!(
49            "[task: {task}]\nNo directly relevant files found. Use ctx_overview for project map."
50        );
51    }
52
53    let mut output = Vec::new();
54    output.push(format!("[task: {task}]"));
55
56    let mut total_estimated_saved = 0usize;
57
58    for rel in &critical {
59        let content = match std::fs::read_to_string(&rel.path) {
60            Ok(c) => c,
61            Err(_) => continue,
62        };
63
64        let file_ref = cache.get_file_ref(&rel.path);
65        let short = protocol::shorten_path(&rel.path);
66        let line_count = content.lines().count();
67        let file_tokens = count_tokens(&content);
68
69        let (entry, _) = cache.store(&rel.path, content.clone());
70        let _ = entry;
71
72        let critical_lines = extract_critical_lines(&content, &task_keywords, MAX_CRITICAL_LINES);
73        let sigs = extract_key_signatures(&content, SIGNATURES_BUDGET);
74        let imports = extract_imports(&content);
75
76        output.push(format!(
77            "\nCRITICAL: {file_ref}={short} {line_count}L score={:.1}",
78            rel.score
79        ));
80
81        if !critical_lines.is_empty() {
82            for (line_no, line) in &critical_lines {
83                output.push(format!("  :{line_no} {line}"));
84            }
85        }
86
87        if !imports.is_empty() {
88            output.push(format!("  imports: {}", imports.join(", ")));
89        }
90
91        if !sigs.is_empty() {
92            for sig in &sigs {
93                output.push(format!("  {sig}"));
94            }
95        }
96
97        total_estimated_saved += file_tokens;
98    }
99
100    let context_files: Vec<_> = relevance
101        .iter()
102        .filter(|r| r.score >= 0.2 && r.score < 0.5)
103        .take(10)
104        .collect();
105
106    if !context_files.is_empty() {
107        output.push("\nRELATED:".to_string());
108        for rel in &context_files {
109            let short = protocol::shorten_path(&rel.path);
110            output.push(format!(
111                "  {} mode={} score={:.1}",
112                short, rel.recommended_mode, rel.score
113            ));
114        }
115    }
116
117    let graph_edges: Vec<_> = index
118        .edges
119        .iter()
120        .filter(|e| critical.iter().any(|c| c.path == e.from || c.path == e.to))
121        .take(10)
122        .collect();
123
124    if !graph_edges.is_empty() {
125        output.push("\nGRAPH:".to_string());
126        for edge in &graph_edges {
127            let from_short = protocol::shorten_path(&edge.from);
128            let to_short = protocol::shorten_path(&edge.to);
129            output.push(format!("  {from_short} -> {to_short}"));
130        }
131    }
132
133    let preload_result = output.join("\n");
134    let preload_tokens = count_tokens(&preload_result);
135    let savings = protocol::format_savings(total_estimated_saved, preload_tokens);
136
137    if crp_mode.is_tdd() {
138        format!("{preload_result}\n{savings}")
139    } else {
140        format!("{preload_result}\n\nNext: ctx_read(path, mode=\"full\") for any file above.\n{savings}")
141    }
142}
143
144fn extract_critical_lines(content: &str, keywords: &[String], max: usize) -> Vec<(usize, String)> {
145    let kw_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
146
147    let mut hits: Vec<(usize, String, usize)> = content
148        .lines()
149        .enumerate()
150        .filter_map(|(i, line)| {
151            let trimmed = line.trim();
152            if trimmed.is_empty() {
153                return None;
154            }
155            let line_lower = trimmed.to_lowercase();
156            let hit_count = kw_lower
157                .iter()
158                .filter(|kw| line_lower.contains(kw.as_str()))
159                .count();
160
161            let is_error = trimmed.contains("Error")
162                || trimmed.contains("Err(")
163                || trimmed.contains("panic!")
164                || trimmed.contains("unwrap()")
165                || trimmed.starts_with("return Err");
166
167            if hit_count > 0 || is_error {
168                let priority = hit_count + if is_error { 2 } else { 0 };
169                Some((i + 1, trimmed.to_string(), priority))
170            } else {
171                None
172            }
173        })
174        .collect();
175
176    hits.sort_by(|a, b| b.2.cmp(&a.2));
177    hits.truncate(max);
178    hits.iter().map(|(n, l, _)| (*n, l.clone())).collect()
179}
180
181fn extract_key_signatures(content: &str, max: usize) -> Vec<String> {
182    let sig_starters = [
183        "pub fn ",
184        "pub async fn ",
185        "pub struct ",
186        "pub enum ",
187        "pub trait ",
188        "pub type ",
189        "pub const ",
190    ];
191
192    content
193        .lines()
194        .filter(|line| {
195            let trimmed = line.trim();
196            sig_starters.iter().any(|s| trimmed.starts_with(s))
197        })
198        .take(max)
199        .map(|line| {
200            let trimmed = line.trim();
201            if trimmed.len() > 120 {
202                format!("{}...", &trimmed[..117])
203            } else {
204                trimmed.to_string()
205            }
206        })
207        .collect()
208}
209
210fn extract_imports(content: &str) -> Vec<String> {
211    content
212        .lines()
213        .filter(|line| {
214            let t = line.trim();
215            t.starts_with("use ") || t.starts_with("import ") || t.starts_with("from ")
216        })
217        .take(8)
218        .map(|line| {
219            let t = line.trim();
220            if let Some(rest) = t.strip_prefix("use ") {
221                rest.trim_end_matches(';').to_string()
222            } else {
223                t.to_string()
224            }
225        })
226        .collect()
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn extract_critical_lines_finds_keywords() {
235        let content = "fn main() {\n    let token = validate();\n    return Err(e);\n}\n";
236        let result = extract_critical_lines(content, &["validate".to_string()], 5);
237        assert!(!result.is_empty());
238        assert!(result.iter().any(|(_, l)| l.contains("validate")));
239    }
240
241    #[test]
242    fn extract_critical_lines_prioritizes_errors() {
243        let content = "fn main() {\n    let x = 1;\n    return Err(\"bad\");\n    let token = validate();\n}\n";
244        let result = extract_critical_lines(content, &["validate".to_string()], 5);
245        assert!(result.len() >= 2);
246        assert!(result[0].1.contains("Err"), "errors should be first");
247    }
248
249    #[test]
250    fn extract_key_signatures_finds_pub() {
251        let content = "use std::io;\nfn private() {}\npub fn public_one() {}\npub struct Foo {}\n";
252        let sigs = extract_key_signatures(content, 10);
253        assert_eq!(sigs.len(), 2);
254        assert!(sigs[0].contains("pub fn public_one"));
255        assert!(sigs[1].contains("pub struct Foo"));
256    }
257
258    #[test]
259    fn extract_imports_works() {
260        let content = "use std::io;\nuse crate::core::cache;\nfn main() {}\n";
261        let imports = extract_imports(content);
262        assert_eq!(imports.len(), 2);
263        assert!(imports[0].contains("std::io"));
264    }
265}