1use std::collections::HashMap;
12
13use crate::error::Result;
14
15#[derive(Debug, Clone)]
17pub struct PrunerConfig {
18 pub predictability_threshold: f64,
22 pub min_token_length: usize,
25 pub preserve_code_tokens: bool,
28}
29
30impl Default for PrunerConfig {
31 fn default() -> Self {
32 Self {
33 predictability_threshold: 0.85,
34 min_token_length: 2,
35 preserve_code_tokens: true,
36 }
37 }
38}
39
40pub struct TokenPruner {
42 config: PrunerConfig,
43 trigram_table: HashMap<(String, String), HashMap<String, u32>>,
45 bigram_totals: HashMap<(String, String), u32>,
47}
48
49impl TokenPruner {
50 pub fn new() -> Self {
52 Self::with_config(PrunerConfig::default())
53 }
54
55 pub fn with_config(config: PrunerConfig) -> Self {
57 let mut pruner = Self {
58 config,
59 trigram_table: HashMap::new(),
60 bigram_totals: HashMap::new(),
61 };
62 pruner.load_builtin_patterns();
63 pruner
64 }
65
66 fn load_builtin_patterns(&mut self) {
68 let patterns: &[(&str, &str, &str, u32)] = &[
70 ("in", "the", "same", 80),
72 ("in", "the", "following", 75),
73 ("of", "the", "same", 70),
74 ("on", "the", "other", 65),
75 ("at", "the", "same", 60),
76 ("to", "the", "same", 55),
77 ("is", "a", "function", 50),
78 ("is", "a", "method", 48),
79 ("is", "a", "type", 45),
80 ("is", "the", "same", 70),
81 ("as", "a", "result", 60),
82 ("this", "is", "a", 90),
84 ("this", "is", "the", 85),
85 ("this", "is", "an", 80),
86 ("it", "is", "a", 85),
87 ("it", "is", "the", 80),
88 ("it", "is", "not", 75),
89 ("there", "is", "a", 80),
90 ("there", "is", "no", 75),
91 ("there", "are", "no", 70),
92 ("that", "is", "a", 75),
93 ("that", "is", "the", 70),
94 ("which", "is", "a", 70),
95 ("which", "is", "the", 65),
96 ("error", "in", "the", 60),
98 ("failed", "to", "connect", 55),
99 ("failed", "to", "open", 50),
100 ("failed", "to", "read", 50),
101 ("unable", "to", "find", 55),
102 ("unable", "to", "open", 50),
103 ("could", "not", "find", 60),
104 ("could", "not", "open", 55),
105 ("does", "not", "exist", 65),
106 ("is", "not", "a", 60),
107 ("is", "not", "the", 55),
108 ("for", "more", "information", 80),
110 ("for", "more", "details", 75),
111 ("see", "the", "documentation", 70),
112 ("refer", "to", "the", 65),
113 ("please", "refer", "to", 60),
114 ("note", "that", "the", 55),
115 ("note", "that", "this", 50),
116 ("make", "sure", "that", 60),
117 ("make", "sure", "to", 55),
118 ("in", "order", "to", 90),
120 ("as", "well", "as", 85),
121 ("due", "to", "the", 70),
122 ("based", "on", "the", 65),
123 ("with", "respect", "to", 60),
124 ("in", "addition", "to", 55),
125 ("as", "opposed", "to", 50),
126 ("on", "behalf", "of", 50),
127 ];
128
129 for &(w1, w2, w3, count) in patterns {
130 let key = (w1.to_lowercase(), w2.to_lowercase());
131 self.trigram_table
132 .entry(key.clone())
133 .or_default()
134 .insert(w3.to_lowercase(), count);
135 *self.bigram_totals.entry(key).or_insert(0) += count;
136 }
137 }
138
139 pub fn train(&mut self, text: &str) {
141 let words: Vec<String> = tokenize_words(text);
142 if words.len() < 3 {
143 return;
144 }
145 for window in words.windows(3) {
146 let key = (window[0].clone(), window[1].clone());
147 self.trigram_table
148 .entry(key.clone())
149 .or_default()
150 .entry(window[2].clone())
151 .and_modify(|c| *c += 1)
152 .or_insert(1);
153 *self.bigram_totals.entry(key).or_insert(0) += 1;
154 }
155 }
156
157 fn predictability(&self, w1: &str, w2: &str, w3: &str) -> f64 {
160 let key = (w1.to_lowercase(), w2.to_lowercase());
161 let total = match self.bigram_totals.get(&key) {
162 Some(&t) if t > 0 => t,
163 _ => return 0.0,
164 };
165 let count = self
166 .trigram_table
167 .get(&key)
168 .and_then(|m| m.get(&w3.to_lowercase()))
169 .copied()
170 .unwrap_or(0);
171 count as f64 / total as f64
172 }
173
174 pub fn prune(&self, text: &str) -> Result<PruneResult> {
178 let lines: Vec<&str> = text.lines().collect();
179 let mut output_lines = Vec::with_capacity(lines.len());
180 let mut total_removed = 0u32;
181 let mut total_original = 0u32;
182
183 for line in &lines {
184 if self.config.preserve_code_tokens && is_code_line(line) {
186 output_lines.push(line.to_string());
187 total_original += count_words(line) as u32;
188 continue;
189 }
190
191 let words: Vec<&str> = line.split_whitespace().collect();
192 total_original += words.len() as u32;
193
194 if words.len() < 3 {
195 output_lines.push(line.to_string());
196 continue;
197 }
198
199 let mut kept: Vec<&str> = Vec::with_capacity(words.len());
200 kept.push(words[0]);
202 kept.push(words[1]);
203
204 for i in 2..words.len() {
205 let w1 = words[i - 2].to_lowercase();
206 let w2 = words[i - 1].to_lowercase();
207 let w3_clean = words[i]
208 .trim_matches(|c: char| !c.is_alphanumeric())
209 .to_lowercase();
210
211 if w3_clean.len() < self.config.min_token_length {
212 kept.push(words[i]);
213 continue;
214 }
215
216 let p = self.predictability(&w1, &w2, &w3_clean);
217 if p > self.config.predictability_threshold {
218 total_removed += 1;
219 } else {
220 kept.push(words[i]);
221 }
222 }
223
224 output_lines.push(kept.join(" "));
225 }
226
227 let pruned_text = output_lines.join("\n");
228 let result = if text.ends_with('\n') && !pruned_text.ends_with('\n') {
230 format!("{pruned_text}\n")
231 } else {
232 pruned_text
233 };
234
235 Ok(PruneResult {
236 text: result,
237 tokens_removed: total_removed,
238 tokens_original: total_original,
239 })
240 }
241
242 pub fn zipf_prune(&self, text: &str) -> Result<PruneResult> {
251 let words: Vec<&str> = text.split_whitespace().collect();
252 let total_original = words.len() as u32;
253
254 if words.len() < 10 {
255 return Ok(PruneResult {
256 text: text.to_string(),
257 tokens_removed: 0,
258 tokens_original: total_original,
259 });
260 }
261
262 let mut freq_map: HashMap<String, usize> = HashMap::new();
264 for &w in &words {
265 *freq_map.entry(w.to_lowercase()).or_insert(0) += 1;
266 }
267
268 let mut ranked: Vec<(String, usize)> = freq_map.into_iter().collect();
270 ranked.sort_by(|a, b| b.1.cmp(&a.1));
271
272 let _n = ranked.len() as f64;
275 let harmonic: f64 = (1..=ranked.len()).map(|r| 1.0 / r as f64).sum();
276 let c = words.len() as f64 / harmonic;
277
278 let mut redundant_words: std::collections::HashSet<String> = std::collections::HashSet::new();
280 for (rank_idx, (word, actual_freq)) in ranked.iter().enumerate() {
281 let rank = rank_idx + 1;
282 let expected = c / rank as f64;
283 if *actual_freq as f64 > expected * 1.5
286 && word.len() <= 4
287 && !is_technical_word(word)
288 {
289 redundant_words.insert(word.clone());
290 }
291 }
292
293 if redundant_words.is_empty() {
294 return Ok(PruneResult {
295 text: text.to_string(),
296 tokens_removed: 0,
297 tokens_original: total_original,
298 });
299 }
300
301 let mut seen_counts: HashMap<String, usize> = HashMap::new();
303 let mut kept = Vec::new();
304 let mut removed = 0u32;
305
306 for &w in &words {
307 let lower = w.to_lowercase();
308 if redundant_words.contains(&lower) {
309 let count = seen_counts.entry(lower.clone()).or_insert(0);
310 *count += 1;
311 if *count <= 1 {
313 kept.push(w);
314 } else {
315 removed += 1;
316 }
317 } else {
318 kept.push(w);
319 }
320 }
321
322 let result = kept.join(" ");
323 let result = if text.ends_with('\n') && !result.ends_with('\n') {
324 format!("{result}\n")
325 } else {
326 result
327 };
328
329 Ok(PruneResult {
330 text: result,
331 tokens_removed: removed,
332 tokens_original: total_original,
333 })
334 }
335}
336
337impl Default for TokenPruner {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct PruneResult {
346 pub text: String,
348 pub tokens_removed: u32,
350 pub tokens_original: u32,
352}
353
354impl PruneResult {
355 pub fn reduction_ratio(&self) -> f64 {
357 if self.tokens_original == 0 {
358 0.0
359 } else {
360 self.tokens_removed as f64 / self.tokens_original as f64
361 }
362 }
363}
364
365fn tokenize_words(text: &str) -> Vec<String> {
369 text.split(|c: char| !c.is_alphanumeric() && c != '\'')
370 .filter(|s| !s.is_empty())
371 .map(|s| s.to_lowercase())
372 .collect()
373}
374
375fn count_words(text: &str) -> usize {
377 text.split_whitespace().count()
378}
379
380fn is_code_line(line: &str) -> bool {
382 let trimmed = line.trim();
383 if trimmed.is_empty() {
384 return false;
385 }
386 trimmed.starts_with("fn ")
388 || trimmed.starts_with("pub ")
389 || trimmed.starts_with("let ")
390 || trimmed.starts_with("const ")
391 || trimmed.starts_with("var ")
392 || trimmed.starts_with("def ")
393 || trimmed.starts_with("class ")
394 || trimmed.starts_with("import ")
395 || trimmed.starts_with("from ")
396 || trimmed.starts_with("use ")
397 || trimmed.starts_with("return ")
398 || trimmed.starts_with("if ")
399 || trimmed.starts_with("for ")
400 || trimmed.starts_with("while ")
401 || trimmed.starts_with('#')
402 || trimmed.starts_with("//")
403 || trimmed.starts_with("/*")
404 || trimmed.starts_with('*')
405 || trimmed.ends_with('{')
406 || trimmed.ends_with('}')
407 || trimmed.ends_with(';')
408 || trimmed.ends_with(')')
409 || trimmed.contains("->")
410 || trimmed.contains("=>")
411 || trimmed.contains("::")
412 || trimmed.contains("()")
413}
414
415fn is_technical_word(word: &str) -> bool {
417 matches!(
418 word,
419 "null" | "none" | "true" | "false" | "void" | "self" | "this"
420 | "type" | "enum" | "impl" | "func" | "main" | "test" | "init"
421 | "open" | "read" | "send" | "recv" | "lock" | "drop" | "move"
422 | "copy" | "sync" | "push" | "pull" | "port" | "host" | "path"
423 | "file" | "line" | "code" | "data" | "node" | "root" | "hash"
424 | "size" | "name" | "list" | "loop" | "exit" | "fail" | "pass"
425 | "skip" | "todo" | "warn" | "info" | "http" | "json" | "yaml"
426 | "toml" | "html" | "rust" | "java" | "bash"
427 )
428}
429
430#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_default_creates_pruner() {
438 let pruner = TokenPruner::new();
439 assert!(!pruner.trigram_table.is_empty());
440 assert!(!pruner.bigram_totals.is_empty());
441 }
442
443 #[test]
444 fn test_prune_empty_input() {
445 let pruner = TokenPruner::new();
446 let result = pruner.prune("").unwrap();
447 assert_eq!(result.text, "");
448 assert_eq!(result.tokens_removed, 0);
449 }
450
451 #[test]
452 fn test_prune_short_input_unchanged() {
453 let pruner = TokenPruner::new();
454 let result = pruner.prune("hello world").unwrap();
455 assert_eq!(result.text, "hello world");
456 assert_eq!(result.tokens_removed, 0);
457 }
458
459 #[test]
460 fn test_prune_removes_predictable_tokens() {
461 let pruner = TokenPruner::new();
462 let result = pruner.prune("We need in order to do this task").unwrap();
464 assert!(
465 result.tokens_removed > 0 || result.text.len() <= "We need in order to do this task".len(),
466 "expected some pruning on predictable prose"
467 );
468 }
469
470 #[test]
471 fn test_prune_preserves_code_lines() {
472 let pruner = TokenPruner::new();
473 let code = "fn main() {\n let x = 42;\n}";
474 let result = pruner.prune(code).unwrap();
475 assert_eq!(result.text, code);
476 assert_eq!(result.tokens_removed, 0);
477 }
478
479 #[test]
480 fn test_prune_preserves_trailing_newline() {
481 let pruner = TokenPruner::new();
482 let result = pruner.prune("hello world\n").unwrap();
483 assert!(result.text.ends_with('\n'));
484 }
485
486 #[test]
487 fn test_train_adds_patterns() {
488 let mut pruner = TokenPruner::new();
489 let initial_size = pruner.trigram_table.len();
490 pruner.train("the quick brown fox jumps over the lazy dog and the quick brown cat");
491 assert!(pruner.trigram_table.len() >= initial_size);
492 }
493
494 #[test]
495 fn test_predictability_unknown_context() {
496 let pruner = TokenPruner::new();
497 let p = pruner.predictability("xyzzy", "plugh", "foo");
498 assert_eq!(p, 0.0);
499 }
500
501 #[test]
502 fn test_predictability_known_pattern() {
503 let pruner = TokenPruner::new();
504 let p = pruner.predictability("in", "order", "to");
506 assert!(p > 0.5, "expected high predictability, got {p}");
507 }
508
509 #[test]
510 fn test_reduction_ratio_zero_for_empty() {
511 let result = PruneResult {
512 text: String::new(),
513 tokens_removed: 0,
514 tokens_original: 0,
515 };
516 assert_eq!(result.reduction_ratio(), 0.0);
517 }
518
519 #[test]
520 fn test_is_code_line_detection() {
521 assert!(is_code_line("fn main() {"));
522 assert!(is_code_line(" let x = 42;"));
523 assert!(is_code_line("// comment"));
524 assert!(is_code_line("import os"));
525 assert!(!is_code_line("This is a normal sentence."));
526 assert!(!is_code_line("The error occurred in the module."));
527 assert!(!is_code_line(""));
528 }
529
530 #[test]
531 fn test_custom_config() {
532 let config = PrunerConfig {
533 predictability_threshold: 0.5,
534 min_token_length: 1,
535 preserve_code_tokens: false,
536 };
537 let pruner = TokenPruner::with_config(config);
538 let result = pruner.prune("this is a very long sentence with many words in order to test").unwrap();
540 assert!(!result.text.is_empty());
542 }
543
544 use proptest::prelude::*;
547
548 proptest! {
549 #[test]
551 fn prop_prune_never_increases_length(
552 text in "[a-z ]{10,200}"
553 ) {
554 let pruner = TokenPruner::new();
555 let result = pruner.prune(&text).unwrap();
556 prop_assert!(
557 result.text.len() <= text.len() + 1, "pruned text ({}) should not be longer than input ({})",
559 result.text.len(), text.len()
560 );
561 }
562
563 #[test]
565 fn prop_prune_token_accounting(
566 text in "[a-z ]{10,200}"
567 ) {
568 let pruner = TokenPruner::new();
569 let result = pruner.prune(&text).unwrap();
570 let remaining = count_words(&result.text) as u32;
571 prop_assert!(
572 result.tokens_removed + remaining <= result.tokens_original + 1,
573 "removed ({}) + remaining ({}) should be <= original ({})",
574 result.tokens_removed, remaining, result.tokens_original
575 );
576 }
577 }
578
579 #[test]
582 fn test_zipf_prune_short_text_unchanged() {
583 let pruner = TokenPruner::new();
584 let result = pruner.zipf_prune("hello world").unwrap();
585 assert_eq!(result.text, "hello world");
586 assert_eq!(result.tokens_removed, 0);
587 }
588
589 #[test]
590 fn test_zipf_prune_removes_overrepresented_fillers() {
591 let pruner = TokenPruner::new();
592 let text = "the cat the dog the bird the fish the tree the rock the sky the sun the moon the star";
594 let result = pruner.zipf_prune(text).unwrap();
595 assert!(result.text.contains("the"), "should keep at least one 'the'");
597 assert!(
598 result.tokens_removed > 0,
599 "should prune overrepresented filler words"
600 );
601 }
602
603 #[test]
604 fn test_zipf_prune_preserves_technical_words() {
605 let pruner = TokenPruner::new();
606 let text = "null null null null null null null null null null check for null values";
607 let result = pruner.zipf_prune(text).unwrap();
608 assert_eq!(result.tokens_removed, 0, "technical words should be preserved");
610 }
611
612 #[test]
613 fn test_is_technical_word() {
614 assert!(is_technical_word("null"));
615 assert!(is_technical_word("type"));
616 assert!(is_technical_word("json"));
617 assert!(!is_technical_word("the"));
618 assert!(!is_technical_word("and"));
619 assert!(!is_technical_word("xyz"));
620 }
621
622 proptest! {
623 #[test]
625 fn prop_zipf_prune_non_empty(
626 text in "[a-z]{2,5}( [a-z]{2,5}){10,30}"
627 ) {
628 let pruner = TokenPruner::new();
629 let result = pruner.zipf_prune(&text).unwrap();
630 prop_assert!(!result.text.is_empty());
631 }
632 }
633}