Skip to main content

roder_context/
retrieval_router.rs

1use roder_api::context::{
2    ContextBlock, ContextBlockKind, ContextPlan, ContextPlanner, ContextPlannerId, ContextQuery,
3};
4use roder_api::retrieval::{
5    RetrievalAvoidance, RetrievalConfidence, RetrievalIntent, RetrievalMode,
6    RetrievalRecommendation, RetrievalRoutePlan,
7};
8use serde::Serialize;
9use serde_json::json;
10use time::OffsetDateTime;
11
12#[derive(Debug, Clone, Default)]
13pub struct RetrievalRouterPlanner;
14
15#[async_trait::async_trait]
16impl ContextPlanner for RetrievalRouterPlanner {
17    fn id(&self) -> ContextPlannerId {
18        "retrieval-router".to_string()
19    }
20
21    async fn plan(
22        &self,
23        query: &ContextQuery,
24        mut provider_blocks: Vec<ContextBlock>,
25    ) -> anyhow::Result<ContextPlan> {
26        let plan = route_retrieval(query, &provider_blocks);
27        if !plan.recommended.is_empty() || !plan.avoid.is_empty() {
28            provider_blocks.push(render_retrieval_block(&plan));
29        }
30        provider_blocks.sort_by_key(|block| std::cmp::Reverse(block.priority));
31        Ok(ContextPlan {
32            blocks: provider_blocks,
33        })
34    }
35}
36
37pub fn route_retrieval(
38    query: &ContextQuery,
39    provider_blocks: &[ContextBlock],
40) -> RetrievalRoutePlan {
41    let prompt = query.prompt.to_ascii_lowercase();
42    let mut recommended = Vec::new();
43    let mut avoid = Vec::new();
44    let intent = classify_intent(&prompt);
45    let semantic_ready = provider_blocks.iter().any(|block| {
46        block
47            .metadata
48            .get("source")
49            .and_then(serde_json::Value::as_str)
50            == Some("indexed_semantic_code_search")
51    });
52
53    if looks_like_command_failure(&prompt) {
54        recommended.push(recommend(
55            RetrievalMode::Artifact,
56            "grep_artifact",
57            extract_query(&query.prompt),
58            "command or terminal output failure should start with saved artifact search",
59            RetrievalConfidence::High,
60        ));
61    }
62    if looks_like_capability_lookup(&prompt) {
63        if let Some(item_id) = matching_promoted_capability(provider_blocks, &prompt) {
64            let mut rec = recommend(
65                RetrievalMode::Promotion,
66                "discovery.read",
67                extract_query(&query.prompt),
68                "matching capability is already promoted or warm-cached for this session",
69                RetrievalConfidence::High,
70            );
71            rec.item_id = Some(item_id);
72            recommended.push(rec);
73        }
74        recommended.push(recommend(
75            RetrievalMode::Discovery,
76            "discovery.search",
77            extract_query(&query.prompt),
78            "tool, MCP, skill, command, or plugin capability lookup",
79            RetrievalConfidence::High,
80        ));
81    }
82    if looks_like_capability_execution(&prompt) {
83        recommended.push(recommend(
84            RetrievalMode::Promotion,
85            "discovery.read",
86            extract_query(&query.prompt),
87            "full schema or instructions are needed before capability use",
88            RetrievalConfidence::High,
89        ));
90    }
91    if looks_like_file_name(&query.prompt) {
92        recommended.push(recommend(
93            RetrievalMode::FileName,
94            "glob",
95            extract_query(&query.prompt),
96            "path or filename-shaped prompt",
97            RetrievalConfidence::High,
98        ));
99    }
100    if looks_like_exact_search(&query.prompt) {
101        recommended.push(recommend(
102            RetrievalMode::ExactText,
103            "grep",
104            extract_query(&query.prompt),
105            "exact symbol, path, regex, or error string",
106            RetrievalConfidence::High,
107        ));
108    }
109    if matches!(intent, RetrievalIntent::BroadConcept) {
110        if semantic_ready {
111            recommended.push(recommend(
112                RetrievalMode::SemanticCode,
113                "code_index.search",
114                extract_query(&query.prompt),
115                "conceptual code search with ready semantic index",
116                RetrievalConfidence::Medium,
117            ));
118        } else {
119            recommended.push(recommend(
120                RetrievalMode::ExactText,
121                "grep",
122                extract_query(&query.prompt),
123                "semantic index not observed; start with exact local search fallback",
124                RetrievalConfidence::Medium,
125            ));
126        }
127    }
128    if prompt.contains("history") || prompt.contains("previous turn") || prompt.contains("resume") {
129        recommended.push(recommend(
130            RetrievalMode::History,
131            "history.search",
132            extract_query(&query.prompt),
133            "prior conversation or session recovery",
134            RetrievalConfidence::Medium,
135        ));
136    }
137    if prompt.contains("code") || prompt.contains("repo") || prompt.contains("workspace") {
138        avoid.push(RetrievalAvoidance {
139            mode: RetrievalMode::Web,
140            reason: "local workspace retrieval should be tried before web search".to_string(),
141        });
142    }
143
144    dedupe_recommendations(&mut recommended);
145    RetrievalRoutePlan {
146        route_id: format!("route:{}:{}", query.thread_id, query.turn_id),
147        thread_id: query.thread_id.clone(),
148        turn_id: query.turn_id.clone(),
149        intent,
150        recommended,
151        avoid,
152        timestamp: OffsetDateTime::now_utc(),
153    }
154}
155
156fn render_retrieval_block(plan: &RetrievalRoutePlan) -> ContextBlock {
157    let mut text = format!("Retrieval route intent: {:?}", plan.intent);
158    for (index, rec) in plan.recommended.iter().take(5).enumerate() {
159        text.push_str(&format!(
160            "\n{}. {:?} via `{}` query `{}` - {}",
161            index + 1,
162            rec.mode,
163            rec.tool,
164            truncate(&rec.query, 80),
165            rec.reason
166        ));
167    }
168    if !plan.avoid.is_empty() {
169        let avoid = plan
170            .avoid
171            .iter()
172            .map(|avoid| format!("{:?}: {}", avoid.mode, avoid.reason))
173            .collect::<Vec<_>>()
174            .join("; ");
175        text.push_str(&format!("\nAvoid: {avoid}"));
176    }
177
178    ContextBlock {
179        id: "retrieval-router".to_string(),
180        kind: ContextBlockKind::RetrievalHint,
181        text,
182        priority: 88,
183        token_estimate: None,
184        metadata: json!({
185            "planner": "retrieval-router",
186            "route_id": plan.route_id,
187            "intent": format!("{:?}", plan.intent),
188            "retrievalPlan": serializable(plan),
189            "recommended": serializable(&plan.recommended),
190            "avoid": serializable(&plan.avoid),
191        }),
192    }
193}
194
195fn classify_intent(prompt: &str) -> RetrievalIntent {
196    if prompt.contains("tool")
197        || prompt.contains("mcp")
198        || prompt.contains("skill")
199        || prompt.contains("plugin")
200    {
201        return RetrievalIntent::InspectTool;
202    }
203    if looks_like_command_failure(prompt) {
204        return RetrievalIntent::DebugFailure;
205    }
206    if prompt.contains("usage") || prompt.contains("call sites") || prompt.contains("where used") {
207        return RetrievalIntent::TraceUsage;
208    }
209    if prompt.contains("history") || prompt.contains("previous turn") || prompt.contains("resume") {
210        return RetrievalIntent::RecoverHistory;
211    }
212    if prompt.contains("file") || prompt.contains("path") || prompt.contains("filename") {
213        return RetrievalIntent::FileLookup;
214    }
215    if looks_like_exact_search(prompt) {
216        return RetrievalIntent::FindDefinition;
217    }
218    RetrievalIntent::BroadConcept
219}
220
221fn recommend(
222    mode: RetrievalMode,
223    tool: &str,
224    query: String,
225    reason: &str,
226    confidence: RetrievalConfidence,
227) -> RetrievalRecommendation {
228    RetrievalRecommendation {
229        mode,
230        tool: tool.to_string(),
231        query,
232        reason: reason.to_string(),
233        confidence,
234        item_id: None,
235    }
236}
237
238fn dedupe_recommendations(recommended: &mut Vec<RetrievalRecommendation>) {
239    let mut seen = std::collections::BTreeSet::new();
240    recommended.retain(|rec| seen.insert((rec.mode.clone(), rec.tool.clone())));
241    recommended.truncate(5);
242}
243
244fn looks_like_capability_lookup(prompt: &str) -> bool {
245    prompt.contains("tool")
246        || prompt.contains("mcp")
247        || prompt.contains("skill")
248        || prompt.contains("command")
249        || prompt.contains("plugin")
250}
251
252fn looks_like_capability_execution(prompt: &str) -> bool {
253    looks_like_capability_lookup(prompt)
254        && (prompt.contains("run")
255            || prompt.contains("use")
256            || prompt.contains("execute")
257            || prompt.contains("call")
258            || prompt.contains("invoke"))
259}
260
261fn looks_like_command_failure(prompt: &str) -> bool {
262    prompt.contains("stderr")
263        || prompt.contains("stdout")
264        || prompt.contains("exit code")
265        || prompt.contains("terminal")
266        || prompt.contains("command failed")
267        || prompt.contains("panic")
268        || prompt.contains("stack trace")
269}
270
271fn looks_like_file_name(prompt: &str) -> bool {
272    prompt.contains('/')
273        || prompt.contains(".rs")
274        || prompt.contains(".ts")
275        || prompt.contains(".tsx")
276        || prompt.contains(".json")
277        || prompt.contains(".toml")
278        || prompt.contains(".md")
279}
280
281fn looks_like_exact_search(prompt: &str) -> bool {
282    prompt.contains("::")
283        || prompt.contains("->")
284        || prompt.contains("fn ")
285        || prompt.contains("struct ")
286        || prompt.contains("enum ")
287        || prompt.split_whitespace().any(|token| {
288            token.len() >= 4
289                && token
290                    .chars()
291                    .any(|ch| ch == '_' || ch.is_ascii_uppercase() || ch.is_ascii_digit())
292        })
293}
294
295fn matching_promoted_capability(blocks: &[ContextBlock], prompt: &str) -> Option<String> {
296    let prompt_tokens = prompt
297        .split(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '-')
298        .map(str::to_ascii_lowercase)
299        .filter(|token| token.len() >= 3)
300        .collect::<Vec<_>>();
301    blocks.iter().find_map(|block| {
302        let source = block
303            .metadata
304            .get("source")
305            .and_then(serde_json::Value::as_str)
306            .unwrap_or_default();
307        if !matches!(
308            source,
309            "promoted_capabilities" | "discovery_promotions" | "warm_cached_capabilities"
310        ) {
311            return None;
312        }
313        let item_id = block
314            .metadata
315            .get("item_id")
316            .and_then(serde_json::Value::as_str)
317            .or_else(|| {
318                block
319                    .metadata
320                    .get("itemId")
321                    .and_then(serde_json::Value::as_str)
322            })?;
323        let haystack = format!("{} {}", item_id, block.text).to_ascii_lowercase();
324        prompt_tokens
325            .iter()
326            .any(|token| haystack.contains(token))
327            .then(|| item_id.to_string())
328    })
329}
330
331fn extract_query(prompt: &str) -> String {
332    truncate(prompt.trim(), 120).to_string()
333}
334
335fn truncate(text: &str, max: usize) -> &str {
336    if text.len() <= max {
337        return text;
338    }
339    let mut end = max;
340    while !text.is_char_boundary(end) {
341        end -= 1;
342    }
343    &text[..end]
344}
345
346fn serializable<T: Serialize>(value: &T) -> serde_json::Value {
347    serde_json::to_value(value).unwrap_or_else(|_| serde_json::Value::Null)
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[tokio::test]
355    async fn retrieval_router_routes_exact_symbols_to_grep() {
356        let planner = RetrievalRouterPlanner;
357        let plan = planner
358            .plan(&query("Find ToolExecutionContext in the repo"), Vec::new())
359            .await
360            .unwrap();
361
362        let block = plan
363            .blocks
364            .iter()
365            .find(|block| block.kind == ContextBlockKind::RetrievalHint)
366            .unwrap();
367        assert!(block.text.contains("ExactText"));
368        assert!(block.text.contains("`grep`"));
369    }
370
371    #[tokio::test]
372    async fn retrieval_router_routes_concepts_to_semantic_when_index_ready() {
373        let planner = RetrievalRouterPlanner;
374        let semantic_block = ContextBlock {
375            id: "code-index".to_string(),
376            kind: ContextBlockKind::RetrievedDocument,
377            text: "Indexed context".to_string(),
378            priority: 10,
379            token_estimate: None,
380            metadata: json!({ "source": "indexed_semantic_code_search" }),
381        };
382
383        let plan = planner
384            .plan(
385                &query("How does the policy gate choose approvals?"),
386                vec![semantic_block],
387            )
388            .await
389            .unwrap();
390
391        let block = plan.blocks.first().unwrap();
392        assert_eq!(block.kind, ContextBlockKind::RetrievalHint);
393        assert!(block.text.contains("SemanticCode"));
394        assert!(block.text.contains("code_index.search"));
395    }
396
397    #[tokio::test]
398    async fn retrieval_router_routes_capability_execution_to_discovery_and_promotion() {
399        let planner = RetrievalRouterPlanner;
400        let plan = planner
401            .plan(
402                &query("Use the GitHub MCP issue search tool to find blockers"),
403                Vec::new(),
404            )
405            .await
406            .unwrap();
407        let block = plan.blocks.first().unwrap();
408
409        assert!(block.text.contains("Discovery"));
410        assert!(block.text.contains("Promotion"));
411        assert!(block.text.contains("discovery.search"));
412        assert!(block.text.contains("discovery.read"));
413    }
414
415    #[tokio::test]
416    async fn retrieval_router_prefers_promoted_capability_state() {
417        let planner = RetrievalRouterPlanner;
418        let promoted = ContextBlock {
419            id: "promoted-github".to_string(),
420            kind: ContextBlockKind::ToolAvailability,
421            text: "GitHub issue search is promoted".to_string(),
422            priority: 20,
423            token_estimate: None,
424            metadata: json!({
425                "source": "promoted_capabilities",
426                "item_id": "mcp:github/issues.search",
427            }),
428        };
429
430        let plan = planner
431            .plan(
432                &query("Use the GitHub MCP issue search tool"),
433                vec![promoted],
434            )
435            .await
436            .unwrap();
437        let block = plan.blocks.first().unwrap();
438
439        assert!(block.text.contains("already promoted or warm-cached"));
440        assert_eq!(
441            block.metadata["recommended"][0]["itemId"],
442            "mcp:github/issues.search"
443        );
444    }
445
446    #[tokio::test]
447    async fn retrieval_router_routes_command_failures_to_artifacts() {
448        let planner = RetrievalRouterPlanner;
449        let plan = planner
450            .plan(
451                &query("A terminal command failed with stderr; inspect the log"),
452                Vec::new(),
453            )
454            .await
455            .unwrap();
456        let block = plan.blocks.first().unwrap();
457
458        assert!(block.text.contains("Artifact"));
459        assert!(block.text.contains("grep_artifact"));
460    }
461
462    fn query(prompt: &str) -> ContextQuery {
463        ContextQuery {
464            thread_id: "thread-a".to_string(),
465            turn_id: "turn-a".to_string(),
466            prompt: prompt.to_string(),
467            workspace: None,
468            token_budget: None,
469        }
470    }
471}