1use super::{RlmChunker, RlmConfig, RlmResult, RlmStats};
7use crate::provider::{CompletionRequest, ContentPart, Message, Provider, Role};
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::HashSet;
11use std::sync::Arc;
12use std::time::Instant;
13use tracing::{info, warn};
14
15fn rlm_eligible_tools() -> HashSet<&'static str> {
17 ["read", "glob", "grep", "bash", "search"].iter().copied().collect()
18}
19
20#[derive(Debug, Clone)]
22pub struct RoutingContext {
23 pub tool_id: String,
24 pub session_id: String,
25 pub call_id: Option<String>,
26 pub model_context_limit: usize,
27 pub current_context_tokens: Option<usize>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct RoutingResult {
33 pub should_route: bool,
34 pub reason: String,
35 pub estimated_tokens: usize,
36}
37
38pub struct AutoProcessContext<'a> {
40 pub tool_id: &'a str,
41 pub tool_args: serde_json::Value,
42 pub session_id: &'a str,
43 pub abort: Option<tokio::sync::watch::Receiver<bool>>,
44 pub on_progress: Option<Box<dyn Fn(ProcessProgress) + Send + Sync>>,
45 pub provider: Arc<dyn Provider>,
46 pub model: String,
47}
48
49#[derive(Debug, Clone)]
51pub struct ProcessProgress {
52 pub iteration: usize,
53 pub max_iterations: usize,
54 pub status: String,
55}
56
57pub struct RlmRouter;
59
60impl RlmRouter {
61 pub fn should_route(output: &str, ctx: &RoutingContext, config: &RlmConfig) -> RoutingResult {
63 let estimated_tokens = RlmChunker::estimate_tokens(output);
64
65 if config.mode == "off" {
67 return RoutingResult {
68 should_route: false,
69 reason: "rlm_mode_off".to_string(),
70 estimated_tokens,
71 };
72 }
73
74 if config.mode == "always" {
76 if !rlm_eligible_tools().contains(ctx.tool_id.as_str()) {
77 return RoutingResult {
78 should_route: false,
79 reason: "tool_not_eligible".to_string(),
80 estimated_tokens,
81 };
82 }
83 return RoutingResult {
84 should_route: true,
85 reason: "rlm_mode_always".to_string(),
86 estimated_tokens,
87 };
88 }
89
90 if !rlm_eligible_tools().contains(ctx.tool_id.as_str()) {
92 return RoutingResult {
93 should_route: false,
94 reason: "tool_not_eligible".to_string(),
95 estimated_tokens,
96 };
97 }
98
99 let threshold_tokens = (ctx.model_context_limit as f64 * config.threshold) as usize;
101 if estimated_tokens > threshold_tokens {
102 return RoutingResult {
103 should_route: true,
104 reason: "exceeds_threshold".to_string(),
105 estimated_tokens,
106 };
107 }
108
109 if let Some(current) = ctx.current_context_tokens {
111 let projected_total = current + estimated_tokens;
112 if projected_total > (ctx.model_context_limit as f64 * 0.8) as usize {
113 return RoutingResult {
114 should_route: true,
115 reason: "would_overflow".to_string(),
116 estimated_tokens,
117 };
118 }
119 }
120
121 RoutingResult {
122 should_route: false,
123 reason: "within_threshold".to_string(),
124 estimated_tokens,
125 }
126 }
127
128 pub fn smart_truncate(
130 output: &str,
131 tool_id: &str,
132 tool_args: &serde_json::Value,
133 max_tokens: usize,
134 ) -> (String, bool, usize) {
135 let estimated_tokens = RlmChunker::estimate_tokens(output);
136
137 if estimated_tokens <= max_tokens {
138 return (output.to_string(), false, estimated_tokens);
139 }
140
141 info!(
142 tool = tool_id,
143 original_tokens = estimated_tokens,
144 max_tokens,
145 "Smart truncating large output"
146 );
147
148 let max_chars = max_tokens * 4;
150 let head_chars = (max_chars as f64 * 0.6) as usize;
151 let tail_chars = (max_chars as f64 * 0.3) as usize;
152
153 let head: String = output.chars().take(head_chars).collect();
154 let tail: String = output.chars().rev().take(tail_chars).collect::<String>().chars().rev().collect();
155
156 let omitted_tokens = estimated_tokens - RlmChunker::estimate_tokens(&head) - RlmChunker::estimate_tokens(&tail);
157 let rlm_hint = Self::build_rlm_hint(tool_id, tool_args, estimated_tokens);
158
159 let truncated = format!(
160 "{}\n\n[... {} tokens truncated ...]\n\n{}\n\n{}",
161 head, omitted_tokens, rlm_hint, tail
162 );
163
164 (truncated, true, estimated_tokens)
165 }
166
167 fn build_rlm_hint(tool_id: &str, args: &serde_json::Value, tokens: usize) -> String {
168 let base = format!("⚠️ OUTPUT TOO LARGE ({} tokens). Use RLM for full analysis:", tokens);
169
170 match tool_id {
171 "read" => {
172 let path = args.get("filePath").and_then(|v| v.as_str()).unwrap_or("...");
173 format!("{}\n```\nrlm({{ query: \"Analyze this file\", content_paths: [\"{}\"] }})\n```", base, path)
174 }
175 "bash" => {
176 format!("{}\n```\nrlm({{ query: \"Analyze this command output\", content: \"<paste or use content_paths>\" }})\n```", base)
177 }
178 "grep" => {
179 let pattern = args.get("pattern").and_then(|v| v.as_str()).unwrap_or("...");
180 let include = args.get("include").and_then(|v| v.as_str()).unwrap_or("*");
181 format!("{}\n```\nrlm({{ query: \"Summarize search results for {}\", content_glob: \"{}\" }})\n```", base, pattern, include)
182 }
183 _ => {
184 format!("{}\n```\nrlm({{ query: \"Summarize this output\", content: \"...\" }})\n```", base)
185 }
186 }
187 }
188
189 pub async fn auto_process(output: &str, ctx: AutoProcessContext<'_>, config: &RlmConfig) -> Result<RlmResult> {
195 let start = Instant::now();
196 let input_tokens = RlmChunker::estimate_tokens(output);
197
198 info!(
199 tool = ctx.tool_id,
200 input_tokens,
201 model = %ctx.model,
202 "RLM: Starting auto-processing"
203 );
204
205 let content_type = RlmChunker::detect_content_type(output);
207 let content_hints = RlmChunker::get_processing_hints(content_type);
208
209 info!(content_type = ?content_type, tool = ctx.tool_id, "RLM: Content type detected");
210
211 let processed_output = if input_tokens > 50000 {
213 RlmChunker::compress(output, 40000, None)
214 } else {
215 output.to_string()
216 };
217
218 let base_query = Self::build_query_for_tool(ctx.tool_id, &ctx.tool_args);
220 let query = format!("{}\n\n## Content Analysis Hints\n{}", base_query, content_hints);
221
222 let system_prompt = Self::build_rlm_system_prompt(input_tokens, ctx.tool_id, &query);
224
225 let max_iterations = config.max_iterations;
226 let max_subcalls = config.max_subcalls;
227 let mut iterations = 0;
228 let mut subcalls = 0;
229 let mut final_answer: Option<String> = None;
230
231 let exploration = Self::build_exploration_summary(&processed_output, input_tokens);
233
234 let mut conversation = vec![
236 Message {
237 role: Role::User,
238 content: vec![ContentPart::Text {
239 text: format!(
240 "{}\n\nHere is the context exploration:\n```\n{}\n```\n\nNow analyze and answer the query.",
241 system_prompt, exploration
242 ),
243 }],
244 },
245 ];
246
247 for i in 0..max_iterations {
248 iterations = i + 1;
249
250 if let Some(ref progress) = ctx.on_progress {
251 progress(ProcessProgress {
252 iteration: iterations,
253 max_iterations,
254 status: "running".to_string(),
255 });
256 }
257
258 if let Some(ref abort) = ctx.abort {
260 if *abort.borrow() {
261 warn!("RLM: Processing aborted");
262 break;
263 }
264 }
265
266 let request = CompletionRequest {
268 messages: conversation.clone(),
269 tools: Vec::new(),
270 model: ctx.model.clone(),
271 temperature: Some(0.7),
272 top_p: None,
273 max_tokens: Some(4000),
274 stop: Vec::new(),
275 };
276
277 let response = match ctx.provider.complete(request).await {
279 Ok(r) => r,
280 Err(e) => {
281 warn!(error = %e, iteration = iterations, "RLM: Model call failed");
282 if iterations > 1 {
283 break; }
285 return Ok(Self::fallback_result(output, ctx.tool_id, &ctx.tool_args, input_tokens));
286 }
287 };
288
289 let response_text: String = response.message.content
290 .iter()
291 .filter_map(|p| match p {
292 ContentPart::Text { text } => Some(text.clone()),
293 _ => None,
294 })
295 .collect::<Vec<_>>()
296 .join("\n");
297
298 info!(
299 iteration = iterations,
300 response_len = response_text.len(),
301 "RLM: Model response"
302 );
303
304 if let Some(answer) = Self::extract_final(&response_text) {
306 final_answer = Some(answer);
307 break;
308 }
309
310 if iterations >= 3 && response_text.len() > 500 && !response_text.contains("```") {
312 final_answer = Some(response_text.clone());
314 break;
315 }
316
317 conversation.push(Message {
319 role: Role::Assistant,
320 content: vec![ContentPart::Text { text: response_text }],
321 });
322
323 conversation.push(Message {
325 role: Role::User,
326 content: vec![ContentPart::Text {
327 text: "Continue analysis. Call FINAL(\"your answer\") when ready.".to_string(),
328 }],
329 });
330
331 subcalls += 1;
332 if subcalls >= max_subcalls {
333 warn!(subcalls, max = max_subcalls, "RLM: Max subcalls reached");
334 break;
335 }
336 }
337
338 if let Some(ref progress) = ctx.on_progress {
339 progress(ProcessProgress {
340 iteration: iterations,
341 max_iterations,
342 status: "completed".to_string(),
343 });
344 }
345
346 let answer = final_answer.unwrap_or_else(|| {
348 warn!(iterations, subcalls, "RLM: No FINAL produced, using fallback");
349 Self::build_enhanced_fallback(output, ctx.tool_id, &ctx.tool_args, input_tokens)
350 });
351
352 let output_tokens = RlmChunker::estimate_tokens(&answer);
353 let compression_ratio = input_tokens as f64 / output_tokens.max(1) as f64;
354 let elapsed_ms = start.elapsed().as_millis() as u64;
355
356 let result = format!(
357 "[RLM: {} → {} tokens | {} iterations | {} sub-calls]\n\n{}",
358 input_tokens, output_tokens, iterations, subcalls, answer
359 );
360
361 info!(
362 input_tokens,
363 output_tokens,
364 iterations,
365 subcalls,
366 elapsed_ms,
367 compression_ratio = format!("{:.1}", compression_ratio),
368 "RLM: Processing complete"
369 );
370
371 Ok(RlmResult {
372 processed: result,
373 stats: RlmStats {
374 input_tokens,
375 output_tokens: RlmChunker::estimate_tokens(&answer),
376 iterations,
377 subcalls,
378 elapsed_ms,
379 compression_ratio,
380 },
381 success: true,
382 error: None,
383 })
384 }
385
386 fn extract_final(text: &str) -> Option<String> {
387 let patterns = [
389 r#"FINAL\s*\(\s*["'`]"#,
390 r#"FINAL!\s*\(\s*["'`]?"#,
391 ];
392
393 for _pattern_start in patterns {
394 if let Some(start_idx) = text.find("FINAL") {
395 let after = &text[start_idx..];
396
397 if let Some(open_idx) = after.find(['"', '\'', '`']) {
399 let quote_char = after.chars().nth(open_idx)?;
400 let content_start = start_idx + open_idx + 1;
401
402 let content = &text[content_start..];
404 if let Some(close_idx) = content.find(quote_char) {
405 let answer = &content[..close_idx];
406 if !answer.is_empty() {
407 return Some(answer.to_string());
408 }
409 }
410 }
411 }
412 }
413
414 None
415 }
416
417 fn build_exploration_summary(content: &str, input_tokens: usize) -> String {
418 let lines: Vec<&str> = content.lines().collect();
419 let total_lines = lines.len();
420
421 let head: String = lines.iter().take(30).copied().collect::<Vec<_>>().join("\n");
422 let tail: String = lines.iter().rev().take(50).collect::<Vec<_>>().into_iter().rev().copied().collect::<Vec<_>>().join("\n");
423
424 format!(
425 "=== CONTEXT EXPLORATION ===\n\
426 Total: {} chars, {} lines, ~{} tokens\n\n\
427 === FIRST 30 LINES ===\n{}\n\n\
428 === LAST 50 LINES ===\n{}\n\
429 === END EXPLORATION ===",
430 content.len(), total_lines, input_tokens, head, tail
431 )
432 }
433
434 fn build_rlm_system_prompt(input_tokens: usize, tool_id: &str, query: &str) -> String {
435 let context_type = if tool_id == "session_context" { "conversation history" } else { "tool output" };
436
437 format!(
438 r#"You are tasked with analyzing large content that cannot fit in a normal context window.
439
440The content is a {} with {} total tokens.
441
442YOUR TASK: {}
443
444## Analysis Strategy
445
4461. First, examine the exploration (head + tail of content) to understand structure
4472. Identify the most important information for answering the query
4483. Focus on: errors, key decisions, file paths, recent activity
4494. Provide a concise but complete answer
450
451When ready, call FINAL("your detailed answer") with your findings.
452
453Be SPECIFIC - include actual file paths, function names, error messages. Generic summaries are not useful."#,
454 context_type, input_tokens, query
455 )
456 }
457
458 fn build_query_for_tool(tool_id: &str, args: &serde_json::Value) -> String {
459 match tool_id {
460 "read" => {
461 let path = args.get("filePath").and_then(|v| v.as_str()).unwrap_or("unknown");
462 format!("Summarize the key contents of file \"{}\". Focus on: structure, main functions/classes, important logic. Be concise.", path)
463 }
464 "bash" => {
465 "Summarize the command output. Extract key information, results, errors, warnings. Be concise.".to_string()
466 }
467 "grep" => {
468 let pattern = args.get("pattern").and_then(|v| v.as_str()).unwrap_or("pattern");
469 format!("Summarize search results for \"{}\". Group by file, highlight most relevant matches. Be concise.", pattern)
470 }
471 "glob" => {
472 "Summarize the file listing. Group by directory, highlight important files. Be concise.".to_string()
473 }
474 "session_context" => {
475 r#"You are a CONTEXT MEMORY SYSTEM. Create a BRIEFING for an AI assistant to continue this conversation.
476
477CRITICAL: The assistant will ONLY see your briefing - it has NO memory of the conversation.
478
479## What to Extract
480
4811. **PRIMARY GOAL**: What is the user ultimately trying to achieve?
4822. **CURRENT STATE**: What has been accomplished? Current status?
4833. **LAST ACTIONS**: What just happened? (last 3-5 tool calls, their results)
4844. **ACTIVE FILES**: Which files were modified?
4855. **PENDING TASKS**: What remains to be done?
4866. **CRITICAL DETAILS**: File paths, error messages, specific values, decisions made
4877. **NEXT STEPS**: What should happen next?
488
489Be SPECIFIC with file paths, function names, error messages."#.to_string()
490 }
491 _ => "Summarize this output concisely, extracting the most important information.".to_string()
492 }
493 }
494
495 fn build_enhanced_fallback(output: &str, tool_id: &str, tool_args: &serde_json::Value, input_tokens: usize) -> String {
496 let lines: Vec<&str> = output.lines().collect();
497
498 if tool_id == "session_context" {
499 let file_matches: Vec<&str> = lines.iter()
501 .filter_map(|l| {
502 if l.contains(".ts") || l.contains(".rs") || l.contains(".py") || l.contains(".json") {
503 Some(*l)
504 } else {
505 None
506 }
507 })
508 .take(15)
509 .collect();
510
511 let tool_calls: Vec<&str> = lines.iter()
512 .filter(|l| l.contains("[Tool "))
513 .take(10)
514 .copied()
515 .collect();
516
517 let errors: Vec<&str> = lines.iter()
518 .filter(|l| l.to_lowercase().contains("error") || l.to_lowercase().contains("failed"))
519 .take(5)
520 .copied()
521 .collect();
522
523 let head: String = lines.iter().take(30).copied().collect::<Vec<_>>().join("\n");
524 let tail: String = lines.iter().rev().take(80).collect::<Vec<_>>().into_iter().rev().copied().collect::<Vec<_>>().join("\n");
525
526 let mut parts = vec![
527 "## Context Summary (Fallback Mode)".to_string(),
528 format!("*Original: {} tokens - RLM processing produced insufficient output*", input_tokens),
529 String::new(),
530 ];
531
532 if !file_matches.is_empty() {
533 parts.push(format!("**Files Mentioned:** {}", file_matches.len()));
534 }
535
536 if !tool_calls.is_empty() {
537 parts.push(format!("**Recent Tool Calls:** {}", tool_calls.join(", ")));
538 }
539
540 if !errors.is_empty() {
541 parts.push("**Recent Errors:**".to_string());
542 for e in errors {
543 parts.push(format!("- {}", e.chars().take(150).collect::<String>()));
544 }
545 }
546
547 parts.push(String::new());
548 parts.push("### Initial Request".to_string());
549 parts.push("```".to_string());
550 parts.push(head);
551 parts.push("```".to_string());
552 parts.push(String::new());
553 parts.push("### Recent Activity".to_string());
554 parts.push("```".to_string());
555 parts.push(tail);
556 parts.push("```".to_string());
557
558 parts.join("\n")
559 } else {
560 let (truncated, _, _) = Self::smart_truncate(output, tool_id, tool_args, 8000);
561 format!("## Fallback Summary\n*RLM processing failed - showing structured excerpt*\n\n{}", truncated)
562 }
563 }
564
565 fn fallback_result(output: &str, tool_id: &str, tool_args: &serde_json::Value, input_tokens: usize) -> RlmResult {
566 let (truncated, _, _) = Self::smart_truncate(output, tool_id, tool_args, 8000);
567 let output_tokens = RlmChunker::estimate_tokens(&truncated);
568
569 RlmResult {
570 processed: format!("[RLM processing failed, showing truncated output]\n\n{}", truncated),
571 stats: RlmStats {
572 input_tokens,
573 output_tokens,
574 iterations: 0,
575 subcalls: 0,
576 elapsed_ms: 0,
577 compression_ratio: input_tokens as f64 / output_tokens.max(1) as f64,
578 },
579 success: false,
580 error: Some("Model call failed".to_string()),
581 }
582 }
583}