1use super::{RlmChunker, RlmConfig, RlmResult, RlmStats};
11use crate::provider::{CompletionRequest, ContentPart, Message, Provider, Role};
12use anyhow::Result;
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15use std::sync::Arc;
16use std::time::Instant;
17use tracing::{info, warn};
18
19#[cfg(feature = "functiongemma")]
20use crate::cognition::tool_router::{ToolCallRouter, ToolRouterConfig};
21
22use super::tools::rlm_tool_definitions;
23
24fn rlm_eligible_tools() -> HashSet<&'static str> {
26 ["read", "glob", "grep", "bash", "search"]
27 .iter()
28 .copied()
29 .collect()
30}
31
32#[derive(Debug, Clone)]
34pub struct RoutingContext {
35 pub tool_id: String,
36 pub session_id: String,
37 pub call_id: Option<String>,
38 pub model_context_limit: usize,
39 pub current_context_tokens: Option<usize>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct RoutingResult {
45 pub should_route: bool,
46 pub reason: String,
47 pub estimated_tokens: usize,
48}
49
50pub struct AutoProcessContext<'a> {
52 pub tool_id: &'a str,
53 pub tool_args: serde_json::Value,
54 pub session_id: &'a str,
55 pub abort: Option<tokio::sync::watch::Receiver<bool>>,
56 pub on_progress: Option<Box<dyn Fn(ProcessProgress) + Send + Sync>>,
57 pub provider: Arc<dyn Provider>,
58 pub model: String,
59}
60
61#[derive(Debug, Clone)]
63pub struct ProcessProgress {
64 pub iteration: usize,
65 pub max_iterations: usize,
66 pub status: String,
67}
68
69pub struct RlmRouter;
71
72impl RlmRouter {
73 pub fn should_route(output: &str, ctx: &RoutingContext, config: &RlmConfig) -> RoutingResult {
75 let estimated_tokens = RlmChunker::estimate_tokens(output);
76
77 if config.mode == "off" {
79 return RoutingResult {
80 should_route: false,
81 reason: "rlm_mode_off".to_string(),
82 estimated_tokens,
83 };
84 }
85
86 if config.mode == "always" {
88 if !rlm_eligible_tools().contains(ctx.tool_id.as_str()) {
89 return RoutingResult {
90 should_route: false,
91 reason: "tool_not_eligible".to_string(),
92 estimated_tokens,
93 };
94 }
95 return RoutingResult {
96 should_route: true,
97 reason: "rlm_mode_always".to_string(),
98 estimated_tokens,
99 };
100 }
101
102 if !rlm_eligible_tools().contains(ctx.tool_id.as_str()) {
104 return RoutingResult {
105 should_route: false,
106 reason: "tool_not_eligible".to_string(),
107 estimated_tokens,
108 };
109 }
110
111 let threshold_tokens = (ctx.model_context_limit as f64 * config.threshold) as usize;
113 if estimated_tokens > threshold_tokens {
114 return RoutingResult {
115 should_route: true,
116 reason: "exceeds_threshold".to_string(),
117 estimated_tokens,
118 };
119 }
120
121 if let Some(current) = ctx.current_context_tokens {
123 let projected_total = current + estimated_tokens;
124 if projected_total > (ctx.model_context_limit as f64 * 0.8) as usize {
125 return RoutingResult {
126 should_route: true,
127 reason: "would_overflow".to_string(),
128 estimated_tokens,
129 };
130 }
131 }
132
133 RoutingResult {
134 should_route: false,
135 reason: "within_threshold".to_string(),
136 estimated_tokens,
137 }
138 }
139
140 pub fn smart_truncate(
142 output: &str,
143 tool_id: &str,
144 tool_args: &serde_json::Value,
145 max_tokens: usize,
146 ) -> (String, bool, usize) {
147 let estimated_tokens = RlmChunker::estimate_tokens(output);
148
149 if estimated_tokens <= max_tokens {
150 return (output.to_string(), false, estimated_tokens);
151 }
152
153 info!(
154 tool = tool_id,
155 original_tokens = estimated_tokens,
156 max_tokens,
157 "Smart truncating large output"
158 );
159
160 let max_chars = max_tokens * 4;
162 let head_chars = (max_chars as f64 * 0.6) as usize;
163 let tail_chars = (max_chars as f64 * 0.3) as usize;
164
165 let head: String = output.chars().take(head_chars).collect();
166 let tail: String = output
167 .chars()
168 .rev()
169 .take(tail_chars)
170 .collect::<String>()
171 .chars()
172 .rev()
173 .collect();
174
175 let omitted_tokens = estimated_tokens
176 - RlmChunker::estimate_tokens(&head)
177 - RlmChunker::estimate_tokens(&tail);
178 let rlm_hint = Self::build_rlm_hint(tool_id, tool_args, estimated_tokens);
179
180 let truncated = format!(
181 "{}\n\n[... {} tokens truncated ...]\n\n{}\n\n{}",
182 head, omitted_tokens, rlm_hint, tail
183 );
184
185 (truncated, true, estimated_tokens)
186 }
187
188 fn build_rlm_hint(tool_id: &str, args: &serde_json::Value, tokens: usize) -> String {
189 let base = format!(
190 "⚠️ OUTPUT TOO LARGE ({} tokens). Use RLM for full analysis:",
191 tokens
192 );
193
194 match tool_id {
195 "read" => {
196 let path = args
197 .get("filePath")
198 .and_then(|v| v.as_str())
199 .unwrap_or("...");
200 format!(
201 "{}\n```\nrlm({{ query: \"Analyze this file\", content_paths: [\"{}\"] }})\n```",
202 base, path
203 )
204 }
205 "bash" => {
206 format!(
207 "{}\n```\nrlm({{ query: \"Analyze this command output\", content: \"<paste or use content_paths>\" }})\n```",
208 base
209 )
210 }
211 "grep" => {
212 let pattern = args
213 .get("pattern")
214 .and_then(|v| v.as_str())
215 .unwrap_or("...");
216 let include = args.get("include").and_then(|v| v.as_str()).unwrap_or("*");
217 format!(
218 "{}\n```\nrlm({{ query: \"Summarize search results for {}\", content_glob: \"{}\" }})\n```",
219 base, pattern, include
220 )
221 }
222 _ => {
223 format!(
224 "{}\n```\nrlm({{ query: \"Summarize this output\", content: \"...\" }})\n```",
225 base
226 )
227 }
228 }
229 }
230
231 pub async fn auto_process(
241 output: &str,
242 ctx: AutoProcessContext<'_>,
243 config: &RlmConfig,
244 ) -> Result<RlmResult> {
245 let start = Instant::now();
246 let input_tokens = RlmChunker::estimate_tokens(output);
247
248 info!(
249 tool = ctx.tool_id,
250 input_tokens,
251 model = %ctx.model,
252 "RLM: Starting auto-processing"
253 );
254
255 #[cfg(feature = "functiongemma")]
257 let tool_router: Option<ToolCallRouter> = {
258 let cfg = ToolRouterConfig::from_env();
259 ToolCallRouter::from_config(&cfg)
260 .inspect_err(|e| {
261 tracing::debug!(error = %e, "FunctionGemma router unavailable for RLM router");
262 })
263 .ok()
264 .flatten()
265 };
266
267 let tools = rlm_tool_definitions();
269
270 let content_type = RlmChunker::detect_content_type(output);
272 let content_hints = RlmChunker::get_processing_hints(content_type);
273
274 info!(content_type = ?content_type, tool = ctx.tool_id, "RLM: Content type detected");
275
276 let processed_output = if input_tokens > 50000 {
278 RlmChunker::compress(output, 40000, None)
279 } else {
280 output.to_string()
281 };
282
283 let mut repl =
285 super::repl::RlmRepl::new(processed_output.clone(), super::repl::ReplRuntime::Rust);
286
287 let base_query = Self::build_query_for_tool(ctx.tool_id, &ctx.tool_args);
289 let query = format!(
290 "{}\n\n## Content Analysis Hints\n{}",
291 base_query, content_hints
292 );
293
294 let system_prompt = Self::build_rlm_system_prompt(input_tokens, ctx.tool_id, &query);
296
297 let max_iterations = config.max_iterations;
298 let max_subcalls = config.max_subcalls;
299 let mut iterations = 0;
300 let mut subcalls = 0;
301 let mut final_answer: Option<String> = None;
302
303 let exploration = Self::build_exploration_summary(&processed_output, input_tokens);
305
306 let mut conversation = vec![Message {
308 role: Role::User,
309 content: vec![ContentPart::Text {
310 text: format!(
311 "{}\n\nHere is the context exploration:\n```\n{}\n```\n\nNow analyze and answer the query.",
312 system_prompt, exploration
313 ),
314 }],
315 }];
316
317 for i in 0..max_iterations {
318 iterations = i + 1;
319
320 if let Some(ref progress) = ctx.on_progress {
321 progress(ProcessProgress {
322 iteration: iterations,
323 max_iterations,
324 status: "running".to_string(),
325 });
326 }
327
328 if let Some(ref abort) = ctx.abort {
330 if *abort.borrow() {
331 warn!("RLM: Processing aborted");
332 break;
333 }
334 }
335
336 let request = CompletionRequest {
338 messages: conversation.clone(),
339 tools: tools.clone(),
340 model: ctx.model.clone(),
341 temperature: Some(0.7),
342 top_p: None,
343 max_tokens: Some(4000),
344 stop: Vec::new(),
345 };
346
347 let response = match ctx.provider.complete(request).await {
349 Ok(r) => r,
350 Err(e) => {
351 warn!(error = %e, iteration = iterations, "RLM: Model call failed");
352 if iterations > 1 {
353 break; }
355 return Ok(Self::fallback_result(
356 output,
357 ctx.tool_id,
358 &ctx.tool_args,
359 input_tokens,
360 ));
361 }
362 };
363
364 #[cfg(feature = "functiongemma")]
366 let response = if let Some(ref router) = tool_router {
367 router.maybe_reformat(response, &tools, true).await
370 } else {
371 response
372 };
373
374 let tool_calls: Vec<(String, String, String)> = response
376 .message
377 .content
378 .iter()
379 .filter_map(|p| match p {
380 ContentPart::ToolCall {
381 id,
382 name,
383 arguments,
384 } => Some((id.clone(), name.clone(), arguments.clone())),
385 _ => None,
386 })
387 .collect();
388
389 if !tool_calls.is_empty() {
390 info!(
391 count = tool_calls.len(),
392 iteration = iterations,
393 "RLM router: dispatching structured tool calls"
394 );
395
396 conversation.push(Message {
397 role: Role::Assistant,
398 content: response.message.content.clone(),
399 });
400
401 let mut tool_results: Vec<ContentPart> = Vec::new();
402
403 for (call_id, name, arguments) in &tool_calls {
404 match super::tools::dispatch_tool_call(name, arguments, &mut repl) {
405 Some(super::tools::RlmToolResult::Final(answer)) => {
406 final_answer = Some(answer);
407 tool_results.push(ContentPart::ToolResult {
408 tool_call_id: call_id.clone(),
409 content: "FINAL received".to_string(),
410 });
411 break;
412 }
413 Some(super::tools::RlmToolResult::Output(out)) => {
414 tool_results.push(ContentPart::ToolResult {
415 tool_call_id: call_id.clone(),
416 content: out,
417 });
418 }
419 None => {
420 tool_results.push(ContentPart::ToolResult {
421 tool_call_id: call_id.clone(),
422 content: format!("Unknown tool: {name}"),
423 });
424 }
425 }
426 }
427
428 if !tool_results.is_empty() {
429 conversation.push(Message {
430 role: Role::Tool,
431 content: tool_results,
432 });
433 }
434
435 subcalls += 1;
436 if final_answer.is_some() || subcalls >= max_subcalls {
437 break;
438 }
439 continue;
440 }
441
442 let response_text: String = response
444 .message
445 .content
446 .iter()
447 .filter_map(|p| match p {
448 ContentPart::Text { text } => Some(text.clone()),
449 _ => None,
450 })
451 .collect::<Vec<_>>()
452 .join("\n");
453
454 info!(
455 iteration = iterations,
456 response_len = response_text.len(),
457 "RLM: Model response (text-only fallback)"
458 );
459
460 if let Some(answer) = Self::extract_final(&response_text) {
462 final_answer = Some(answer);
463 break;
464 }
465
466 if iterations >= 3 && response_text.len() > 500 && !response_text.contains("```") {
468 final_answer = Some(response_text.clone());
470 break;
471 }
472
473 conversation.push(Message {
475 role: Role::Assistant,
476 content: vec![ContentPart::Text {
477 text: response_text,
478 }],
479 });
480
481 conversation.push(Message {
483 role: Role::User,
484 content: vec![ContentPart::Text {
485 text: "Continue analysis. Call FINAL(\"your answer\") when ready.".to_string(),
486 }],
487 });
488
489 subcalls += 1;
490 if subcalls >= max_subcalls {
491 warn!(subcalls, max = max_subcalls, "RLM: Max subcalls reached");
492 break;
493 }
494 }
495
496 if let Some(ref progress) = ctx.on_progress {
497 progress(ProcessProgress {
498 iteration: iterations,
499 max_iterations,
500 status: "completed".to_string(),
501 });
502 }
503
504 let answer = final_answer.unwrap_or_else(|| {
506 warn!(
507 iterations,
508 subcalls, "RLM: No FINAL produced, using fallback"
509 );
510 Self::build_enhanced_fallback(output, ctx.tool_id, &ctx.tool_args, input_tokens)
511 });
512
513 let output_tokens = RlmChunker::estimate_tokens(&answer);
514 let compression_ratio = input_tokens as f64 / output_tokens.max(1) as f64;
515 let elapsed_ms = start.elapsed().as_millis() as u64;
516
517 let result = format!(
518 "[RLM: {} → {} tokens | {} iterations | {} sub-calls]\n\n{}",
519 input_tokens, output_tokens, iterations, subcalls, answer
520 );
521
522 info!(
523 input_tokens,
524 output_tokens,
525 iterations,
526 subcalls,
527 elapsed_ms,
528 compression_ratio = format!("{:.1}", compression_ratio),
529 "RLM: Processing complete"
530 );
531
532 Ok(RlmResult {
533 processed: result,
534 stats: RlmStats {
535 input_tokens,
536 output_tokens: RlmChunker::estimate_tokens(&answer),
537 iterations,
538 subcalls,
539 elapsed_ms,
540 compression_ratio,
541 },
542 success: true,
543 error: None,
544 })
545 }
546
547 fn extract_final(text: &str) -> Option<String> {
548 let patterns = [r#"FINAL\s*\(\s*["'`]"#, r#"FINAL!\s*\(\s*["'`]?"#];
550
551 for _pattern_start in patterns {
552 if let Some(start_idx) = text.find("FINAL") {
553 let after = &text[start_idx..];
554
555 if let Some(open_idx) = after.find(['"', '\'', '`']) {
557 let quote_char = after.chars().nth(open_idx)?;
558 let content_start = start_idx + open_idx + 1;
559
560 let content = &text[content_start..];
562 if let Some(close_idx) = content.find(quote_char) {
563 let answer = &content[..close_idx];
564 if !answer.is_empty() {
565 return Some(answer.to_string());
566 }
567 }
568 }
569 }
570 }
571
572 None
573 }
574
575 fn build_exploration_summary(content: &str, input_tokens: usize) -> String {
576 let lines: Vec<&str> = content.lines().collect();
577 let total_lines = lines.len();
578
579 let head: String = lines
580 .iter()
581 .take(30)
582 .copied()
583 .collect::<Vec<_>>()
584 .join("\n");
585 let tail: String = lines
586 .iter()
587 .rev()
588 .take(50)
589 .collect::<Vec<_>>()
590 .into_iter()
591 .rev()
592 .copied()
593 .collect::<Vec<_>>()
594 .join("\n");
595
596 format!(
597 "=== CONTEXT EXPLORATION ===\n\
598 Total: {} chars, {} lines, ~{} tokens\n\n\
599 === FIRST 30 LINES ===\n{}\n\n\
600 === LAST 50 LINES ===\n{}\n\
601 === END EXPLORATION ===",
602 content.len(),
603 total_lines,
604 input_tokens,
605 head,
606 tail
607 )
608 }
609
610 fn build_rlm_system_prompt(input_tokens: usize, tool_id: &str, query: &str) -> String {
611 let context_type = if tool_id == "session_context" {
612 "conversation history"
613 } else {
614 "tool output"
615 };
616
617 format!(
618 r#"You are tasked with analyzing large content that cannot fit in a normal context window.
619
620The content is a {} with {} total tokens.
621
622YOUR TASK: {}
623
624## Analysis Strategy
625
6261. First, examine the exploration (head + tail of content) to understand structure
6272. Identify the most important information for answering the query
6283. Focus on: errors, key decisions, file paths, recent activity
6294. Provide a concise but complete answer
630
631When ready, call FINAL("your detailed answer") with your findings.
632
633Be SPECIFIC - include actual file paths, function names, error messages. Generic summaries are not useful."#,
634 context_type, input_tokens, query
635 )
636 }
637
638 fn build_query_for_tool(tool_id: &str, args: &serde_json::Value) -> String {
639 match tool_id {
640 "read" => {
641 let path = args.get("filePath").and_then(|v| v.as_str()).unwrap_or("unknown");
642 format!("Summarize the key contents of file \"{}\". Focus on: structure, main functions/classes, important logic. Be concise.", path)
643 }
644 "bash" => {
645 "Summarize the command output. Extract key information, results, errors, warnings. Be concise.".to_string()
646 }
647 "grep" => {
648 let pattern = args.get("pattern").and_then(|v| v.as_str()).unwrap_or("pattern");
649 format!("Summarize search results for \"{}\". Group by file, highlight most relevant matches. Be concise.", pattern)
650 }
651 "glob" => {
652 "Summarize the file listing. Group by directory, highlight important files. Be concise.".to_string()
653 }
654 "session_context" => {
655 r#"You are a CONTEXT MEMORY SYSTEM. Create a BRIEFING for an AI assistant to continue this conversation.
656
657CRITICAL: The assistant will ONLY see your briefing - it has NO memory of the conversation.
658
659## What to Extract
660
6611. **PRIMARY GOAL**: What is the user ultimately trying to achieve?
6622. **CURRENT STATE**: What has been accomplished? Current status?
6633. **LAST ACTIONS**: What just happened? (last 3-5 tool calls, their results)
6644. **ACTIVE FILES**: Which files were modified?
6655. **PENDING TASKS**: What remains to be done?
6666. **CRITICAL DETAILS**: File paths, error messages, specific values, decisions made
6677. **NEXT STEPS**: What should happen next?
668
669Be SPECIFIC with file paths, function names, error messages."#.to_string()
670 }
671 _ => "Summarize this output concisely, extracting the most important information.".to_string()
672 }
673 }
674
675 fn build_enhanced_fallback(
676 output: &str,
677 tool_id: &str,
678 tool_args: &serde_json::Value,
679 input_tokens: usize,
680 ) -> String {
681 let lines: Vec<&str> = output.lines().collect();
682
683 if tool_id == "session_context" {
684 let file_matches: Vec<&str> = lines
686 .iter()
687 .filter_map(|l| {
688 if l.contains(".ts")
689 || l.contains(".rs")
690 || l.contains(".py")
691 || l.contains(".json")
692 {
693 Some(*l)
694 } else {
695 None
696 }
697 })
698 .take(15)
699 .collect();
700
701 let tool_calls: Vec<&str> = lines
702 .iter()
703 .filter(|l| l.contains("[Tool "))
704 .take(10)
705 .copied()
706 .collect();
707
708 let errors: Vec<&str> = lines
709 .iter()
710 .filter(|l| {
711 l.to_lowercase().contains("error") || l.to_lowercase().contains("failed")
712 })
713 .take(5)
714 .copied()
715 .collect();
716
717 let head: String = lines
718 .iter()
719 .take(30)
720 .copied()
721 .collect::<Vec<_>>()
722 .join("\n");
723 let tail: String = lines
724 .iter()
725 .rev()
726 .take(80)
727 .collect::<Vec<_>>()
728 .into_iter()
729 .rev()
730 .copied()
731 .collect::<Vec<_>>()
732 .join("\n");
733
734 let mut parts = vec![
735 "## Context Summary (Fallback Mode)".to_string(),
736 format!(
737 "*Original: {} tokens - RLM processing produced insufficient output*",
738 input_tokens
739 ),
740 String::new(),
741 ];
742
743 if !file_matches.is_empty() {
744 parts.push(format!("**Files Mentioned:** {}", file_matches.len()));
745 }
746
747 if !tool_calls.is_empty() {
748 parts.push(format!("**Recent Tool Calls:** {}", tool_calls.join(", ")));
749 }
750
751 if !errors.is_empty() {
752 parts.push("**Recent Errors:**".to_string());
753 for e in errors {
754 parts.push(format!("- {}", e.chars().take(150).collect::<String>()));
755 }
756 }
757
758 parts.push(String::new());
759 parts.push("### Initial Request".to_string());
760 parts.push("```".to_string());
761 parts.push(head);
762 parts.push("```".to_string());
763 parts.push(String::new());
764 parts.push("### Recent Activity".to_string());
765 parts.push("```".to_string());
766 parts.push(tail);
767 parts.push("```".to_string());
768
769 parts.join("\n")
770 } else {
771 let (truncated, _, _) = Self::smart_truncate(output, tool_id, tool_args, 8000);
772 format!(
773 "## Fallback Summary\n*RLM processing failed - showing structured excerpt*\n\n{}",
774 truncated
775 )
776 }
777 }
778
779 fn fallback_result(
780 output: &str,
781 tool_id: &str,
782 tool_args: &serde_json::Value,
783 input_tokens: usize,
784 ) -> RlmResult {
785 let (truncated, _, _) = Self::smart_truncate(output, tool_id, tool_args, 8000);
786 let output_tokens = RlmChunker::estimate_tokens(&truncated);
787
788 RlmResult {
789 processed: format!(
790 "[RLM processing failed, showing truncated output]\n\n{}",
791 truncated
792 ),
793 stats: RlmStats {
794 input_tokens,
795 output_tokens,
796 iterations: 0,
797 subcalls: 0,
798 elapsed_ms: 0,
799 compression_ratio: input_tokens as f64 / output_tokens.max(1) as f64,
800 },
801 success: false,
802 error: Some("Model call failed".to_string()),
803 }
804 }
805}