Skip to main content

roder_context/
entrypoint.rs

1use std::collections::HashSet;
2use std::path::{Path, PathBuf};
3
4use roder_api::context::{
5    ContextBlock, ContextBlockKind, ContextPlan, ContextPlanner, ContextPlannerId, ContextQuery,
6};
7use roder_search::{DEFAULT_MAX_FILE_SIZE, SearchMode, SearchOptions, search_workspace};
8use serde::Serialize;
9use serde_json::json;
10
11const MAX_CANDIDATES: usize = 5;
12const MAX_FILES_SCANNED: usize = 2_000;
13const MAX_CONTENT_BYTES: u64 = 8 * 1024;
14
15#[derive(Debug, Clone)]
16pub struct EntrypointContextPlanner {
17    workspace: PathBuf,
18}
19
20impl EntrypointContextPlanner {
21    pub fn new(workspace: impl Into<PathBuf>) -> Self {
22        Self {
23            workspace: workspace.into(),
24        }
25    }
26}
27
28#[async_trait::async_trait]
29impl ContextPlanner for EntrypointContextPlanner {
30    fn id(&self) -> ContextPlannerId {
31        "entrypoint-context-planner".to_string()
32    }
33
34    async fn plan(
35        &self,
36        query: &ContextQuery,
37        mut provider_blocks: Vec<ContextBlock>,
38    ) -> anyhow::Result<ContextPlan> {
39        let workspace = query
40            .workspace
41            .as_deref()
42            .map(Path::new)
43            .unwrap_or(&self.workspace);
44        if !workspace.is_dir() {
45            provider_blocks.sort_by_key(|block| std::cmp::Reverse(block.priority));
46            return Ok(ContextPlan {
47                blocks: provider_blocks,
48            });
49        }
50        let candidates = discover_entrypoints(workspace, &query.prompt)?;
51        if !candidates.is_empty() {
52            provider_blocks.push(render_entrypoint_block(&candidates));
53        }
54        provider_blocks.sort_by_key(|block| std::cmp::Reverse(block.priority));
55        Ok(ContextPlan {
56            blocks: provider_blocks,
57        })
58    }
59}
60
61#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
62struct EntrypointCandidate {
63    path: String,
64    score: i32,
65    reasons: Vec<String>,
66}
67
68fn discover_entrypoints(root: &Path, prompt: &str) -> anyhow::Result<Vec<EntrypointCandidate>> {
69    let tokens = prompt_tokens(prompt);
70    if tokens.is_empty() {
71        return Ok(Vec::new());
72    }
73
74    let changed = changed_paths(root);
75    let search_hits = fresh_search_hits(root, &tokens);
76    let mut candidates = Vec::new();
77    let mut scanned = 0usize;
78    visit_files(root, &mut |path| {
79        if scanned >= MAX_FILES_SCANNED {
80            return Ok(());
81        }
82        scanned += 1;
83        if let Some(candidate) = score_file(root, path, &tokens, &changed, &search_hits)? {
84            candidates.push(candidate);
85        }
86        Ok(())
87    })?;
88    candidates.sort_by(|a, b| b.score.cmp(&a.score).then_with(|| a.path.cmp(&b.path)));
89    candidates.truncate(MAX_CANDIDATES);
90    Ok(candidates)
91}
92
93fn score_file(
94    root: &Path,
95    path: &Path,
96    tokens: &[String],
97    changed: &HashSet<String>,
98    search_hits: &HashSet<String>,
99) -> anyhow::Result<Option<EntrypointCandidate>> {
100    let rel = path
101        .strip_prefix(root)
102        .unwrap_or(path)
103        .to_string_lossy()
104        .replace('\\', "/");
105    let rel_lower = rel.to_ascii_lowercase();
106    let file_type_score = extension_score(path);
107    let mut score = 0;
108    let mut reasons = Vec::new();
109
110    for token in tokens {
111        if rel_lower.contains(token) {
112            score += 8;
113            reasons.push(format!("path matches `{token}`"));
114        }
115    }
116
117    if changed.contains(&rel) {
118        score += 10;
119        reasons.push("recent git change".to_string());
120    }
121
122    if search_hits.contains(&rel) {
123        score += 6;
124        reasons.push("fresh search hit".to_string());
125    }
126
127    if likely_entrypoint_name(path) {
128        score += 4;
129        reasons.push("entrypoint-like filename".to_string());
130    }
131
132    if let Ok(metadata) = std::fs::metadata(path)
133        && metadata.len() <= MAX_CONTENT_BYTES
134        && let Ok(text) = std::fs::read_to_string(path)
135    {
136        let text = text.to_ascii_lowercase();
137        for token in tokens {
138            if text.contains(token) {
139                score += 3;
140                reasons.push(format!("bounded content matches `{token}`"));
141            }
142        }
143    }
144
145    Ok((score > 0).then_some(EntrypointCandidate {
146        path: rel,
147        score: score + file_type_score,
148        reasons,
149    }))
150}
151
152fn render_entrypoint_block(candidates: &[EntrypointCandidate]) -> ContextBlock {
153    let mut text = String::from("Likely entry points:");
154    for (index, candidate) in candidates.iter().enumerate() {
155        let reason = candidate
156            .reasons
157            .first()
158            .map(String::as_str)
159            .unwrap_or("workspace evidence");
160        text.push_str(&format!("\n{}. {} - {}", index + 1, candidate.path, reason));
161    }
162
163    ContextBlock {
164        id: "entrypoint-context-planner".to_string(),
165        kind: ContextBlockKind::EntrypointHint,
166        text,
167        priority: 90,
168        token_estimate: None,
169        metadata: json!({
170            "planner": "entrypoint-context-planner",
171            "candidate_count": candidates.len(),
172            "candidates": candidates,
173            "source": "fresh_filesystem_heuristics",
174        }),
175    }
176}
177
178fn visit_files(
179    root: &Path,
180    visitor: &mut dyn FnMut(&Path) -> anyhow::Result<()>,
181) -> anyhow::Result<()> {
182    if root.is_file() {
183        return visitor(root);
184    }
185    for entry in std::fs::read_dir(root)? {
186        let entry = entry?;
187        let path = entry.path();
188        let name = entry.file_name();
189        let name = name.to_string_lossy();
190        if entry.file_type()?.is_dir() {
191            if matches!(name.as_ref(), ".git" | "target" | "node_modules" | ".roder") {
192                continue;
193            }
194            visit_files(&path, visitor)?;
195        } else {
196            visitor(&path)?;
197        }
198    }
199    Ok(())
200}
201
202fn prompt_tokens(prompt: &str) -> Vec<String> {
203    let mut seen = HashSet::new();
204    prompt
205        .split(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '-')
206        .map(|token| token.trim().to_ascii_lowercase())
207        .filter(|token| token.len() >= 3 && !STOP_WORDS.contains(&token.as_str()))
208        .filter(|token| seen.insert(token.clone()))
209        .collect()
210}
211
212fn extension_score(path: &Path) -> i32 {
213    match path.extension().and_then(|ext| ext.to_str()) {
214        Some("rs" | "toml" | "md") => 2,
215        Some("ts" | "tsx" | "js" | "jsx" | "py") => 1,
216        _ => 0,
217    }
218}
219
220fn likely_entrypoint_name(path: &Path) -> bool {
221    let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
222        return false;
223    };
224    matches!(
225        name,
226        "lib.rs" | "main.rs" | "mod.rs" | "runtime.rs" | "server.rs" | "index.ts" | "index.tsx"
227    )
228}
229
230fn changed_paths(root: &Path) -> HashSet<String> {
231    let Ok(output) = std::process::Command::new("git")
232        .arg("-C")
233        .arg(root)
234        .arg("status")
235        .arg("--short")
236        .output()
237    else {
238        return HashSet::new();
239    };
240    if !output.status.success() {
241        return HashSet::new();
242    }
243    String::from_utf8_lossy(&output.stdout)
244        .lines()
245        .filter_map(|line| line.get(3..))
246        .map(|path| path.trim().replace('\\', "/"))
247        .collect()
248}
249
250fn fresh_search_hits(root: &Path, tokens: &[String]) -> HashSet<String> {
251    let mut hits = HashSet::new();
252    for token in tokens.iter().take(6) {
253        let mut options = SearchOptions::new(token.clone())
254            .with_mode(SearchMode::Scan)
255            .case_sensitive(false);
256        options.max_file_size = DEFAULT_MAX_FILE_SIZE.min(MAX_CONTENT_BYTES);
257        let Ok(results) = search_workspace(root, &options) else {
258            continue;
259        };
260        for hit in results.matches.iter().take(100) {
261            hits.insert(hit.path.to_string_lossy().replace('\\', "/"));
262        }
263    }
264    hits
265}
266
267const STOP_WORDS: &[&str] = &[
268    "the", "and", "for", "with", "that", "this", "from", "into", "where", "when", "what", "why",
269    "how", "need", "needs", "find", "file", "files", "code", "task", "work",
270];
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use roder_api::context::ContextQuery;
276
277    #[tokio::test]
278    async fn entrypoint_planner_puts_relevant_file_in_top_five() {
279        let root = test_workspace("entrypoint-top-five");
280        write(
281            &root,
282            "crates/roder-core/src/runtime.rs",
283            "fn route_tools() {}\n",
284        );
285        write(
286            &root,
287            "crates/roder-tools/src/files.rs",
288            "fn read_file() {}\n",
289        );
290        write(&root, "README.md", "Roder docs\n");
291        let planner = EntrypointContextPlanner::new(root.clone());
292
293        let plan = planner
294            .plan(&query("debug runtime tool routing"), Vec::new())
295            .await
296            .unwrap();
297
298        let block = plan.blocks.first().unwrap();
299        assert!(block.text.contains("crates/roder-core/src/runtime.rs"));
300        assert_eq!(block.kind, ContextBlockKind::EntrypointHint);
301        assert!(block.text.lines().nth(1).unwrap().contains("runtime.rs"));
302        assert!(block.metadata["candidate_count"].as_u64().unwrap() <= MAX_CANDIDATES as u64);
303
304        let _ = std::fs::remove_dir_all(root);
305    }
306
307    #[tokio::test]
308    async fn entrypoint_planner_keeps_output_bounded_for_large_files() {
309        let root = test_workspace("entrypoint-bounded");
310        write(&root, "src/runtime.rs", &"runtime ".repeat(4_000));
311        let planner = EntrypointContextPlanner::new(root.clone());
312
313        let plan = planner
314            .plan(&query("runtime entrypoint"), Vec::new())
315            .await
316            .unwrap();
317        let block = plan.blocks.first().unwrap();
318
319        assert!(block.text.len() < 1_000);
320        assert!(!block.text.contains(&"runtime ".repeat(100)));
321
322        let _ = std::fs::remove_dir_all(root);
323    }
324
325    fn query(prompt: &str) -> ContextQuery {
326        ContextQuery {
327            thread_id: "thread-a".to_string(),
328            turn_id: "turn-a".to_string(),
329            prompt: prompt.to_string(),
330            workspace: None,
331            token_budget: None,
332        }
333    }
334
335    fn write(root: &Path, path: &str, text: &str) {
336        let path = root.join(path);
337        std::fs::create_dir_all(path.parent().unwrap()).unwrap();
338        std::fs::write(path, text).unwrap();
339    }
340
341    fn test_workspace(name: &str) -> PathBuf {
342        let stamp = std::time::SystemTime::now()
343            .duration_since(std::time::UNIX_EPOCH)
344            .unwrap()
345            .as_nanos();
346        let path = std::env::temp_dir().join(format!("roder-context-{name}-{stamp}"));
347        let _ = std::fs::remove_dir_all(&path);
348        std::fs::create_dir_all(&path).unwrap();
349        path
350    }
351}