1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum ContentType {
12 Code,
13 Documents,
14 Logs,
15 Conversation,
16 Mixed,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Chunk {
22 pub content: String,
23 #[serde(rename = "type")]
24 pub chunk_type: ChunkType,
25 pub start_line: usize,
26 pub end_line: usize,
27 pub tokens: usize,
28 pub priority: u8,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
33#[serde(rename_all = "snake_case")]
34pub enum ChunkType {
35 Code,
36 Text,
37 ToolOutput,
38 Conversation,
39}
40
41#[derive(Debug, Clone)]
43pub struct ChunkOptions {
44 pub max_chunk_tokens: usize,
46 pub preserve_recent: usize,
48}
49
50impl Default for ChunkOptions {
51 fn default() -> Self {
52 Self {
53 max_chunk_tokens: 4000,
54 preserve_recent: 100,
55 }
56 }
57}
58
59pub struct RlmChunker;
61
62impl RlmChunker {
63 pub fn detect_content_type(content: &str) -> ContentType {
65 let lines: Vec<&str> = content.lines().collect();
66 let sample_size = lines.len().min(200);
67
68 let sample: Vec<&str> = lines
70 .iter()
71 .take(sample_size / 2)
72 .chain(lines.iter().rev().take(sample_size / 2))
73 .copied()
74 .collect();
75
76 let mut code_indicators = 0;
77 let mut log_indicators = 0;
78 let mut conversation_indicators = 0;
79 let mut document_indicators = 0;
80
81 for line in &sample {
82 let trimmed = line.trim();
83
84 if Self::is_code_line(trimmed) {
86 code_indicators += 1;
87 }
88
89 if Self::is_log_line(trimmed) {
91 log_indicators += 1;
92 }
93
94 if Self::is_conversation_line(trimmed) {
96 conversation_indicators += 1;
97 }
98
99 if Self::is_document_line(trimmed) {
101 document_indicators += 1;
102 }
103 }
104
105 let total =
106 code_indicators + log_indicators + conversation_indicators + document_indicators;
107 if total == 0 {
108 return ContentType::Mixed;
109 }
110
111 let threshold = (total as f64 * 0.3) as usize;
112
113 if conversation_indicators > threshold {
114 ContentType::Conversation
115 } else if log_indicators > threshold {
116 ContentType::Logs
117 } else if code_indicators > threshold {
118 ContentType::Code
119 } else if document_indicators > threshold {
120 ContentType::Documents
121 } else {
122 ContentType::Mixed
123 }
124 }
125
126 fn is_code_line(line: &str) -> bool {
127 let patterns = [
129 "function", "class ", "def ", "const ", "let ", "var ", "import ", "export ", "async ",
130 "fn ", "impl ", "struct ", "enum ", "pub ", "use ", "mod ", "trait ",
131 ];
132
133 if patterns.iter().any(|p| line.starts_with(p)) {
134 return true;
135 }
136
137 if matches!(line, "{" | "}" | "(" | ")" | ";" | "{}" | "};") {
139 return true;
140 }
141
142 if line.starts_with("//")
144 || line.starts_with("#")
145 || line.starts_with("*")
146 || line.starts_with("/*")
147 {
148 return true;
149 }
150
151 false
152 }
153
154 fn is_log_line(line: &str) -> bool {
155 if line.len() >= 10
157 && line.chars().take(4).all(|c| c.is_ascii_digit())
158 && line.chars().nth(4) == Some('-')
159 {
160 return true;
161 }
162
163 if line.starts_with('[')
165 && line.len() > 5
166 && line.chars().nth(1).is_some_and(|c| c.is_ascii_digit())
167 {
168 return true;
169 }
170
171 let log_levels = ["INFO", "DEBUG", "WARN", "ERROR", "FATAL", "TRACE"];
173 for level in log_levels {
174 if line.starts_with(level) || line.contains(&format!(" {} ", level)) {
175 return true;
176 }
177 }
178
179 false
180 }
181
182 fn is_conversation_line(line: &str) -> bool {
183 let patterns = [
184 "[User]:",
185 "[Assistant]:",
186 "[Human]:",
187 "[AI]:",
188 "User:",
189 "Assistant:",
190 "Human:",
191 "AI:",
192 "[Tool ",
193 "<user>",
194 "<assistant>",
195 "<system>",
196 ];
197 patterns.iter().any(|p| line.starts_with(p))
198 }
199
200 fn is_document_line(line: &str) -> bool {
201 if line.starts_with('#') && line.chars().nth(1).is_some_and(|c| c == ' ' || c == '#') {
203 return true;
204 }
205
206 if line.starts_with("**") && line.contains("**") {
208 return true;
209 }
210
211 if line.starts_with("> ") {
213 return true;
214 }
215
216 if line.starts_with("- ") && line.len() > 3 {
218 return true;
219 }
220
221 if line.len() > 80
223 && !line.ends_with('{')
224 && !line.ends_with(';')
225 && !line.ends_with('(')
226 && !line.ends_with(')')
227 && !line.ends_with('=')
228 {
229 return true;
230 }
231
232 false
233 }
234
235 pub fn get_processing_hints(content_type: ContentType) -> &'static str {
237 match content_type {
238 ContentType::Code => {
239 "This appears to be source code. Focus on:\n\
240 - Function/class definitions and their purposes\n\
241 - Import statements and dependencies\n\
242 - Error handling patterns\n\
243 - Key algorithms and logic flow"
244 }
245 ContentType::Logs => {
246 "This appears to be log output. Focus on:\n\
247 - Error and warning messages\n\
248 - Timestamps and event sequences\n\
249 - Stack traces and exceptions\n\
250 - Key events and state changes"
251 }
252 ContentType::Conversation => {
253 "This appears to be conversation history. Focus on:\n\
254 - User's original request/goal\n\
255 - Key decisions made\n\
256 - Tool calls and their results\n\
257 - Current state and pending tasks"
258 }
259 ContentType::Documents => {
260 "This appears to be documentation or prose. Focus on:\n\
261 - Main topics and structure\n\
262 - Key information and facts\n\
263 - Actionable items\n\
264 - References and links"
265 }
266 ContentType::Mixed => {
267 "Mixed content detected. Analyze the structure first, then extract key information."
268 }
269 }
270 }
271
272 pub fn estimate_tokens(text: &str) -> usize {
274 if text.is_empty() {
275 return 0;
276 }
277
278 let len = text.len();
286 let whitespace = text
287 .as_bytes()
288 .iter()
289 .filter(|b| b.is_ascii_whitespace())
290 .count();
291
292 let ws_ratio = whitespace as f64 / len as f64;
293 let chars_per_token = if ws_ratio < 0.05 {
294 2.8
295 } else if ws_ratio < 0.10 {
296 3.2
297 } else if ws_ratio < 0.20 {
298 3.6
299 } else {
300 4.0
301 };
302
303 ((len as f64) / chars_per_token).ceil() as usize
304 }
305
306 pub fn chunk(content: &str, options: Option<ChunkOptions>) -> Vec<Chunk> {
308 let opts = options.unwrap_or_default();
309 let lines: Vec<&str> = content.lines().collect();
310 let mut chunks = Vec::new();
311
312 let boundaries = Self::find_boundaries(&lines);
314
315 let mut current_chunk: Vec<&str> = Vec::new();
316 let mut current_type = ChunkType::Text;
317 let mut current_start = 0;
318 let mut current_priority: u8 = 1;
319
320 for (i, line) in lines.iter().enumerate() {
321 if let Some((boundary_type, boundary_priority)) = boundaries.get(&i)
323 && !current_chunk.is_empty()
324 {
325 let content = current_chunk.join("\n");
326 let tokens = Self::estimate_tokens(&content);
327
328 if tokens > opts.max_chunk_tokens {
330 let sub_chunks = Self::split_large_chunk(
331 ¤t_chunk,
332 current_start,
333 current_type,
334 opts.max_chunk_tokens,
335 );
336 chunks.extend(sub_chunks);
337 } else {
338 chunks.push(Chunk {
339 content,
340 chunk_type: current_type,
341 start_line: current_start,
342 end_line: i.saturating_sub(1),
343 tokens,
344 priority: current_priority,
345 });
346 }
347
348 current_chunk = Vec::new();
349 current_start = i;
350 current_type = *boundary_type;
351 current_priority = *boundary_priority;
352 }
353
354 current_chunk.push(line);
355
356 if i >= lines.len().saturating_sub(opts.preserve_recent) {
358 current_priority = current_priority.max(8);
359 }
360 }
361
362 if !current_chunk.is_empty() {
364 let content = current_chunk.join("\n");
365 let tokens = Self::estimate_tokens(&content);
366
367 if tokens > opts.max_chunk_tokens {
368 let sub_chunks = Self::split_large_chunk(
369 ¤t_chunk,
370 current_start,
371 current_type,
372 opts.max_chunk_tokens,
373 );
374 chunks.extend(sub_chunks);
375 } else {
376 chunks.push(Chunk {
377 content,
378 chunk_type: current_type,
379 start_line: current_start,
380 end_line: lines.len().saturating_sub(1),
381 tokens,
382 priority: current_priority,
383 });
384 }
385 }
386
387 chunks
388 }
389
390 fn find_boundaries(lines: &[&str]) -> std::collections::HashMap<usize, (ChunkType, u8)> {
392 let mut boundaries = std::collections::HashMap::new();
393
394 for (i, line) in lines.iter().enumerate() {
395 let trimmed = line.trim();
396
397 if trimmed.starts_with("[User]:") || trimmed.starts_with("[Assistant]:") {
399 boundaries.insert(i, (ChunkType::Conversation, 5));
400 continue;
401 }
402
403 if trimmed.starts_with("[Tool ") {
405 let priority = if trimmed.contains("FAILED") || trimmed.contains("error") {
406 7
407 } else {
408 3
409 };
410 boundaries.insert(i, (ChunkType::ToolOutput, priority));
411 continue;
412 }
413
414 if trimmed.starts_with("```") {
416 boundaries.insert(i, (ChunkType::Code, 4));
417 continue;
418 }
419
420 if trimmed.starts_with('/') || trimmed.starts_with("./") || trimmed.starts_with("~/") {
422 boundaries.insert(i, (ChunkType::Code, 4));
423 continue;
424 }
425
426 let def_patterns = [
428 "function",
429 "class ",
430 "def ",
431 "async function",
432 "export",
433 "fn ",
434 "impl ",
435 "struct ",
436 "enum ",
437 ];
438 if def_patterns.iter().any(|p| trimmed.starts_with(p)) {
439 boundaries.insert(i, (ChunkType::Code, 5));
440 continue;
441 }
442
443 if trimmed.to_lowercase().starts_with("error")
445 || trimmed.to_lowercase().contains("error:")
446 || trimmed.starts_with("Exception")
447 || trimmed.contains("FAILED")
448 {
449 boundaries.insert(i, (ChunkType::Text, 8));
450 continue;
451 }
452
453 if trimmed.starts_with('#') && trimmed.len() > 2 && trimmed.chars().nth(1) == Some(' ')
455 {
456 boundaries.insert(i, (ChunkType::Text, 6));
457 continue;
458 }
459 }
460
461 boundaries
462 }
463
464 fn split_large_chunk(
466 lines: &[&str],
467 start_line: usize,
468 chunk_type: ChunkType,
469 max_tokens: usize,
470 ) -> Vec<Chunk> {
471 let mut chunks = Vec::new();
472 let mut current: Vec<&str> = Vec::new();
473 let mut current_tokens = 0;
474 let mut current_start = start_line;
475
476 for (i, line) in lines.iter().enumerate() {
477 let line_tokens = Self::estimate_tokens(line);
478
479 if current_tokens + line_tokens > max_tokens && !current.is_empty() {
480 chunks.push(Chunk {
481 content: current.join("\n"),
482 chunk_type,
483 start_line: current_start,
484 end_line: start_line + i - 1,
485 tokens: current_tokens,
486 priority: 3,
487 });
488 current = Vec::new();
489 current_tokens = 0;
490 current_start = start_line + i;
491 }
492
493 current.push(line);
494 current_tokens += line_tokens;
495 }
496
497 if !current.is_empty() {
498 chunks.push(Chunk {
499 content: current.join("\n"),
500 chunk_type,
501 start_line: current_start,
502 end_line: start_line + lines.len() - 1,
503 tokens: current_tokens,
504 priority: 3,
505 });
506 }
507
508 chunks
509 }
510
511 pub fn select_chunks(chunks: &[Chunk], max_tokens: usize) -> Vec<Chunk> {
514 let mut sorted: Vec<_> = chunks.to_vec();
515
516 sorted.sort_by(|a, b| match b.priority.cmp(&a.priority) {
518 std::cmp::Ordering::Equal => b.start_line.cmp(&a.start_line),
519 other => other,
520 });
521
522 let mut selected = Vec::new();
523 let mut total_tokens = 0;
524
525 for chunk in sorted {
526 if total_tokens + chunk.tokens <= max_tokens {
527 selected.push(chunk.clone());
528 total_tokens += chunk.tokens;
529 }
530 }
531
532 selected.sort_by_key(|c| c.start_line);
534
535 selected
536 }
537
538 pub fn reassemble(chunks: &[Chunk]) -> String {
540 if chunks.is_empty() {
541 return String::new();
542 }
543
544 let mut parts = Vec::new();
545 let mut last_end: Option<usize> = None;
546
547 for chunk in chunks {
548 if let Some(end) = last_end
550 && chunk.start_line > end + 1
551 {
552 let gap = chunk.start_line - end - 1;
553 parts.push(format!("\n[... {} lines omitted ...]\n", gap));
554 }
555 parts.push(chunk.content.clone());
556 last_end = Some(chunk.end_line);
557 }
558
559 parts.join("\n")
560 }
561
562 pub fn compress(content: &str, max_tokens: usize, options: Option<ChunkOptions>) -> String {
564 let chunks = Self::chunk(content, options);
565 let selected = Self::select_chunks(&chunks, max_tokens);
566 Self::reassemble(&selected)
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 #[test]
575 fn test_detect_code() {
576 let content = r#"
577fn main() {
578 println!("Hello, world!");
579}
580
581impl Foo {
582 pub fn new() -> Self {
583 Self {}
584 }
585}
586"#;
587 assert_eq!(RlmChunker::detect_content_type(content), ContentType::Code);
588 }
589
590 #[test]
591 fn test_detect_conversation() {
592 let content = r#"
593[User]: Can you help me with this?
594
595[Assistant]: Of course! What do you need?
596
597[User]: I want to implement a feature.
598"#;
599 assert_eq!(
600 RlmChunker::detect_content_type(content),
601 ContentType::Conversation
602 );
603 }
604
605 #[test]
606 fn test_compress() {
607 let content = "line\n".repeat(1000);
608 let compressed = RlmChunker::compress(&content, 100, None);
609 let tokens = RlmChunker::estimate_tokens(&compressed);
610 assert!(tokens <= 100 || compressed.contains("[..."));
611 }
612}