1use super::{RlmChunker, RlmConfig, RlmResult, RlmStats};
11use crate::provider::{CompletionRequest, ContentPart, Message, Provider, Role};
12use anyhow::Result;
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::time::Instant;
17use tracing::{info, warn};
18
19#[cfg(feature = "functiongemma")]
20use crate::cognition::tool_router::{ToolCallRouter, ToolRouterConfig};
21
22use super::tools::rlm_tool_definitions;
23
24fn rlm_eligible_tools() -> HashSet<&'static str> {
26 ["read", "glob", "grep", "bash", "search"]
27 .iter()
28 .copied()
29 .collect()
30}
31
32#[derive(Debug, Clone)]
34pub struct RoutingContext {
35 pub tool_id: String,
36 pub session_id: String,
37 pub call_id: Option<String>,
38 pub model_context_limit: usize,
39 pub current_context_tokens: Option<usize>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct RoutingResult {
45 pub should_route: bool,
46 pub reason: String,
47 pub estimated_tokens: usize,
48}
49
50pub struct AutoProcessContext<'a> {
52 pub tool_id: &'a str,
53 pub tool_args: serde_json::Value,
54 pub session_id: &'a str,
55 pub abort: Option<tokio::sync::watch::Receiver<bool>>,
56 pub on_progress: Option<Box<dyn Fn(ProcessProgress) + Send + Sync>>,
57 pub provider: Arc<dyn Provider>,
58 pub model: String,
59}
60
61#[derive(Debug, Clone)]
63pub struct ProcessProgress {
64 pub iteration: usize,
65 pub max_iterations: usize,
66 pub status: String,
67}
68
69pub struct RlmRouter;
71
72impl RlmRouter {
73 pub fn should_route(output: &str, ctx: &RoutingContext, config: &RlmConfig) -> RoutingResult {
75 let estimated_tokens = RlmChunker::estimate_tokens(output);
76
77 if config.mode == "off" {
79 return RoutingResult {
80 should_route: false,
81 reason: "rlm_mode_off".to_string(),
82 estimated_tokens,
83 };
84 }
85
86 if config.mode == "always" {
88 if !rlm_eligible_tools().contains(ctx.tool_id.as_str()) {
89 return RoutingResult {
90 should_route: false,
91 reason: "tool_not_eligible".to_string(),
92 estimated_tokens,
93 };
94 }
95 return RoutingResult {
96 should_route: true,
97 reason: "rlm_mode_always".to_string(),
98 estimated_tokens,
99 };
100 }
101
102 if !rlm_eligible_tools().contains(ctx.tool_id.as_str()) {
104 return RoutingResult {
105 should_route: false,
106 reason: "tool_not_eligible".to_string(),
107 estimated_tokens,
108 };
109 }
110
111 let threshold_tokens = (ctx.model_context_limit as f64 * config.threshold) as usize;
113 if estimated_tokens > threshold_tokens {
114 return RoutingResult {
115 should_route: true,
116 reason: "exceeds_threshold".to_string(),
117 estimated_tokens,
118 };
119 }
120
121 if let Some(current) = ctx.current_context_tokens {
123 let projected_total = current + estimated_tokens;
124 if projected_total > (ctx.model_context_limit as f64 * 0.8) as usize {
125 return RoutingResult {
126 should_route: true,
127 reason: "would_overflow".to_string(),
128 estimated_tokens,
129 };
130 }
131 }
132
133 RoutingResult {
134 should_route: false,
135 reason: "within_threshold".to_string(),
136 estimated_tokens,
137 }
138 }
139
140 pub fn smart_truncate(
142 output: &str,
143 tool_id: &str,
144 tool_args: &serde_json::Value,
145 max_tokens: usize,
146 ) -> (String, bool, usize) {
147 let estimated_tokens = RlmChunker::estimate_tokens(output);
148
149 if estimated_tokens <= max_tokens {
150 return (output.to_string(), false, estimated_tokens);
151 }
152
153 info!(
154 tool = tool_id,
155 original_tokens = estimated_tokens,
156 max_tokens,
157 "Smart truncating large output"
158 );
159
160 let max_chars = max_tokens * 4;
162 let head_chars = (max_chars as f64 * 0.6) as usize;
163 let tail_chars = (max_chars as f64 * 0.3) as usize;
164
165 let head: String = output.chars().take(head_chars).collect();
166 let tail: String = output
167 .chars()
168 .rev()
169 .take(tail_chars)
170 .collect::<String>()
171 .chars()
172 .rev()
173 .collect();
174
175 let omitted_tokens = estimated_tokens
176 - RlmChunker::estimate_tokens(&head)
177 - RlmChunker::estimate_tokens(&tail);
178 let rlm_hint = Self::build_rlm_hint(tool_id, tool_args, estimated_tokens);
179
180 let truncated = format!(
181 "{}\n\n[... {} tokens truncated ...]\n\n{}\n\n{}",
182 head, omitted_tokens, rlm_hint, tail
183 );
184
185 (truncated, true, estimated_tokens)
186 }
187
188 fn build_rlm_hint(tool_id: &str, args: &serde_json::Value, tokens: usize) -> String {
189 let base = format!(
190 "⚠️ OUTPUT TOO LARGE ({} tokens). Use RLM for full analysis:",
191 tokens
192 );
193
194 match tool_id {
195 "read" => {
196 let path = args
197 .get("filePath")
198 .and_then(|v| v.as_str())
199 .unwrap_or("...");
200 format!(
201 "{}\n```\nrlm({{ query: \"Analyze this file\", content_paths: [\"{}\"] }})\n```",
202 base, path
203 )
204 }
205 "bash" => {
206 format!(
207 "{}\n```\nrlm({{ query: \"Analyze this command output\", content: \"<paste or use content_paths>\" }})\n```",
208 base
209 )
210 }
211 "grep" => {
212 let pattern = args
213 .get("pattern")
214 .and_then(|v| v.as_str())
215 .unwrap_or("...");
216 let include = args.get("include").and_then(|v| v.as_str()).unwrap_or("*");
217 format!(
218 "{}\n```\nrlm({{ query: \"Summarize search results for {}\", content_glob: \"{}\" }})\n```",
219 base, pattern, include
220 )
221 }
222 _ => {
223 format!(
224 "{}\n```\nrlm({{ query: \"Summarize this output\", content: \"...\" }})\n```",
225 base
226 )
227 }
228 }
229 }
230
231 pub async fn auto_process(
241 output: &str,
242 ctx: AutoProcessContext<'_>,
243 config: &RlmConfig,
244 ) -> Result<RlmResult> {
245 let start = Instant::now();
246 let input_tokens = RlmChunker::estimate_tokens(output);
247
248 info!(
249 tool = ctx.tool_id,
250 input_tokens,
251 model = %ctx.model,
252 "RLM: Starting auto-processing"
253 );
254
255 #[cfg(feature = "functiongemma")]
257 let tool_router: Option<ToolCallRouter> = {
258 let cfg = ToolRouterConfig::from_env();
259 ToolCallRouter::from_config(&cfg)
260 .inspect_err(|e| {
261 tracing::debug!(error = %e, "FunctionGemma router unavailable for RLM router");
262 })
263 .ok()
264 .flatten()
265 };
266
267 let tools = rlm_tool_definitions();
269
270 let content_type = RlmChunker::detect_content_type(output);
272 let content_hints = RlmChunker::get_processing_hints(content_type);
273
274 info!(content_type = ?content_type, tool = ctx.tool_id, "RLM: Content type detected");
275
276 let processed_output = if input_tokens > 50000 {
278 RlmChunker::compress(output, 40000, None)
279 } else {
280 output.to_string()
281 };
282
283 let mut repl =
285 super::repl::RlmRepl::new(processed_output.clone(), super::repl::ReplRuntime::Rust);
286
287 let base_query = Self::build_query_for_tool(ctx.tool_id, &ctx.tool_args);
289 let query = format!(
290 "{}\n\n## Content Analysis Hints\n{}",
291 base_query, content_hints
292 );
293
294 let system_prompt = Self::build_rlm_system_prompt(input_tokens, ctx.tool_id, &query);
296
297 let max_iterations = config.max_iterations;
298 let max_subcalls = config.max_subcalls;
299 let mut iterations = 0;
300 let mut subcalls = 0;
301 let mut final_answer: Option<String> = None;
302
303 let exploration = Self::build_exploration_summary(&processed_output, input_tokens);
305
306 let mut conversation = vec![Message {
308 role: Role::User,
309 content: vec![ContentPart::Text {
310 text: format!(
311 "{}\n\nHere is the context exploration:\n```\n{}\n```\n\nNow analyze and answer the query.",
312 system_prompt, exploration
313 ),
314 }],
315 }];
316
317 for i in 0..max_iterations {
318 iterations = i + 1;
319
320 if let Some(ref progress) = ctx.on_progress {
321 progress(ProcessProgress {
322 iteration: iterations,
323 max_iterations,
324 status: "running".to_string(),
325 });
326 }
327
328 if let Some(ref abort) = ctx.abort {
330 if *abort.borrow() {
331 warn!("RLM: Processing aborted");
332 break;
333 }
334 }
335
336 let request = CompletionRequest {
338 messages: conversation.clone(),
339 tools: tools.clone(),
340 model: ctx.model.clone(),
341 temperature: Some(0.7),
342 top_p: None,
343 max_tokens: Some(4000),
344 stop: Vec::new(),
345 };
346
347 let response = match ctx.provider.complete(request).await {
349 Ok(r) => r,
350 Err(e) => {
351 warn!(error = %e, iteration = iterations, "RLM: Model call failed");
352 if iterations > 1 {
353 break; }
355 return Ok(Self::fallback_result(
356 output,
357 ctx.tool_id,
358 &ctx.tool_args,
359 input_tokens,
360 ));
361 }
362 };
363
364 #[cfg(feature = "functiongemma")]
366 let response = if let Some(ref router) = tool_router {
367 router.maybe_reformat(response, &tools).await
368 } else {
369 response
370 };
371
372 let tool_calls: Vec<(String, String, String)> = response
374 .message
375 .content
376 .iter()
377 .filter_map(|p| match p {
378 ContentPart::ToolCall {
379 id,
380 name,
381 arguments,
382 } => Some((id.clone(), name.clone(), arguments.clone())),
383 _ => None,
384 })
385 .collect();
386
387 if !tool_calls.is_empty() {
388 info!(
389 count = tool_calls.len(),
390 iteration = iterations,
391 "RLM router: dispatching structured tool calls"
392 );
393
394 conversation.push(Message {
395 role: Role::Assistant,
396 content: response.message.content.clone(),
397 });
398
399 let mut tool_results: Vec<ContentPart> = Vec::new();
400
401 for (call_id, name, arguments) in &tool_calls {
402 match super::tools::dispatch_tool_call(name, arguments, &mut repl) {
403 Some(super::tools::RlmToolResult::Final(answer)) => {
404 final_answer = Some(answer);
405 tool_results.push(ContentPart::ToolResult {
406 tool_call_id: call_id.clone(),
407 content: "FINAL received".to_string(),
408 });
409 break;
410 }
411 Some(super::tools::RlmToolResult::Output(out)) => {
412 tool_results.push(ContentPart::ToolResult {
413 tool_call_id: call_id.clone(),
414 content: out,
415 });
416 }
417 None => {
418 tool_results.push(ContentPart::ToolResult {
419 tool_call_id: call_id.clone(),
420 content: format!("Unknown tool: {name}"),
421 });
422 }
423 }
424 }
425
426 if !tool_results.is_empty() {
427 conversation.push(Message {
428 role: Role::Tool,
429 content: tool_results,
430 });
431 }
432
433 subcalls += 1;
434 if final_answer.is_some() || subcalls >= max_subcalls {
435 break;
436 }
437 continue;
438 }
439
440 let response_text: String = response
442 .message
443 .content
444 .iter()
445 .filter_map(|p| match p {
446 ContentPart::Text { text } => Some(text.clone()),
447 _ => None,
448 })
449 .collect::<Vec<_>>()
450 .join("\n");
451
452 info!(
453 iteration = iterations,
454 response_len = response_text.len(),
455 "RLM: Model response (text-only fallback)"
456 );
457
458 if let Some(answer) = Self::extract_final(&response_text) {
460 final_answer = Some(answer);
461 break;
462 }
463
464 if iterations >= 3 && response_text.len() > 500 && !response_text.contains("```") {
466 final_answer = Some(response_text.clone());
468 break;
469 }
470
471 conversation.push(Message {
473 role: Role::Assistant,
474 content: vec![ContentPart::Text {
475 text: response_text,
476 }],
477 });
478
479 conversation.push(Message {
481 role: Role::User,
482 content: vec![ContentPart::Text {
483 text: "Continue analysis. Call FINAL(\"your answer\") when ready.".to_string(),
484 }],
485 });
486
487 subcalls += 1;
488 if subcalls >= max_subcalls {
489 warn!(subcalls, max = max_subcalls, "RLM: Max subcalls reached");
490 break;
491 }
492 }
493
494 if let Some(ref progress) = ctx.on_progress {
495 progress(ProcessProgress {
496 iteration: iterations,
497 max_iterations,
498 status: "completed".to_string(),
499 });
500 }
501
502 let answer = final_answer.unwrap_or_else(|| {
504 warn!(
505 iterations,
506 subcalls, "RLM: No FINAL produced, using fallback"
507 );
508 Self::build_enhanced_fallback(output, ctx.tool_id, &ctx.tool_args, input_tokens)
509 });
510
511 let output_tokens = RlmChunker::estimate_tokens(&answer);
512 let compression_ratio = input_tokens as f64 / output_tokens.max(1) as f64;
513 let elapsed_ms = start.elapsed().as_millis() as u64;
514
515 let result = format!(
516 "[RLM: {} → {} tokens | {} iterations | {} sub-calls]\n\n{}",
517 input_tokens, output_tokens, iterations, subcalls, answer
518 );
519
520 info!(
521 input_tokens,
522 output_tokens,
523 iterations,
524 subcalls,
525 elapsed_ms,
526 compression_ratio = format!("{:.1}", compression_ratio),
527 "RLM: Processing complete"
528 );
529
530 Ok(RlmResult {
531 processed: result,
532 stats: RlmStats {
533 input_tokens,
534 output_tokens: RlmChunker::estimate_tokens(&answer),
535 iterations,
536 subcalls,
537 elapsed_ms,
538 compression_ratio,
539 },
540 success: true,
541 error: None,
542 })
543 }
544
545 fn extract_final(text: &str) -> Option<String> {
546 let patterns = [r#"FINAL\s*\(\s*["'`]"#, r#"FINAL!\s*\(\s*["'`]?"#];
548
549 for _pattern_start in patterns {
550 if let Some(start_idx) = text.find("FINAL") {
551 let after = &text[start_idx..];
552
553 if let Some(open_idx) = after.find(['"', '\'', '`']) {
555 let quote_char = after.chars().nth(open_idx)?;
556 let content_start = start_idx + open_idx + 1;
557
558 let content = &text[content_start..];
560 if let Some(close_idx) = content.find(quote_char) {
561 let answer = &content[..close_idx];
562 if !answer.is_empty() {
563 return Some(answer.to_string());
564 }
565 }
566 }
567 }
568 }
569
570 None
571 }
572
573 fn build_exploration_summary(content: &str, input_tokens: usize) -> String {
574 let lines: Vec<&str> = content.lines().collect();
575 let total_lines = lines.len();
576
577 let head: String = lines
578 .iter()
579 .take(30)
580 .copied()
581 .collect::<Vec<_>>()
582 .join("\n");
583 let tail: String = lines
584 .iter()
585 .rev()
586 .take(50)
587 .collect::<Vec<_>>()
588 .into_iter()
589 .rev()
590 .copied()
591 .collect::<Vec<_>>()
592 .join("\n");
593
594 format!(
595 "=== CONTEXT EXPLORATION ===\n\
596 Total: {} chars, {} lines, ~{} tokens\n\n\
597 === FIRST 30 LINES ===\n{}\n\n\
598 === LAST 50 LINES ===\n{}\n\
599 === END EXPLORATION ===",
600 content.len(),
601 total_lines,
602 input_tokens,
603 head,
604 tail
605 )
606 }
607
608 fn build_rlm_system_prompt(input_tokens: usize, tool_id: &str, query: &str) -> String {
609 let context_type = if tool_id == "session_context" {
610 "conversation history"
611 } else {
612 "tool output"
613 };
614
615 format!(
616 r#"You are tasked with analyzing large content that cannot fit in a normal context window.
617
618The content is a {} with {} total tokens.
619
620YOUR TASK: {}
621
622## Analysis Strategy
623
6241. First, examine the exploration (head + tail of content) to understand structure
6252. Identify the most important information for answering the query
6263. Focus on: errors, key decisions, file paths, recent activity
6274. Provide a concise but complete answer
628
629When ready, call FINAL("your detailed answer") with your findings.
630
631Be SPECIFIC - include actual file paths, function names, error messages. Generic summaries are not useful."#,
632 context_type, input_tokens, query
633 )
634 }
635
636 fn build_query_for_tool(tool_id: &str, args: &serde_json::Value) -> String {
637 match tool_id {
638 "read" => {
639 let path = args.get("filePath").and_then(|v| v.as_str()).unwrap_or("unknown");
640 format!("Summarize the key contents of file \"{}\". Focus on: structure, main functions/classes, important logic. Be concise.", path)
641 }
642 "bash" => {
643 "Summarize the command output. Extract key information, results, errors, warnings. Be concise.".to_string()
644 }
645 "grep" => {
646 let pattern = args.get("pattern").and_then(|v| v.as_str()).unwrap_or("pattern");
647 format!("Summarize search results for \"{}\". Group by file, highlight most relevant matches. Be concise.", pattern)
648 }
649 "glob" => {
650 "Summarize the file listing. Group by directory, highlight important files. Be concise.".to_string()
651 }
652 "session_context" => {
653 r#"You are a CONTEXT MEMORY SYSTEM. Create a BRIEFING for an AI assistant to continue this conversation.
654
655CRITICAL: The assistant will ONLY see your briefing - it has NO memory of the conversation.
656
657## What to Extract
658
6591. **PRIMARY GOAL**: What is the user ultimately trying to achieve?
6602. **CURRENT STATE**: What has been accomplished? Current status?
6613. **LAST ACTIONS**: What just happened? (last 3-5 tool calls, their results)
6624. **ACTIVE FILES**: Which files were modified?
6635. **PENDING TASKS**: What remains to be done?
6646. **CRITICAL DETAILS**: File paths, error messages, specific values, decisions made
6657. **NEXT STEPS**: What should happen next?
666
667Be SPECIFIC with file paths, function names, error messages."#.to_string()
668 }
669 _ => "Summarize this output concisely, extracting the most important information.".to_string()
670 }
671 }
672
673 fn build_enhanced_fallback(
674 output: &str,
675 tool_id: &str,
676 tool_args: &serde_json::Value,
677 input_tokens: usize,
678 ) -> String {
679 let lines: Vec<&str> = output.lines().collect();
680
681 if tool_id == "session_context" {
682 let file_matches: Vec<&str> = lines
684 .iter()
685 .filter_map(|l| {
686 if l.contains(".ts")
687 || l.contains(".rs")
688 || l.contains(".py")
689 || l.contains(".json")
690 {
691 Some(*l)
692 } else {
693 None
694 }
695 })
696 .take(15)
697 .collect();
698
699 let tool_calls: Vec<&str> = lines
700 .iter()
701 .filter(|l| l.contains("[Tool "))
702 .take(10)
703 .copied()
704 .collect();
705
706 let errors: Vec<&str> = lines
707 .iter()
708 .filter(|l| {
709 l.to_lowercase().contains("error") || l.to_lowercase().contains("failed")
710 })
711 .take(5)
712 .copied()
713 .collect();
714
715 let head: String = lines
716 .iter()
717 .take(30)
718 .copied()
719 .collect::<Vec<_>>()
720 .join("\n");
721 let tail: String = lines
722 .iter()
723 .rev()
724 .take(80)
725 .collect::<Vec<_>>()
726 .into_iter()
727 .rev()
728 .copied()
729 .collect::<Vec<_>>()
730 .join("\n");
731
732 let mut parts = vec![
733 "## Context Summary (Fallback Mode)".to_string(),
734 format!(
735 "*Original: {} tokens - RLM processing produced insufficient output*",
736 input_tokens
737 ),
738 String::new(),
739 ];
740
741 if !file_matches.is_empty() {
742 parts.push(format!("**Files Mentioned:** {}", file_matches.len()));
743 }
744
745 if !tool_calls.is_empty() {
746 parts.push(format!("**Recent Tool Calls:** {}", tool_calls.join(", ")));
747 }
748
749 if !errors.is_empty() {
750 parts.push("**Recent Errors:**".to_string());
751 for e in errors {
752 parts.push(format!("- {}", e.chars().take(150).collect::<String>()));
753 }
754 }
755
756 parts.push(String::new());
757 parts.push("### Initial Request".to_string());
758 parts.push("```".to_string());
759 parts.push(head);
760 parts.push("```".to_string());
761 parts.push(String::new());
762 parts.push("### Recent Activity".to_string());
763 parts.push("```".to_string());
764 parts.push(tail);
765 parts.push("```".to_string());
766
767 parts.join("\n")
768 } else {
769 let (truncated, _, _) = Self::smart_truncate(output, tool_id, tool_args, 8000);
770 format!(
771 "## Fallback Summary\n*RLM processing failed - showing structured excerpt*\n\n{}",
772 truncated
773 )
774 }
775 }
776
777 fn fallback_result(
778 output: &str,
779 tool_id: &str,
780 tool_args: &serde_json::Value,
781 input_tokens: usize,
782 ) -> RlmResult {
783 let (truncated, _, _) = Self::smart_truncate(output, tool_id, tool_args, 8000);
784 let output_tokens = RlmChunker::estimate_tokens(&truncated);
785
786 RlmResult {
787 processed: format!(
788 "[RLM processing failed, showing truncated output]\n\n{}",
789 truncated
790 ),
791 stats: RlmStats {
792 input_tokens,
793 output_tokens,
794 iterations: 0,
795 subcalls: 0,
796 elapsed_ms: 0,
797 compression_ratio: input_tokens as f64 / output_tokens.max(1) as f64,
798 },
799 success: false,
800 error: Some("Model call failed".to_string()),
801 }
802 }
803}