1use roder_api::context::{
2 ContextBlock, ContextBlockKind, ContextPlan, ContextPlanner, ContextPlannerId, ContextQuery,
3};
4use roder_api::retrieval::{
5 RetrievalAvoidance, RetrievalConfidence, RetrievalIntent, RetrievalMode,
6 RetrievalRecommendation, RetrievalRoutePlan,
7};
8use serde::Serialize;
9use serde_json::json;
10use time::OffsetDateTime;
11
12#[derive(Debug, Clone, Default)]
13pub struct RetrievalRouterPlanner;
14
15#[async_trait::async_trait]
16impl ContextPlanner for RetrievalRouterPlanner {
17 fn id(&self) -> ContextPlannerId {
18 "retrieval-router".to_string()
19 }
20
21 async fn plan(
22 &self,
23 query: &ContextQuery,
24 mut provider_blocks: Vec<ContextBlock>,
25 ) -> anyhow::Result<ContextPlan> {
26 let plan = route_retrieval(query, &provider_blocks);
27 if !plan.recommended.is_empty() || !plan.avoid.is_empty() {
28 provider_blocks.push(render_retrieval_block(&plan));
29 }
30 provider_blocks.sort_by_key(|block| std::cmp::Reverse(block.priority));
31 Ok(ContextPlan {
32 blocks: provider_blocks,
33 })
34 }
35}
36
37pub fn route_retrieval(
38 query: &ContextQuery,
39 provider_blocks: &[ContextBlock],
40) -> RetrievalRoutePlan {
41 let prompt = query.prompt.to_ascii_lowercase();
42 let mut recommended = Vec::new();
43 let mut avoid = Vec::new();
44 let intent = classify_intent(&prompt);
45 let semantic_ready = provider_blocks.iter().any(|block| {
46 block
47 .metadata
48 .get("source")
49 .and_then(serde_json::Value::as_str)
50 == Some("indexed_semantic_code_search")
51 });
52
53 if looks_like_command_failure(&prompt) {
54 recommended.push(recommend(
55 RetrievalMode::Artifact,
56 "grep_artifact",
57 extract_query(&query.prompt),
58 "command or terminal output failure should start with saved artifact search",
59 RetrievalConfidence::High,
60 ));
61 }
62 if looks_like_capability_lookup(&prompt) {
63 if let Some(item_id) = matching_promoted_capability(provider_blocks, &prompt) {
64 let mut rec = recommend(
65 RetrievalMode::Promotion,
66 "discovery.read",
67 extract_query(&query.prompt),
68 "matching capability is already promoted or warm-cached for this session",
69 RetrievalConfidence::High,
70 );
71 rec.item_id = Some(item_id);
72 recommended.push(rec);
73 }
74 recommended.push(recommend(
75 RetrievalMode::Discovery,
76 "discovery.search",
77 extract_query(&query.prompt),
78 "tool, MCP, skill, command, or plugin capability lookup",
79 RetrievalConfidence::High,
80 ));
81 }
82 if looks_like_capability_execution(&prompt) {
83 recommended.push(recommend(
84 RetrievalMode::Promotion,
85 "discovery.read",
86 extract_query(&query.prompt),
87 "full schema or instructions are needed before capability use",
88 RetrievalConfidence::High,
89 ));
90 }
91 if looks_like_file_name(&query.prompt) {
92 recommended.push(recommend(
93 RetrievalMode::FileName,
94 "glob",
95 extract_query(&query.prompt),
96 "path or filename-shaped prompt",
97 RetrievalConfidence::High,
98 ));
99 }
100 if looks_like_exact_search(&query.prompt) {
101 recommended.push(recommend(
102 RetrievalMode::ExactText,
103 "grep",
104 extract_query(&query.prompt),
105 "exact symbol, path, regex, or error string",
106 RetrievalConfidence::High,
107 ));
108 }
109 if matches!(intent, RetrievalIntent::BroadConcept) {
110 if semantic_ready {
111 recommended.push(recommend(
112 RetrievalMode::SemanticCode,
113 "code_index.search",
114 extract_query(&query.prompt),
115 "conceptual code search with ready semantic index",
116 RetrievalConfidence::Medium,
117 ));
118 } else {
119 recommended.push(recommend(
120 RetrievalMode::ExactText,
121 "grep",
122 extract_query(&query.prompt),
123 "semantic index not observed; start with exact local search fallback",
124 RetrievalConfidence::Medium,
125 ));
126 }
127 }
128 if prompt.contains("history") || prompt.contains("previous turn") || prompt.contains("resume") {
129 recommended.push(recommend(
130 RetrievalMode::History,
131 "history.search",
132 extract_query(&query.prompt),
133 "prior conversation or session recovery",
134 RetrievalConfidence::Medium,
135 ));
136 }
137 if prompt.contains("code") || prompt.contains("repo") || prompt.contains("workspace") {
138 avoid.push(RetrievalAvoidance {
139 mode: RetrievalMode::Web,
140 reason: "local workspace retrieval should be tried before web search".to_string(),
141 });
142 }
143
144 dedupe_recommendations(&mut recommended);
145 RetrievalRoutePlan {
146 route_id: format!("route:{}:{}", query.thread_id, query.turn_id),
147 thread_id: query.thread_id.clone(),
148 turn_id: query.turn_id.clone(),
149 intent,
150 recommended,
151 avoid,
152 timestamp: OffsetDateTime::now_utc(),
153 }
154}
155
156fn render_retrieval_block(plan: &RetrievalRoutePlan) -> ContextBlock {
157 let mut text = format!("Retrieval route intent: {:?}", plan.intent);
158 for (index, rec) in plan.recommended.iter().take(5).enumerate() {
159 text.push_str(&format!(
160 "\n{}. {:?} via `{}` query `{}` - {}",
161 index + 1,
162 rec.mode,
163 rec.tool,
164 truncate(&rec.query, 80),
165 rec.reason
166 ));
167 }
168 if !plan.avoid.is_empty() {
169 let avoid = plan
170 .avoid
171 .iter()
172 .map(|avoid| format!("{:?}: {}", avoid.mode, avoid.reason))
173 .collect::<Vec<_>>()
174 .join("; ");
175 text.push_str(&format!("\nAvoid: {avoid}"));
176 }
177
178 ContextBlock {
179 id: "retrieval-router".to_string(),
180 kind: ContextBlockKind::RetrievalHint,
181 text,
182 priority: 88,
183 token_estimate: None,
184 metadata: json!({
185 "planner": "retrieval-router",
186 "route_id": plan.route_id,
187 "intent": format!("{:?}", plan.intent),
188 "retrievalPlan": serializable(plan),
189 "recommended": serializable(&plan.recommended),
190 "avoid": serializable(&plan.avoid),
191 }),
192 }
193}
194
195fn classify_intent(prompt: &str) -> RetrievalIntent {
196 if prompt.contains("tool")
197 || prompt.contains("mcp")
198 || prompt.contains("skill")
199 || prompt.contains("plugin")
200 {
201 return RetrievalIntent::InspectTool;
202 }
203 if looks_like_command_failure(prompt) {
204 return RetrievalIntent::DebugFailure;
205 }
206 if prompt.contains("usage") || prompt.contains("call sites") || prompt.contains("where used") {
207 return RetrievalIntent::TraceUsage;
208 }
209 if prompt.contains("history") || prompt.contains("previous turn") || prompt.contains("resume") {
210 return RetrievalIntent::RecoverHistory;
211 }
212 if prompt.contains("file") || prompt.contains("path") || prompt.contains("filename") {
213 return RetrievalIntent::FileLookup;
214 }
215 if looks_like_exact_search(prompt) {
216 return RetrievalIntent::FindDefinition;
217 }
218 RetrievalIntent::BroadConcept
219}
220
221fn recommend(
222 mode: RetrievalMode,
223 tool: &str,
224 query: String,
225 reason: &str,
226 confidence: RetrievalConfidence,
227) -> RetrievalRecommendation {
228 RetrievalRecommendation {
229 mode,
230 tool: tool.to_string(),
231 query,
232 reason: reason.to_string(),
233 confidence,
234 item_id: None,
235 }
236}
237
238fn dedupe_recommendations(recommended: &mut Vec<RetrievalRecommendation>) {
239 let mut seen = std::collections::BTreeSet::new();
240 recommended.retain(|rec| seen.insert((rec.mode.clone(), rec.tool.clone())));
241 recommended.truncate(5);
242}
243
244fn looks_like_capability_lookup(prompt: &str) -> bool {
245 prompt.contains("tool")
246 || prompt.contains("mcp")
247 || prompt.contains("skill")
248 || prompt.contains("command")
249 || prompt.contains("plugin")
250}
251
252fn looks_like_capability_execution(prompt: &str) -> bool {
253 looks_like_capability_lookup(prompt)
254 && (prompt.contains("run")
255 || prompt.contains("use")
256 || prompt.contains("execute")
257 || prompt.contains("call")
258 || prompt.contains("invoke"))
259}
260
261fn looks_like_command_failure(prompt: &str) -> bool {
262 prompt.contains("stderr")
263 || prompt.contains("stdout")
264 || prompt.contains("exit code")
265 || prompt.contains("terminal")
266 || prompt.contains("command failed")
267 || prompt.contains("panic")
268 || prompt.contains("stack trace")
269}
270
271fn looks_like_file_name(prompt: &str) -> bool {
272 prompt.contains('/')
273 || prompt.contains(".rs")
274 || prompt.contains(".ts")
275 || prompt.contains(".tsx")
276 || prompt.contains(".json")
277 || prompt.contains(".toml")
278 || prompt.contains(".md")
279}
280
281fn looks_like_exact_search(prompt: &str) -> bool {
282 prompt.contains("::")
283 || prompt.contains("->")
284 || prompt.contains("fn ")
285 || prompt.contains("struct ")
286 || prompt.contains("enum ")
287 || prompt.split_whitespace().any(|token| {
288 token.len() >= 4
289 && token
290 .chars()
291 .any(|ch| ch == '_' || ch.is_ascii_uppercase() || ch.is_ascii_digit())
292 })
293}
294
295fn matching_promoted_capability(blocks: &[ContextBlock], prompt: &str) -> Option<String> {
296 let prompt_tokens = prompt
297 .split(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '-')
298 .map(str::to_ascii_lowercase)
299 .filter(|token| token.len() >= 3)
300 .collect::<Vec<_>>();
301 blocks.iter().find_map(|block| {
302 let source = block
303 .metadata
304 .get("source")
305 .and_then(serde_json::Value::as_str)
306 .unwrap_or_default();
307 if !matches!(
308 source,
309 "promoted_capabilities" | "discovery_promotions" | "warm_cached_capabilities"
310 ) {
311 return None;
312 }
313 let item_id = block
314 .metadata
315 .get("item_id")
316 .and_then(serde_json::Value::as_str)
317 .or_else(|| {
318 block
319 .metadata
320 .get("itemId")
321 .and_then(serde_json::Value::as_str)
322 })?;
323 let haystack = format!("{} {}", item_id, block.text).to_ascii_lowercase();
324 prompt_tokens
325 .iter()
326 .any(|token| haystack.contains(token))
327 .then(|| item_id.to_string())
328 })
329}
330
331fn extract_query(prompt: &str) -> String {
332 truncate(prompt.trim(), 120).to_string()
333}
334
335fn truncate(text: &str, max: usize) -> &str {
336 if text.len() <= max {
337 return text;
338 }
339 let mut end = max;
340 while !text.is_char_boundary(end) {
341 end -= 1;
342 }
343 &text[..end]
344}
345
346fn serializable<T: Serialize>(value: &T) -> serde_json::Value {
347 serde_json::to_value(value).unwrap_or_else(|_| serde_json::Value::Null)
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[tokio::test]
355 async fn retrieval_router_routes_exact_symbols_to_grep() {
356 let planner = RetrievalRouterPlanner;
357 let plan = planner
358 .plan(&query("Find ToolExecutionContext in the repo"), Vec::new())
359 .await
360 .unwrap();
361
362 let block = plan
363 .blocks
364 .iter()
365 .find(|block| block.kind == ContextBlockKind::RetrievalHint)
366 .unwrap();
367 assert!(block.text.contains("ExactText"));
368 assert!(block.text.contains("`grep`"));
369 }
370
371 #[tokio::test]
372 async fn retrieval_router_routes_concepts_to_semantic_when_index_ready() {
373 let planner = RetrievalRouterPlanner;
374 let semantic_block = ContextBlock {
375 id: "code-index".to_string(),
376 kind: ContextBlockKind::RetrievedDocument,
377 text: "Indexed context".to_string(),
378 priority: 10,
379 token_estimate: None,
380 metadata: json!({ "source": "indexed_semantic_code_search" }),
381 };
382
383 let plan = planner
384 .plan(
385 &query("How does the policy gate choose approvals?"),
386 vec![semantic_block],
387 )
388 .await
389 .unwrap();
390
391 let block = plan.blocks.first().unwrap();
392 assert_eq!(block.kind, ContextBlockKind::RetrievalHint);
393 assert!(block.text.contains("SemanticCode"));
394 assert!(block.text.contains("code_index.search"));
395 }
396
397 #[tokio::test]
398 async fn retrieval_router_routes_capability_execution_to_discovery_and_promotion() {
399 let planner = RetrievalRouterPlanner;
400 let plan = planner
401 .plan(
402 &query("Use the GitHub MCP issue search tool to find blockers"),
403 Vec::new(),
404 )
405 .await
406 .unwrap();
407 let block = plan.blocks.first().unwrap();
408
409 assert!(block.text.contains("Discovery"));
410 assert!(block.text.contains("Promotion"));
411 assert!(block.text.contains("discovery.search"));
412 assert!(block.text.contains("discovery.read"));
413 }
414
415 #[tokio::test]
416 async fn retrieval_router_prefers_promoted_capability_state() {
417 let planner = RetrievalRouterPlanner;
418 let promoted = ContextBlock {
419 id: "promoted-github".to_string(),
420 kind: ContextBlockKind::ToolAvailability,
421 text: "GitHub issue search is promoted".to_string(),
422 priority: 20,
423 token_estimate: None,
424 metadata: json!({
425 "source": "promoted_capabilities",
426 "item_id": "mcp:github/issues.search",
427 }),
428 };
429
430 let plan = planner
431 .plan(
432 &query("Use the GitHub MCP issue search tool"),
433 vec![promoted],
434 )
435 .await
436 .unwrap();
437 let block = plan.blocks.first().unwrap();
438
439 assert!(block.text.contains("already promoted or warm-cached"));
440 assert_eq!(
441 block.metadata["recommended"][0]["itemId"],
442 "mcp:github/issues.search"
443 );
444 }
445
446 #[tokio::test]
447 async fn retrieval_router_routes_command_failures_to_artifacts() {
448 let planner = RetrievalRouterPlanner;
449 let plan = planner
450 .plan(
451 &query("A terminal command failed with stderr; inspect the log"),
452 Vec::new(),
453 )
454 .await
455 .unwrap();
456 let block = plan.blocks.first().unwrap();
457
458 assert!(block.text.contains("Artifact"));
459 assert!(block.text.contains("grep_artifact"));
460 }
461
462 fn query(prompt: &str) -> ContextQuery {
463 ContextQuery {
464 thread_id: "thread-a".to_string(),
465 turn_id: "turn-a".to_string(),
466 prompt: prompt.to_string(),
467 workspace: None,
468 token_budget: None,
469 }
470 }
471}