1use super::{RlmChunker, RlmConfig, RlmResult, RlmStats};
7use crate::provider::{CompletionRequest, ContentPart, Message, Provider, Role};
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::HashSet;
11use std::sync::Arc;
12use std::time::Instant;
13use tracing::{info, warn};
14
15fn rlm_eligible_tools() -> HashSet<&'static str> {
17 ["read", "glob", "grep", "bash", "search"]
18 .iter()
19 .copied()
20 .collect()
21}
22
23#[derive(Debug, Clone)]
25pub struct RoutingContext {
26 pub tool_id: String,
27 pub session_id: String,
28 pub call_id: Option<String>,
29 pub model_context_limit: usize,
30 pub current_context_tokens: Option<usize>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct RoutingResult {
36 pub should_route: bool,
37 pub reason: String,
38 pub estimated_tokens: usize,
39}
40
41pub struct AutoProcessContext<'a> {
43 pub tool_id: &'a str,
44 pub tool_args: serde_json::Value,
45 pub session_id: &'a str,
46 pub abort: Option<tokio::sync::watch::Receiver<bool>>,
47 pub on_progress: Option<Box<dyn Fn(ProcessProgress) + Send + Sync>>,
48 pub provider: Arc<dyn Provider>,
49 pub model: String,
50}
51
52#[derive(Debug, Clone)]
54pub struct ProcessProgress {
55 pub iteration: usize,
56 pub max_iterations: usize,
57 pub status: String,
58}
59
60pub struct RlmRouter;
62
63impl RlmRouter {
64 pub fn should_route(output: &str, ctx: &RoutingContext, config: &RlmConfig) -> RoutingResult {
66 let estimated_tokens = RlmChunker::estimate_tokens(output);
67
68 if config.mode == "off" {
70 return RoutingResult {
71 should_route: false,
72 reason: "rlm_mode_off".to_string(),
73 estimated_tokens,
74 };
75 }
76
77 if config.mode == "always" {
79 if !rlm_eligible_tools().contains(ctx.tool_id.as_str()) {
80 return RoutingResult {
81 should_route: false,
82 reason: "tool_not_eligible".to_string(),
83 estimated_tokens,
84 };
85 }
86 return RoutingResult {
87 should_route: true,
88 reason: "rlm_mode_always".to_string(),
89 estimated_tokens,
90 };
91 }
92
93 if !rlm_eligible_tools().contains(ctx.tool_id.as_str()) {
95 return RoutingResult {
96 should_route: false,
97 reason: "tool_not_eligible".to_string(),
98 estimated_tokens,
99 };
100 }
101
102 let threshold_tokens = (ctx.model_context_limit as f64 * config.threshold) as usize;
104 if estimated_tokens > threshold_tokens {
105 return RoutingResult {
106 should_route: true,
107 reason: "exceeds_threshold".to_string(),
108 estimated_tokens,
109 };
110 }
111
112 if let Some(current) = ctx.current_context_tokens {
114 let projected_total = current + estimated_tokens;
115 if projected_total > (ctx.model_context_limit as f64 * 0.8) as usize {
116 return RoutingResult {
117 should_route: true,
118 reason: "would_overflow".to_string(),
119 estimated_tokens,
120 };
121 }
122 }
123
124 RoutingResult {
125 should_route: false,
126 reason: "within_threshold".to_string(),
127 estimated_tokens,
128 }
129 }
130
131 pub fn smart_truncate(
133 output: &str,
134 tool_id: &str,
135 tool_args: &serde_json::Value,
136 max_tokens: usize,
137 ) -> (String, bool, usize) {
138 let estimated_tokens = RlmChunker::estimate_tokens(output);
139
140 if estimated_tokens <= max_tokens {
141 return (output.to_string(), false, estimated_tokens);
142 }
143
144 info!(
145 tool = tool_id,
146 original_tokens = estimated_tokens,
147 max_tokens,
148 "Smart truncating large output"
149 );
150
151 let max_chars = max_tokens * 4;
153 let head_chars = (max_chars as f64 * 0.6) as usize;
154 let tail_chars = (max_chars as f64 * 0.3) as usize;
155
156 let head: String = output.chars().take(head_chars).collect();
157 let tail: String = output
158 .chars()
159 .rev()
160 .take(tail_chars)
161 .collect::<String>()
162 .chars()
163 .rev()
164 .collect();
165
166 let omitted_tokens = estimated_tokens
167 - RlmChunker::estimate_tokens(&head)
168 - RlmChunker::estimate_tokens(&tail);
169 let rlm_hint = Self::build_rlm_hint(tool_id, tool_args, estimated_tokens);
170
171 let truncated = format!(
172 "{}\n\n[... {} tokens truncated ...]\n\n{}\n\n{}",
173 head, omitted_tokens, rlm_hint, tail
174 );
175
176 (truncated, true, estimated_tokens)
177 }
178
179 fn build_rlm_hint(tool_id: &str, args: &serde_json::Value, tokens: usize) -> String {
180 let base = format!(
181 "⚠️ OUTPUT TOO LARGE ({} tokens). Use RLM for full analysis:",
182 tokens
183 );
184
185 match tool_id {
186 "read" => {
187 let path = args
188 .get("filePath")
189 .and_then(|v| v.as_str())
190 .unwrap_or("...");
191 format!(
192 "{}\n```\nrlm({{ query: \"Analyze this file\", content_paths: [\"{}\"] }})\n```",
193 base, path
194 )
195 }
196 "bash" => {
197 format!(
198 "{}\n```\nrlm({{ query: \"Analyze this command output\", content: \"<paste or use content_paths>\" }})\n```",
199 base
200 )
201 }
202 "grep" => {
203 let pattern = args
204 .get("pattern")
205 .and_then(|v| v.as_str())
206 .unwrap_or("...");
207 let include = args.get("include").and_then(|v| v.as_str()).unwrap_or("*");
208 format!(
209 "{}\n```\nrlm({{ query: \"Summarize search results for {}\", content_glob: \"{}\" }})\n```",
210 base, pattern, include
211 )
212 }
213 _ => {
214 format!(
215 "{}\n```\nrlm({{ query: \"Summarize this output\", content: \"...\" }})\n```",
216 base
217 )
218 }
219 }
220 }
221
222 pub async fn auto_process(
228 output: &str,
229 ctx: AutoProcessContext<'_>,
230 config: &RlmConfig,
231 ) -> Result<RlmResult> {
232 let start = Instant::now();
233 let input_tokens = RlmChunker::estimate_tokens(output);
234
235 info!(
236 tool = ctx.tool_id,
237 input_tokens,
238 model = %ctx.model,
239 "RLM: Starting auto-processing"
240 );
241
242 let content_type = RlmChunker::detect_content_type(output);
244 let content_hints = RlmChunker::get_processing_hints(content_type);
245
246 info!(content_type = ?content_type, tool = ctx.tool_id, "RLM: Content type detected");
247
248 let processed_output = if input_tokens > 50000 {
250 RlmChunker::compress(output, 40000, None)
251 } else {
252 output.to_string()
253 };
254
255 let base_query = Self::build_query_for_tool(ctx.tool_id, &ctx.tool_args);
257 let query = format!(
258 "{}\n\n## Content Analysis Hints\n{}",
259 base_query, content_hints
260 );
261
262 let system_prompt = Self::build_rlm_system_prompt(input_tokens, ctx.tool_id, &query);
264
265 let max_iterations = config.max_iterations;
266 let max_subcalls = config.max_subcalls;
267 let mut iterations = 0;
268 let mut subcalls = 0;
269 let mut final_answer: Option<String> = None;
270
271 let exploration = Self::build_exploration_summary(&processed_output, input_tokens);
273
274 let mut conversation = vec![Message {
276 role: Role::User,
277 content: vec![ContentPart::Text {
278 text: format!(
279 "{}\n\nHere is the context exploration:\n```\n{}\n```\n\nNow analyze and answer the query.",
280 system_prompt, exploration
281 ),
282 }],
283 }];
284
285 for i in 0..max_iterations {
286 iterations = i + 1;
287
288 if let Some(ref progress) = ctx.on_progress {
289 progress(ProcessProgress {
290 iteration: iterations,
291 max_iterations,
292 status: "running".to_string(),
293 });
294 }
295
296 if let Some(ref abort) = ctx.abort {
298 if *abort.borrow() {
299 warn!("RLM: Processing aborted");
300 break;
301 }
302 }
303
304 let request = CompletionRequest {
306 messages: conversation.clone(),
307 tools: Vec::new(),
308 model: ctx.model.clone(),
309 temperature: Some(0.7),
310 top_p: None,
311 max_tokens: Some(4000),
312 stop: Vec::new(),
313 };
314
315 let response = match ctx.provider.complete(request).await {
317 Ok(r) => r,
318 Err(e) => {
319 warn!(error = %e, iteration = iterations, "RLM: Model call failed");
320 if iterations > 1 {
321 break; }
323 return Ok(Self::fallback_result(
324 output,
325 ctx.tool_id,
326 &ctx.tool_args,
327 input_tokens,
328 ));
329 }
330 };
331
332 let response_text: String = response
333 .message
334 .content
335 .iter()
336 .filter_map(|p| match p {
337 ContentPart::Text { text } => Some(text.clone()),
338 _ => None,
339 })
340 .collect::<Vec<_>>()
341 .join("\n");
342
343 info!(
344 iteration = iterations,
345 response_len = response_text.len(),
346 "RLM: Model response"
347 );
348
349 if let Some(answer) = Self::extract_final(&response_text) {
351 final_answer = Some(answer);
352 break;
353 }
354
355 if iterations >= 3 && response_text.len() > 500 && !response_text.contains("```") {
357 final_answer = Some(response_text.clone());
359 break;
360 }
361
362 conversation.push(Message {
364 role: Role::Assistant,
365 content: vec![ContentPart::Text {
366 text: response_text,
367 }],
368 });
369
370 conversation.push(Message {
372 role: Role::User,
373 content: vec![ContentPart::Text {
374 text: "Continue analysis. Call FINAL(\"your answer\") when ready.".to_string(),
375 }],
376 });
377
378 subcalls += 1;
379 if subcalls >= max_subcalls {
380 warn!(subcalls, max = max_subcalls, "RLM: Max subcalls reached");
381 break;
382 }
383 }
384
385 if let Some(ref progress) = ctx.on_progress {
386 progress(ProcessProgress {
387 iteration: iterations,
388 max_iterations,
389 status: "completed".to_string(),
390 });
391 }
392
393 let answer = final_answer.unwrap_or_else(|| {
395 warn!(
396 iterations,
397 subcalls, "RLM: No FINAL produced, using fallback"
398 );
399 Self::build_enhanced_fallback(output, ctx.tool_id, &ctx.tool_args, input_tokens)
400 });
401
402 let output_tokens = RlmChunker::estimate_tokens(&answer);
403 let compression_ratio = input_tokens as f64 / output_tokens.max(1) as f64;
404 let elapsed_ms = start.elapsed().as_millis() as u64;
405
406 let result = format!(
407 "[RLM: {} → {} tokens | {} iterations | {} sub-calls]\n\n{}",
408 input_tokens, output_tokens, iterations, subcalls, answer
409 );
410
411 info!(
412 input_tokens,
413 output_tokens,
414 iterations,
415 subcalls,
416 elapsed_ms,
417 compression_ratio = format!("{:.1}", compression_ratio),
418 "RLM: Processing complete"
419 );
420
421 Ok(RlmResult {
422 processed: result,
423 stats: RlmStats {
424 input_tokens,
425 output_tokens: RlmChunker::estimate_tokens(&answer),
426 iterations,
427 subcalls,
428 elapsed_ms,
429 compression_ratio,
430 },
431 success: true,
432 error: None,
433 })
434 }
435
436 fn extract_final(text: &str) -> Option<String> {
437 let patterns = [r#"FINAL\s*\(\s*["'`]"#, r#"FINAL!\s*\(\s*["'`]?"#];
439
440 for _pattern_start in patterns {
441 if let Some(start_idx) = text.find("FINAL") {
442 let after = &text[start_idx..];
443
444 if let Some(open_idx) = after.find(['"', '\'', '`']) {
446 let quote_char = after.chars().nth(open_idx)?;
447 let content_start = start_idx + open_idx + 1;
448
449 let content = &text[content_start..];
451 if let Some(close_idx) = content.find(quote_char) {
452 let answer = &content[..close_idx];
453 if !answer.is_empty() {
454 return Some(answer.to_string());
455 }
456 }
457 }
458 }
459 }
460
461 None
462 }
463
464 fn build_exploration_summary(content: &str, input_tokens: usize) -> String {
465 let lines: Vec<&str> = content.lines().collect();
466 let total_lines = lines.len();
467
468 let head: String = lines
469 .iter()
470 .take(30)
471 .copied()
472 .collect::<Vec<_>>()
473 .join("\n");
474 let tail: String = lines
475 .iter()
476 .rev()
477 .take(50)
478 .collect::<Vec<_>>()
479 .into_iter()
480 .rev()
481 .copied()
482 .collect::<Vec<_>>()
483 .join("\n");
484
485 format!(
486 "=== CONTEXT EXPLORATION ===\n\
487 Total: {} chars, {} lines, ~{} tokens\n\n\
488 === FIRST 30 LINES ===\n{}\n\n\
489 === LAST 50 LINES ===\n{}\n\
490 === END EXPLORATION ===",
491 content.len(),
492 total_lines,
493 input_tokens,
494 head,
495 tail
496 )
497 }
498
499 fn build_rlm_system_prompt(input_tokens: usize, tool_id: &str, query: &str) -> String {
500 let context_type = if tool_id == "session_context" {
501 "conversation history"
502 } else {
503 "tool output"
504 };
505
506 format!(
507 r#"You are tasked with analyzing large content that cannot fit in a normal context window.
508
509The content is a {} with {} total tokens.
510
511YOUR TASK: {}
512
513## Analysis Strategy
514
5151. First, examine the exploration (head + tail of content) to understand structure
5162. Identify the most important information for answering the query
5173. Focus on: errors, key decisions, file paths, recent activity
5184. Provide a concise but complete answer
519
520When ready, call FINAL("your detailed answer") with your findings.
521
522Be SPECIFIC - include actual file paths, function names, error messages. Generic summaries are not useful."#,
523 context_type, input_tokens, query
524 )
525 }
526
527 fn build_query_for_tool(tool_id: &str, args: &serde_json::Value) -> String {
528 match tool_id {
529 "read" => {
530 let path = args.get("filePath").and_then(|v| v.as_str()).unwrap_or("unknown");
531 format!("Summarize the key contents of file \"{}\". Focus on: structure, main functions/classes, important logic. Be concise.", path)
532 }
533 "bash" => {
534 "Summarize the command output. Extract key information, results, errors, warnings. Be concise.".to_string()
535 }
536 "grep" => {
537 let pattern = args.get("pattern").and_then(|v| v.as_str()).unwrap_or("pattern");
538 format!("Summarize search results for \"{}\". Group by file, highlight most relevant matches. Be concise.", pattern)
539 }
540 "glob" => {
541 "Summarize the file listing. Group by directory, highlight important files. Be concise.".to_string()
542 }
543 "session_context" => {
544 r#"You are a CONTEXT MEMORY SYSTEM. Create a BRIEFING for an AI assistant to continue this conversation.
545
546CRITICAL: The assistant will ONLY see your briefing - it has NO memory of the conversation.
547
548## What to Extract
549
5501. **PRIMARY GOAL**: What is the user ultimately trying to achieve?
5512. **CURRENT STATE**: What has been accomplished? Current status?
5523. **LAST ACTIONS**: What just happened? (last 3-5 tool calls, their results)
5534. **ACTIVE FILES**: Which files were modified?
5545. **PENDING TASKS**: What remains to be done?
5556. **CRITICAL DETAILS**: File paths, error messages, specific values, decisions made
5567. **NEXT STEPS**: What should happen next?
557
558Be SPECIFIC with file paths, function names, error messages."#.to_string()
559 }
560 _ => "Summarize this output concisely, extracting the most important information.".to_string()
561 }
562 }
563
564 fn build_enhanced_fallback(
565 output: &str,
566 tool_id: &str,
567 tool_args: &serde_json::Value,
568 input_tokens: usize,
569 ) -> String {
570 let lines: Vec<&str> = output.lines().collect();
571
572 if tool_id == "session_context" {
573 let file_matches: Vec<&str> = lines
575 .iter()
576 .filter_map(|l| {
577 if l.contains(".ts")
578 || l.contains(".rs")
579 || l.contains(".py")
580 || l.contains(".json")
581 {
582 Some(*l)
583 } else {
584 None
585 }
586 })
587 .take(15)
588 .collect();
589
590 let tool_calls: Vec<&str> = lines
591 .iter()
592 .filter(|l| l.contains("[Tool "))
593 .take(10)
594 .copied()
595 .collect();
596
597 let errors: Vec<&str> = lines
598 .iter()
599 .filter(|l| {
600 l.to_lowercase().contains("error") || l.to_lowercase().contains("failed")
601 })
602 .take(5)
603 .copied()
604 .collect();
605
606 let head: String = lines
607 .iter()
608 .take(30)
609 .copied()
610 .collect::<Vec<_>>()
611 .join("\n");
612 let tail: String = lines
613 .iter()
614 .rev()
615 .take(80)
616 .collect::<Vec<_>>()
617 .into_iter()
618 .rev()
619 .copied()
620 .collect::<Vec<_>>()
621 .join("\n");
622
623 let mut parts = vec![
624 "## Context Summary (Fallback Mode)".to_string(),
625 format!(
626 "*Original: {} tokens - RLM processing produced insufficient output*",
627 input_tokens
628 ),
629 String::new(),
630 ];
631
632 if !file_matches.is_empty() {
633 parts.push(format!("**Files Mentioned:** {}", file_matches.len()));
634 }
635
636 if !tool_calls.is_empty() {
637 parts.push(format!("**Recent Tool Calls:** {}", tool_calls.join(", ")));
638 }
639
640 if !errors.is_empty() {
641 parts.push("**Recent Errors:**".to_string());
642 for e in errors {
643 parts.push(format!("- {}", e.chars().take(150).collect::<String>()));
644 }
645 }
646
647 parts.push(String::new());
648 parts.push("### Initial Request".to_string());
649 parts.push("```".to_string());
650 parts.push(head);
651 parts.push("```".to_string());
652 parts.push(String::new());
653 parts.push("### Recent Activity".to_string());
654 parts.push("```".to_string());
655 parts.push(tail);
656 parts.push("```".to_string());
657
658 parts.join("\n")
659 } else {
660 let (truncated, _, _) = Self::smart_truncate(output, tool_id, tool_args, 8000);
661 format!(
662 "## Fallback Summary\n*RLM processing failed - showing structured excerpt*\n\n{}",
663 truncated
664 )
665 }
666 }
667
668 fn fallback_result(
669 output: &str,
670 tool_id: &str,
671 tool_args: &serde_json::Value,
672 input_tokens: usize,
673 ) -> RlmResult {
674 let (truncated, _, _) = Self::smart_truncate(output, tool_id, tool_args, 8000);
675 let output_tokens = RlmChunker::estimate_tokens(&truncated);
676
677 RlmResult {
678 processed: format!(
679 "[RLM processing failed, showing truncated output]\n\n{}",
680 truncated
681 ),
682 stats: RlmStats {
683 input_tokens,
684 output_tokens,
685 iterations: 0,
686 subcalls: 0,
687 elapsed_ms: 0,
688 compression_ratio: input_tokens as f64 / output_tokens.max(1) as f64,
689 },
690 success: false,
691 error: Some("Model call failed".to_string()),
692 }
693 }
694}