Skip to main content

codetether_agent/rlm/
chunker.rs

1//! Semantic chunking for large contexts
2//!
3//! Splits content intelligently at natural boundaries and prioritizes
4//! chunks for token budget selection.
5
6use serde::{Deserialize, Serialize};
7
8/// Content type for optimized processing
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum ContentType {
12    Code,
13    Documents,
14    Logs,
15    Conversation,
16    Mixed,
17}
18
19/// A chunk of content with metadata
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Chunk {
22    pub content: String,
23    #[serde(rename = "type")]
24    pub chunk_type: ChunkType,
25    pub start_line: usize,
26    pub end_line: usize,
27    pub tokens: usize,
28    /// Higher = more important to keep
29    pub priority: u8,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
33#[serde(rename_all = "snake_case")]
34pub enum ChunkType {
35    Code,
36    Text,
37    ToolOutput,
38    Conversation,
39}
40
41/// Options for chunking
42#[derive(Debug, Clone)]
43pub struct ChunkOptions {
44    /// Maximum tokens per chunk
45    pub max_chunk_tokens: usize,
46    /// Number of recent lines to always preserve
47    pub preserve_recent: usize,
48}
49
50impl Default for ChunkOptions {
51    fn default() -> Self {
52        Self {
53            max_chunk_tokens: 4000,
54            preserve_recent: 100,
55        }
56    }
57}
58
59/// Semantic chunker for large contexts
60pub struct RlmChunker;
61
62impl RlmChunker {
63    /// Detect the primary type of content for optimized processing
64    pub fn detect_content_type(content: &str) -> ContentType {
65        let lines: Vec<&str> = content.lines().collect();
66        let sample_size = lines.len().min(200);
67        
68        // Sample from head and tail
69        let sample: Vec<&str> = lines.iter()
70            .take(sample_size / 2)
71            .chain(lines.iter().rev().take(sample_size / 2))
72            .copied()
73            .collect();
74
75        let mut code_indicators = 0;
76        let mut log_indicators = 0;
77        let mut conversation_indicators = 0;
78        let mut document_indicators = 0;
79
80        for line in &sample {
81            let trimmed = line.trim();
82
83            // Code indicators
84            if Self::is_code_line(trimmed) {
85                code_indicators += 1;
86            }
87
88            // Log indicators
89            if Self::is_log_line(trimmed) {
90                log_indicators += 1;
91            }
92
93            // Conversation indicators
94            if Self::is_conversation_line(trimmed) {
95                conversation_indicators += 1;
96            }
97
98            // Document indicators
99            if Self::is_document_line(trimmed) {
100                document_indicators += 1;
101            }
102        }
103
104        let total = code_indicators + log_indicators + conversation_indicators + document_indicators;
105        if total == 0 {
106            return ContentType::Mixed;
107        }
108
109        let threshold = (total as f64 * 0.3) as usize;
110
111        if conversation_indicators > threshold {
112            ContentType::Conversation
113        } else if log_indicators > threshold {
114            ContentType::Logs
115        } else if code_indicators > threshold {
116            ContentType::Code
117        } else if document_indicators > threshold {
118            ContentType::Documents
119        } else {
120            ContentType::Mixed
121        }
122    }
123
124    fn is_code_line(line: &str) -> bool {
125        // Function/class/import definitions
126        let patterns = [
127            "function", "class ", "def ", "const ", "let ", "var ",
128            "import ", "export ", "async ", "fn ", "impl ", "struct ",
129            "enum ", "pub ", "use ", "mod ", "trait ",
130        ];
131        
132        if patterns.iter().any(|p| line.starts_with(p)) {
133            return true;
134        }
135
136        // Brace-only or semicolon-only lines
137        if matches!(line, "{" | "}" | "(" | ")" | ";" | "{}" | "};") {
138            return true;
139        }
140
141        // Comment lines
142        if line.starts_with("//") || line.starts_with("#") || 
143           line.starts_with("*") || line.starts_with("/*") {
144            return true;
145        }
146
147        false
148    }
149
150    fn is_log_line(line: &str) -> bool {
151        // ISO date prefix
152        if line.len() >= 10 && line.chars().take(4).all(|c| c.is_ascii_digit()) 
153            && line.chars().nth(4) == Some('-') {
154            return true;
155        }
156
157        // Time prefix [HH:MM
158        if line.starts_with('[') && line.len() > 5 
159            && line.chars().nth(1).is_some_and(|c| c.is_ascii_digit()) {
160            return true;
161        }
162
163        // Log level prefixes
164        let log_levels = ["INFO", "DEBUG", "WARN", "ERROR", "FATAL", "TRACE"];
165        for level in log_levels {
166            if line.starts_with(level) || line.contains(&format!(" {} ", level)) {
167                return true;
168            }
169        }
170
171        false
172    }
173
174    fn is_conversation_line(line: &str) -> bool {
175        let patterns = [
176            "[User]:", "[Assistant]:", "[Human]:", "[AI]:",
177            "User:", "Assistant:", "Human:", "AI:",
178            "[Tool ", "<user>", "<assistant>", "<system>",
179        ];
180        patterns.iter().any(|p| line.starts_with(p))
181    }
182
183    fn is_document_line(line: &str) -> bool {
184        // Markdown headers
185        if line.starts_with('#') && line.chars().nth(1).is_some_and(|c| c == ' ' || c == '#') {
186            return true;
187        }
188
189        // Bold text
190        if line.starts_with("**") && line.contains("**") {
191            return true;
192        }
193
194        // Blockquotes
195        if line.starts_with("> ") {
196            return true;
197        }
198
199        // List items
200        if line.starts_with("- ") && line.len() > 3 {
201            return true;
202        }
203
204        // Long prose lines without code terminators
205        if line.len() > 80 && !line.ends_with('{') && !line.ends_with(';') 
206            && !line.ends_with('(') && !line.ends_with(')') && !line.ends_with('=') {
207            return true;
208        }
209
210        false
211    }
212
213    /// Get processing hints based on content type
214    pub fn get_processing_hints(content_type: ContentType) -> &'static str {
215        match content_type {
216            ContentType::Code => {
217                "This appears to be source code. Focus on:\n\
218                 - Function/class definitions and their purposes\n\
219                 - Import statements and dependencies\n\
220                 - Error handling patterns\n\
221                 - Key algorithms and logic flow"
222            }
223            ContentType::Logs => {
224                "This appears to be log output. Focus on:\n\
225                 - Error and warning messages\n\
226                 - Timestamps and event sequences\n\
227                 - Stack traces and exceptions\n\
228                 - Key events and state changes"
229            }
230            ContentType::Conversation => {
231                "This appears to be conversation history. Focus on:\n\
232                 - User's original request/goal\n\
233                 - Key decisions made\n\
234                 - Tool calls and their results\n\
235                 - Current state and pending tasks"
236            }
237            ContentType::Documents => {
238                "This appears to be documentation or prose. Focus on:\n\
239                 - Main topics and structure\n\
240                 - Key information and facts\n\
241                 - Actionable items\n\
242                 - References and links"
243            }
244            ContentType::Mixed => {
245                "Mixed content detected. Analyze the structure first, then extract key information."
246            }
247        }
248    }
249
250    /// Estimate token count (roughly 4 chars per token)
251    pub fn estimate_tokens(text: &str) -> usize {
252        text.len().div_ceil(4)
253    }
254
255    /// Split content into semantic chunks
256    pub fn chunk(content: &str, options: Option<ChunkOptions>) -> Vec<Chunk> {
257        let opts = options.unwrap_or_default();
258        let lines: Vec<&str> = content.lines().collect();
259        let mut chunks = Vec::new();
260
261        // Find semantic boundaries
262        let boundaries = Self::find_boundaries(&lines);
263
264        let mut current_chunk: Vec<&str> = Vec::new();
265        let mut current_type = ChunkType::Text;
266        let mut current_start = 0;
267        let mut current_priority: u8 = 1;
268
269        for (i, line) in lines.iter().enumerate() {
270            // Check if we hit a boundary
271            if let Some((boundary_type, boundary_priority)) = boundaries.get(&i) {
272                if !current_chunk.is_empty() {
273                    let content = current_chunk.join("\n");
274                    let tokens = Self::estimate_tokens(&content);
275
276                    // If chunk is too big, split it
277                    if tokens > opts.max_chunk_tokens {
278                        let sub_chunks = Self::split_large_chunk(
279                            &current_chunk, current_start, current_type, opts.max_chunk_tokens
280                        );
281                        chunks.extend(sub_chunks);
282                    } else {
283                        chunks.push(Chunk {
284                            content,
285                            chunk_type: current_type,
286                            start_line: current_start,
287                            end_line: i.saturating_sub(1),
288                            tokens,
289                            priority: current_priority,
290                        });
291                    }
292
293                    current_chunk = Vec::new();
294                    current_start = i;
295                    current_type = *boundary_type;
296                    current_priority = *boundary_priority;
297                }
298            }
299
300            current_chunk.push(line);
301
302            // Boost priority for recent lines
303            if i >= lines.len().saturating_sub(opts.preserve_recent) {
304                current_priority = current_priority.max(8);
305            }
306        }
307
308        // Final chunk
309        if !current_chunk.is_empty() {
310            let content = current_chunk.join("\n");
311            let tokens = Self::estimate_tokens(&content);
312
313            if tokens > opts.max_chunk_tokens {
314                let sub_chunks = Self::split_large_chunk(
315                    &current_chunk, current_start, current_type, opts.max_chunk_tokens
316                );
317                chunks.extend(sub_chunks);
318            } else {
319                chunks.push(Chunk {
320                    content,
321                    chunk_type: current_type,
322                    start_line: current_start,
323                    end_line: lines.len().saturating_sub(1),
324                    tokens,
325                    priority: current_priority,
326                });
327            }
328        }
329
330        chunks
331    }
332
333    /// Find semantic boundaries in content
334    fn find_boundaries(lines: &[&str]) -> std::collections::HashMap<usize, (ChunkType, u8)> {
335        let mut boundaries = std::collections::HashMap::new();
336
337        for (i, line) in lines.iter().enumerate() {
338            let trimmed = line.trim();
339
340            // User/Assistant message markers
341            if trimmed.starts_with("[User]:") || trimmed.starts_with("[Assistant]:") {
342                boundaries.insert(i, (ChunkType::Conversation, 5));
343                continue;
344            }
345
346            // Tool output markers
347            if trimmed.starts_with("[Tool ") {
348                let priority = if trimmed.contains("FAILED") || trimmed.contains("error") { 7 } else { 3 };
349                boundaries.insert(i, (ChunkType::ToolOutput, priority));
350                continue;
351            }
352
353            // Code block markers
354            if trimmed.starts_with("```") {
355                boundaries.insert(i, (ChunkType::Code, 4));
356                continue;
357            }
358
359            // File path markers
360            if trimmed.starts_with('/') || trimmed.starts_with("./") || trimmed.starts_with("~/") {
361                boundaries.insert(i, (ChunkType::Code, 4));
362                continue;
363            }
364
365            // Function/class definitions
366            let def_patterns = ["function", "class ", "def ", "async function", "export", "fn ", "impl ", "struct ", "enum "];
367            if def_patterns.iter().any(|p| trimmed.starts_with(p)) {
368                boundaries.insert(i, (ChunkType::Code, 5));
369                continue;
370            }
371
372            // Error markers
373            if trimmed.to_lowercase().starts_with("error") || 
374               trimmed.to_lowercase().contains("error:") ||
375               trimmed.starts_with("Exception") || 
376               trimmed.contains("FAILED") {
377                boundaries.insert(i, (ChunkType::Text, 8));
378                continue;
379            }
380
381            // Section headers
382            if trimmed.starts_with('#') && trimmed.len() > 2 && trimmed.chars().nth(1) == Some(' ') {
383                boundaries.insert(i, (ChunkType::Text, 6));
384                continue;
385            }
386        }
387
388        boundaries
389    }
390
391    /// Split a large chunk into smaller pieces
392    fn split_large_chunk(
393        lines: &[&str],
394        start_line: usize,
395        chunk_type: ChunkType,
396        max_tokens: usize,
397    ) -> Vec<Chunk> {
398        let mut chunks = Vec::new();
399        let mut current: Vec<&str> = Vec::new();
400        let mut current_tokens = 0;
401        let mut current_start = start_line;
402
403        for (i, line) in lines.iter().enumerate() {
404            let line_tokens = Self::estimate_tokens(line);
405
406            if current_tokens + line_tokens > max_tokens && !current.is_empty() {
407                chunks.push(Chunk {
408                    content: current.join("\n"),
409                    chunk_type,
410                    start_line: current_start,
411                    end_line: start_line + i - 1,
412                    tokens: current_tokens,
413                    priority: 3,
414                });
415                current = Vec::new();
416                current_tokens = 0;
417                current_start = start_line + i;
418            }
419
420            current.push(line);
421            current_tokens += line_tokens;
422        }
423
424        if !current.is_empty() {
425            chunks.push(Chunk {
426                content: current.join("\n"),
427                chunk_type,
428                start_line: current_start,
429                end_line: start_line + lines.len() - 1,
430                tokens: current_tokens,
431                priority: 3,
432            });
433        }
434
435        chunks
436    }
437
438    /// Select chunks to fit within a token budget
439    /// Prioritizes high-priority chunks and recent content
440    pub fn select_chunks(chunks: &[Chunk], max_tokens: usize) -> Vec<Chunk> {
441        let mut sorted: Vec<_> = chunks.to_vec();
442        
443        // Sort by priority (desc), then by line number (desc for recent)
444        sorted.sort_by(|a, b| {
445            match b.priority.cmp(&a.priority) {
446                std::cmp::Ordering::Equal => b.start_line.cmp(&a.start_line),
447                other => other,
448            }
449        });
450
451        let mut selected = Vec::new();
452        let mut total_tokens = 0;
453
454        for chunk in sorted {
455            if total_tokens + chunk.tokens <= max_tokens {
456                selected.push(chunk.clone());
457                total_tokens += chunk.tokens;
458            }
459        }
460
461        // Re-sort by line number for coherent output
462        selected.sort_by_key(|c| c.start_line);
463
464        selected
465    }
466
467    /// Reassemble selected chunks into a single string
468    pub fn reassemble(chunks: &[Chunk]) -> String {
469        if chunks.is_empty() {
470            return String::new();
471        }
472
473        let mut parts = Vec::new();
474        let mut last_end: Option<usize> = None;
475
476        for chunk in chunks {
477            // Add separator if there's a gap
478            if let Some(end) = last_end {
479                if chunk.start_line > end + 1 {
480                    let gap = chunk.start_line - end - 1;
481                    parts.push(format!("\n[... {} lines omitted ...]\n", gap));
482                }
483            }
484            parts.push(chunk.content.clone());
485            last_end = Some(chunk.end_line);
486        }
487
488        parts.join("\n")
489    }
490
491    /// Intelligently compress content to fit within token budget
492    pub fn compress(content: &str, max_tokens: usize, options: Option<ChunkOptions>) -> String {
493        let chunks = Self::chunk(content, options);
494        let selected = Self::select_chunks(&chunks, max_tokens);
495        Self::reassemble(&selected)
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn test_detect_code() {
505        let content = r#"
506fn main() {
507    println!("Hello, world!");
508}
509
510impl Foo {
511    pub fn new() -> Self {
512        Self {}
513    }
514}
515"#;
516        assert_eq!(RlmChunker::detect_content_type(content), ContentType::Code);
517    }
518
519    #[test]
520    fn test_detect_conversation() {
521        let content = r#"
522[User]: Can you help me with this?
523
524[Assistant]: Of course! What do you need?
525
526[User]: I want to implement a feature.
527"#;
528        assert_eq!(RlmChunker::detect_content_type(content), ContentType::Conversation);
529    }
530
531    #[test]
532    fn test_compress() {
533        let content = "line\n".repeat(1000);
534        let compressed = RlmChunker::compress(&content, 100, None);
535        let tokens = RlmChunker::estimate_tokens(&compressed);
536        assert!(tokens <= 100 || compressed.contains("[..."));
537    }
538}