Skip to main content

codetether_agent/rlm/
router.rs

1//! RLM Router - Decides when to route content through RLM processing
2//!
3//! Routes large tool outputs through RLM when they would exceed
4//! the model's context window threshold.
5
6use super::{RlmChunker, RlmConfig, RlmResult, RlmStats};
7use crate::provider::{CompletionRequest, ContentPart, Message, Provider, Role};
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::HashSet;
11use std::sync::Arc;
12use std::time::Instant;
13use tracing::{info, warn};
14
15/// Tools eligible for RLM routing
16fn rlm_eligible_tools() -> HashSet<&'static str> {
17    ["read", "glob", "grep", "bash", "search"].iter().copied().collect()
18}
19
20/// Context for routing decisions
21#[derive(Debug, Clone)]
22pub struct RoutingContext {
23    pub tool_id: String,
24    pub session_id: String,
25    pub call_id: Option<String>,
26    pub model_context_limit: usize,
27    pub current_context_tokens: Option<usize>,
28}
29
30/// Result of routing decision
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct RoutingResult {
33    pub should_route: bool,
34    pub reason: String,
35    pub estimated_tokens: usize,
36}
37
38/// Context for auto-processing
39pub struct AutoProcessContext<'a> {
40    pub tool_id: &'a str,
41    pub tool_args: serde_json::Value,
42    pub session_id: &'a str,
43    pub abort: Option<tokio::sync::watch::Receiver<bool>>,
44    pub on_progress: Option<Box<dyn Fn(ProcessProgress) + Send + Sync>>,
45    pub provider: Arc<dyn Provider>,
46    pub model: String,
47}
48
49/// Progress update during processing
50#[derive(Debug, Clone)]
51pub struct ProcessProgress {
52    pub iteration: usize,
53    pub max_iterations: usize,
54    pub status: String,
55}
56
57/// RLM Router for large content processing
58pub struct RlmRouter;
59
60impl RlmRouter {
61    /// Check if a tool output should be routed through RLM
62    pub fn should_route(output: &str, ctx: &RoutingContext, config: &RlmConfig) -> RoutingResult {
63        let estimated_tokens = RlmChunker::estimate_tokens(output);
64
65        // Mode: off - never route
66        if config.mode == "off" {
67            return RoutingResult {
68                should_route: false,
69                reason: "rlm_mode_off".to_string(),
70                estimated_tokens,
71            };
72        }
73
74        // Mode: always - always route for eligible tools
75        if config.mode == "always" {
76            if !rlm_eligible_tools().contains(ctx.tool_id.as_str()) {
77                return RoutingResult {
78                    should_route: false,
79                    reason: "tool_not_eligible".to_string(),
80                    estimated_tokens,
81                };
82            }
83            return RoutingResult {
84                should_route: true,
85                reason: "rlm_mode_always".to_string(),
86                estimated_tokens,
87            };
88        }
89
90        // Mode: auto - route based on threshold
91        if !rlm_eligible_tools().contains(ctx.tool_id.as_str()) {
92            return RoutingResult {
93                should_route: false,
94                reason: "tool_not_eligible".to_string(),
95                estimated_tokens,
96            };
97        }
98
99        // Check if output exceeds threshold relative to context window
100        let threshold_tokens = (ctx.model_context_limit as f64 * config.threshold) as usize;
101        if estimated_tokens > threshold_tokens {
102            return RoutingResult {
103                should_route: true,
104                reason: "exceeds_threshold".to_string(),
105                estimated_tokens,
106            };
107        }
108
109        // Check if adding this output would cause overflow
110        if let Some(current) = ctx.current_context_tokens {
111            let projected_total = current + estimated_tokens;
112            if projected_total > (ctx.model_context_limit as f64 * 0.8) as usize {
113                return RoutingResult {
114                    should_route: true,
115                    reason: "would_overflow".to_string(),
116                    estimated_tokens,
117                };
118            }
119        }
120
121        RoutingResult {
122            should_route: false,
123            reason: "within_threshold".to_string(),
124            estimated_tokens,
125        }
126    }
127
128    /// Smart truncate large output with RLM hint
129    pub fn smart_truncate(
130        output: &str,
131        tool_id: &str,
132        tool_args: &serde_json::Value,
133        max_tokens: usize,
134    ) -> (String, bool, usize) {
135        let estimated_tokens = RlmChunker::estimate_tokens(output);
136
137        if estimated_tokens <= max_tokens {
138            return (output.to_string(), false, estimated_tokens);
139        }
140
141        info!(
142            tool = tool_id,
143            original_tokens = estimated_tokens,
144            max_tokens,
145            "Smart truncating large output"
146        );
147
148        // Calculate how much to keep (roughly 4 chars per token)
149        let max_chars = max_tokens * 4;
150        let head_chars = (max_chars as f64 * 0.6) as usize;
151        let tail_chars = (max_chars as f64 * 0.3) as usize;
152
153        let head: String = output.chars().take(head_chars).collect();
154        let tail: String = output.chars().rev().take(tail_chars).collect::<String>().chars().rev().collect();
155
156        let omitted_tokens = estimated_tokens - RlmChunker::estimate_tokens(&head) - RlmChunker::estimate_tokens(&tail);
157        let rlm_hint = Self::build_rlm_hint(tool_id, tool_args, estimated_tokens);
158
159        let truncated = format!(
160            "{}\n\n[... {} tokens truncated ...]\n\n{}\n\n{}",
161            head, omitted_tokens, rlm_hint, tail
162        );
163
164        (truncated, true, estimated_tokens)
165    }
166
167    fn build_rlm_hint(tool_id: &str, args: &serde_json::Value, tokens: usize) -> String {
168        let base = format!("⚠️ OUTPUT TOO LARGE ({} tokens). Use RLM for full analysis:", tokens);
169
170        match tool_id {
171            "read" => {
172                let path = args.get("filePath").and_then(|v| v.as_str()).unwrap_or("...");
173                format!("{}\n```\nrlm({{ query: \"Analyze this file\", content_paths: [\"{}\"] }})\n```", base, path)
174            }
175            "bash" => {
176                format!("{}\n```\nrlm({{ query: \"Analyze this command output\", content: \"<paste or use content_paths>\" }})\n```", base)
177            }
178            "grep" => {
179                let pattern = args.get("pattern").and_then(|v| v.as_str()).unwrap_or("...");
180                let include = args.get("include").and_then(|v| v.as_str()).unwrap_or("*");
181                format!("{}\n```\nrlm({{ query: \"Summarize search results for {}\", content_glob: \"{}\" }})\n```", base, pattern, include)
182            }
183            _ => {
184                format!("{}\n```\nrlm({{ query: \"Summarize this output\", content: \"...\" }})\n```", base)
185            }
186        }
187    }
188
189    /// Automatically process large output through RLM
190    ///
191    /// Based on "Recursive Language Models" (Zhang et al. 2025):
192    /// - Context is loaded as a variable in a REPL-like environment
193    /// - LLM writes code/queries to analyze, decompose, and recursively sub-call itself
194    pub async fn auto_process(output: &str, ctx: AutoProcessContext<'_>, config: &RlmConfig) -> Result<RlmResult> {
195        let start = Instant::now();
196        let input_tokens = RlmChunker::estimate_tokens(output);
197
198        info!(
199            tool = ctx.tool_id,
200            input_tokens,
201            model = %ctx.model,
202            "RLM: Starting auto-processing"
203        );
204
205        // Detect content type for smarter processing
206        let content_type = RlmChunker::detect_content_type(output);
207        let content_hints = RlmChunker::get_processing_hints(content_type);
208
209        info!(content_type = ?content_type, tool = ctx.tool_id, "RLM: Content type detected");
210
211        // For very large contexts, use semantic chunking to preserve important parts
212        let processed_output = if input_tokens > 50000 {
213            RlmChunker::compress(output, 40000, None)
214        } else {
215            output.to_string()
216        };
217
218        // Build the query based on tool type
219        let base_query = Self::build_query_for_tool(ctx.tool_id, &ctx.tool_args);
220        let query = format!("{}\n\n## Content Analysis Hints\n{}", base_query, content_hints);
221
222        // Build the RLM system prompt
223        let system_prompt = Self::build_rlm_system_prompt(input_tokens, ctx.tool_id, &query);
224
225        let max_iterations = config.max_iterations;
226        let max_subcalls = config.max_subcalls;
227        let mut iterations = 0;
228        let mut subcalls = 0;
229        let mut final_answer: Option<String> = None;
230
231        // Build initial exploration prompt
232        let exploration = Self::build_exploration_summary(&processed_output, input_tokens);
233
234        // Run iterative analysis
235        let mut conversation = vec![
236            Message {
237                role: Role::User,
238                content: vec![ContentPart::Text {
239                    text: format!(
240                        "{}\n\nHere is the context exploration:\n```\n{}\n```\n\nNow analyze and answer the query.",
241                        system_prompt, exploration
242                    ),
243                }],
244            },
245        ];
246
247        for i in 0..max_iterations {
248            iterations = i + 1;
249
250            if let Some(ref progress) = ctx.on_progress {
251                progress(ProcessProgress {
252                    iteration: iterations,
253                    max_iterations,
254                    status: "running".to_string(),
255                });
256            }
257
258            // Check for abort
259            if let Some(ref abort) = ctx.abort {
260                if *abort.borrow() {
261                    warn!("RLM: Processing aborted");
262                    break;
263                }
264            }
265
266            // Build completion request
267            let request = CompletionRequest {
268                messages: conversation.clone(),
269                tools: Vec::new(),
270                model: ctx.model.clone(),
271                temperature: Some(0.7),
272                top_p: None,
273                max_tokens: Some(4000),
274                stop: Vec::new(),
275            };
276
277            // Call the model
278            let response = match ctx.provider.complete(request).await {
279                Ok(r) => r,
280                Err(e) => {
281                    warn!(error = %e, iteration = iterations, "RLM: Model call failed");
282                    if iterations > 1 {
283                        break; // Use what we have
284                    }
285                    return Ok(Self::fallback_result(output, ctx.tool_id, &ctx.tool_args, input_tokens));
286                }
287            };
288
289            let response_text: String = response.message.content
290                .iter()
291                .filter_map(|p| match p {
292                    ContentPart::Text { text } => Some(text.clone()),
293                    _ => None,
294                })
295                .collect::<Vec<_>>()
296                .join("\n");
297
298            info!(
299                iteration = iterations,
300                response_len = response_text.len(),
301                "RLM: Model response"
302            );
303
304            // Check for FINAL answer
305            if let Some(answer) = Self::extract_final(&response_text) {
306                final_answer = Some(answer);
307                break;
308            }
309
310            // Check for analysis that can be used directly
311            if iterations >= 3 && response_text.len() > 500 && !response_text.contains("```") {
312                // The model is providing direct analysis, use it
313                final_answer = Some(response_text.clone());
314                break;
315            }
316
317            // Add response to conversation
318            conversation.push(Message {
319                role: Role::Assistant,
320                content: vec![ContentPart::Text { text: response_text }],
321            });
322
323            // Prompt for continuation
324            conversation.push(Message {
325                role: Role::User,
326                content: vec![ContentPart::Text {
327                    text: "Continue analysis. Call FINAL(\"your answer\") when ready.".to_string(),
328                }],
329            });
330
331            subcalls += 1;
332            if subcalls >= max_subcalls {
333                warn!(subcalls, max = max_subcalls, "RLM: Max subcalls reached");
334                break;
335            }
336        }
337
338        if let Some(ref progress) = ctx.on_progress {
339            progress(ProcessProgress {
340                iteration: iterations,
341                max_iterations,
342                status: "completed".to_string(),
343            });
344        }
345
346        // Fallback if no FINAL was produced
347        let answer = final_answer.unwrap_or_else(|| {
348            warn!(iterations, subcalls, "RLM: No FINAL produced, using fallback");
349            Self::build_enhanced_fallback(output, ctx.tool_id, &ctx.tool_args, input_tokens)
350        });
351
352        let output_tokens = RlmChunker::estimate_tokens(&answer);
353        let compression_ratio = input_tokens as f64 / output_tokens.max(1) as f64;
354        let elapsed_ms = start.elapsed().as_millis() as u64;
355
356        let result = format!(
357            "[RLM: {} → {} tokens | {} iterations | {} sub-calls]\n\n{}",
358            input_tokens, output_tokens, iterations, subcalls, answer
359        );
360
361        info!(
362            input_tokens,
363            output_tokens,
364            iterations,
365            subcalls,
366            elapsed_ms,
367            compression_ratio = format!("{:.1}", compression_ratio),
368            "RLM: Processing complete"
369        );
370
371        Ok(RlmResult {
372            processed: result,
373            stats: RlmStats {
374                input_tokens,
375                output_tokens: RlmChunker::estimate_tokens(&answer),
376                iterations,
377                subcalls,
378                elapsed_ms,
379                compression_ratio,
380            },
381            success: true,
382            error: None,
383        })
384    }
385
386    fn extract_final(text: &str) -> Option<String> {
387        // Look for FINAL("...") or FINAL('...') or FINAL!(...)
388        let patterns = [
389            r#"FINAL\s*\(\s*["'`]"#,
390            r#"FINAL!\s*\(\s*["'`]?"#,
391        ];
392
393        for _pattern_start in patterns {
394            if let Some(start_idx) = text.find("FINAL") {
395                let after = &text[start_idx..];
396                
397                // Find the opening quote/paren
398                if let Some(open_idx) = after.find(['"', '\'', '`']) {
399                    let quote_char = after.chars().nth(open_idx)?;
400                    let content_start = start_idx + open_idx + 1;
401                    
402                    // Find matching close
403                    let content = &text[content_start..];
404                    if let Some(close_idx) = content.find(quote_char) {
405                        let answer = &content[..close_idx];
406                        if !answer.is_empty() {
407                            return Some(answer.to_string());
408                        }
409                    }
410                }
411            }
412        }
413
414        None
415    }
416
417    fn build_exploration_summary(content: &str, input_tokens: usize) -> String {
418        let lines: Vec<&str> = content.lines().collect();
419        let total_lines = lines.len();
420
421        let head: String = lines.iter().take(30).copied().collect::<Vec<_>>().join("\n");
422        let tail: String = lines.iter().rev().take(50).collect::<Vec<_>>().into_iter().rev().copied().collect::<Vec<_>>().join("\n");
423
424        format!(
425            "=== CONTEXT EXPLORATION ===\n\
426             Total: {} chars, {} lines, ~{} tokens\n\n\
427             === FIRST 30 LINES ===\n{}\n\n\
428             === LAST 50 LINES ===\n{}\n\
429             === END EXPLORATION ===",
430            content.len(), total_lines, input_tokens, head, tail
431        )
432    }
433
434    fn build_rlm_system_prompt(input_tokens: usize, tool_id: &str, query: &str) -> String {
435        let context_type = if tool_id == "session_context" { "conversation history" } else { "tool output" };
436
437        format!(
438            r#"You are tasked with analyzing large content that cannot fit in a normal context window.
439
440The content is a {} with {} total tokens.
441
442YOUR TASK: {}
443
444## Analysis Strategy
445
4461. First, examine the exploration (head + tail of content) to understand structure
4472. Identify the most important information for answering the query
4483. Focus on: errors, key decisions, file paths, recent activity
4494. Provide a concise but complete answer
450
451When ready, call FINAL("your detailed answer") with your findings.
452
453Be SPECIFIC - include actual file paths, function names, error messages. Generic summaries are not useful."#,
454            context_type, input_tokens, query
455        )
456    }
457
458    fn build_query_for_tool(tool_id: &str, args: &serde_json::Value) -> String {
459        match tool_id {
460            "read" => {
461                let path = args.get("filePath").and_then(|v| v.as_str()).unwrap_or("unknown");
462                format!("Summarize the key contents of file \"{}\". Focus on: structure, main functions/classes, important logic. Be concise.", path)
463            }
464            "bash" => {
465                "Summarize the command output. Extract key information, results, errors, warnings. Be concise.".to_string()
466            }
467            "grep" => {
468                let pattern = args.get("pattern").and_then(|v| v.as_str()).unwrap_or("pattern");
469                format!("Summarize search results for \"{}\". Group by file, highlight most relevant matches. Be concise.", pattern)
470            }
471            "glob" => {
472                "Summarize the file listing. Group by directory, highlight important files. Be concise.".to_string()
473            }
474            "session_context" => {
475                r#"You are a CONTEXT MEMORY SYSTEM. Create a BRIEFING for an AI assistant to continue this conversation.
476
477CRITICAL: The assistant will ONLY see your briefing - it has NO memory of the conversation.
478
479## What to Extract
480
4811. **PRIMARY GOAL**: What is the user ultimately trying to achieve?
4822. **CURRENT STATE**: What has been accomplished? Current status?
4833. **LAST ACTIONS**: What just happened? (last 3-5 tool calls, their results)
4844. **ACTIVE FILES**: Which files were modified?
4855. **PENDING TASKS**: What remains to be done?
4866. **CRITICAL DETAILS**: File paths, error messages, specific values, decisions made
4877. **NEXT STEPS**: What should happen next?
488
489Be SPECIFIC with file paths, function names, error messages."#.to_string()
490            }
491            _ => "Summarize this output concisely, extracting the most important information.".to_string()
492        }
493    }
494
495    fn build_enhanced_fallback(output: &str, tool_id: &str, tool_args: &serde_json::Value, input_tokens: usize) -> String {
496        let lines: Vec<&str> = output.lines().collect();
497
498        if tool_id == "session_context" {
499            // Extract key structural information
500            let file_matches: Vec<&str> = lines.iter()
501                .filter_map(|l| {
502                    if l.contains(".ts") || l.contains(".rs") || l.contains(".py") || l.contains(".json") {
503                        Some(*l)
504                    } else {
505                        None
506                    }
507                })
508                .take(15)
509                .collect();
510
511            let tool_calls: Vec<&str> = lines.iter()
512                .filter(|l| l.contains("[Tool "))
513                .take(10)
514                .copied()
515                .collect();
516
517            let errors: Vec<&str> = lines.iter()
518                .filter(|l| l.to_lowercase().contains("error") || l.to_lowercase().contains("failed"))
519                .take(5)
520                .copied()
521                .collect();
522
523            let head: String = lines.iter().take(30).copied().collect::<Vec<_>>().join("\n");
524            let tail: String = lines.iter().rev().take(80).collect::<Vec<_>>().into_iter().rev().copied().collect::<Vec<_>>().join("\n");
525
526            let mut parts = vec![
527                "## Context Summary (Fallback Mode)".to_string(),
528                format!("*Original: {} tokens - RLM processing produced insufficient output*", input_tokens),
529                String::new(),
530            ];
531
532            if !file_matches.is_empty() {
533                parts.push(format!("**Files Mentioned:** {}", file_matches.len()));
534            }
535
536            if !tool_calls.is_empty() {
537                parts.push(format!("**Recent Tool Calls:** {}", tool_calls.join(", ")));
538            }
539
540            if !errors.is_empty() {
541                parts.push("**Recent Errors:**".to_string());
542                for e in errors {
543                    parts.push(format!("- {}", e.chars().take(150).collect::<String>()));
544                }
545            }
546
547            parts.push(String::new());
548            parts.push("### Initial Request".to_string());
549            parts.push("```".to_string());
550            parts.push(head);
551            parts.push("```".to_string());
552            parts.push(String::new());
553            parts.push("### Recent Activity".to_string());
554            parts.push("```".to_string());
555            parts.push(tail);
556            parts.push("```".to_string());
557
558            parts.join("\n")
559        } else {
560            let (truncated, _, _) = Self::smart_truncate(output, tool_id, tool_args, 8000);
561            format!("## Fallback Summary\n*RLM processing failed - showing structured excerpt*\n\n{}", truncated)
562        }
563    }
564
565    fn fallback_result(output: &str, tool_id: &str, tool_args: &serde_json::Value, input_tokens: usize) -> RlmResult {
566        let (truncated, _, _) = Self::smart_truncate(output, tool_id, tool_args, 8000);
567        let output_tokens = RlmChunker::estimate_tokens(&truncated);
568
569        RlmResult {
570            processed: format!("[RLM processing failed, showing truncated output]\n\n{}", truncated),
571            stats: RlmStats {
572                input_tokens,
573                output_tokens,
574                iterations: 0,
575                subcalls: 0,
576                elapsed_ms: 0,
577                compression_ratio: input_tokens as f64 / output_tokens.max(1) as f64,
578            },
579            success: false,
580            error: Some("Model call failed".to_string()),
581        }
582    }
583}