Skip to main content

sqz_engine/
token_pruner.rs

1/// Self-Information Token Pruning (inspired by CompactPrompt, arxiv 2510.18043).
2///
3/// Uses n-gram frequency as a lightweight proxy for token predictability.
4/// Tokens with high predictability (low self-information) are prunable
5/// without affecting LLM comprehension — grounded in Shannon's source
6/// coding theorem: only the "surprise" bits need to be transmitted.
7///
8/// This module uses a built-in trigram frequency table derived from common
9/// code/prose patterns rather than shipping a large external asset.
10
11use std::collections::HashMap;
12
13use crate::error::Result;
14
15/// Configuration for the token pruner.
16#[derive(Debug, Clone)]
17pub struct PrunerConfig {
18    /// Probability threshold above which a token is considered prunable.
19    /// Tokens with P(token | context) > threshold are removed.
20    /// Default: 0.85
21    pub predictability_threshold: f64,
22    /// Minimum token length to consider for pruning (skip short tokens).
23    /// Default: 2
24    pub min_token_length: usize,
25    /// Whether to preserve tokens that appear in code-like contexts.
26    /// Default: true
27    pub preserve_code_tokens: bool,
28}
29
30impl Default for PrunerConfig {
31    fn default() -> Self {
32        Self {
33            predictability_threshold: 0.85,
34            min_token_length: 2,
35            preserve_code_tokens: true,
36        }
37    }
38}
39
40/// Trigram-based self-information token pruner.
41pub struct TokenPruner {
42    config: PrunerConfig,
43    /// Trigram frequency table: (w1, w2) -> {w3 -> count}
44    trigram_table: HashMap<(String, String), HashMap<String, u32>>,
45    /// Bigram totals for normalization: (w1, w2) -> total_count
46    bigram_totals: HashMap<(String, String), u32>,
47}
48
49impl TokenPruner {
50    /// Create a new pruner with default config and built-in frequency table.
51    pub fn new() -> Self {
52        Self::with_config(PrunerConfig::default())
53    }
54
55    /// Create a new pruner with custom config.
56    pub fn with_config(config: PrunerConfig) -> Self {
57        let mut pruner = Self {
58            config,
59            trigram_table: HashMap::new(),
60            bigram_totals: HashMap::new(),
61        };
62        pruner.load_builtin_patterns();
63        pruner
64    }
65
66    /// Load built-in trigram patterns from common code and prose.
67    fn load_builtin_patterns(&mut self) {
68        // Common English prose trigrams (highly predictable completions)
69        let patterns: &[(&str, &str, &str, u32)] = &[
70            // Articles and prepositions
71            ("in", "the", "same", 80),
72            ("in", "the", "following", 75),
73            ("of", "the", "same", 70),
74            ("on", "the", "other", 65),
75            ("at", "the", "same", 60),
76            ("to", "the", "same", 55),
77            ("is", "a", "function", 50),
78            ("is", "a", "method", 48),
79            ("is", "a", "type", 45),
80            ("is", "the", "same", 70),
81            ("as", "a", "result", 60),
82            // Common code-adjacent prose
83            ("this", "is", "a", 90),
84            ("this", "is", "the", 85),
85            ("this", "is", "an", 80),
86            ("it", "is", "a", 85),
87            ("it", "is", "the", 80),
88            ("it", "is", "not", 75),
89            ("there", "is", "a", 80),
90            ("there", "is", "no", 75),
91            ("there", "are", "no", 70),
92            ("that", "is", "a", 75),
93            ("that", "is", "the", 70),
94            ("which", "is", "a", 70),
95            ("which", "is", "the", 65),
96            // Error message patterns
97            ("error", "in", "the", 60),
98            ("failed", "to", "connect", 55),
99            ("failed", "to", "open", 50),
100            ("failed", "to", "read", 50),
101            ("unable", "to", "find", 55),
102            ("unable", "to", "open", 50),
103            ("could", "not", "find", 60),
104            ("could", "not", "open", 55),
105            ("does", "not", "exist", 65),
106            ("is", "not", "a", 60),
107            ("is", "not", "the", 55),
108            // Documentation patterns
109            ("for", "more", "information", 80),
110            ("for", "more", "details", 75),
111            ("see", "the", "documentation", 70),
112            ("refer", "to", "the", 65),
113            ("please", "refer", "to", 60),
114            ("note", "that", "the", 55),
115            ("note", "that", "this", 50),
116            ("make", "sure", "that", 60),
117            ("make", "sure", "to", 55),
118            // Filler phrases
119            ("in", "order", "to", 90),
120            ("as", "well", "as", 85),
121            ("due", "to", "the", 70),
122            ("based", "on", "the", 65),
123            ("with", "respect", "to", 60),
124            ("in", "addition", "to", 55),
125            ("as", "opposed", "to", 50),
126            ("on", "behalf", "of", 50),
127        ];
128
129        for &(w1, w2, w3, count) in patterns {
130            let key = (w1.to_lowercase(), w2.to_lowercase());
131            self.trigram_table
132                .entry(key.clone())
133                .or_default()
134                .insert(w3.to_lowercase(), count);
135            *self.bigram_totals.entry(key).or_insert(0) += count;
136        }
137    }
138
139    /// Train the pruner on additional text to improve frequency estimates.
140    pub fn train(&mut self, text: &str) {
141        let words: Vec<String> = tokenize_words(text);
142        if words.len() < 3 {
143            return;
144        }
145        for window in words.windows(3) {
146            let key = (window[0].clone(), window[1].clone());
147            self.trigram_table
148                .entry(key.clone())
149                .or_default()
150                .entry(window[2].clone())
151                .and_modify(|c| *c += 1)
152                .or_insert(1);
153            *self.bigram_totals.entry(key).or_insert(0) += 1;
154        }
155    }
156
157    /// Compute the predictability P(w3 | w1, w2) for a trigram.
158    /// Returns 0.0 if the bigram context has never been seen.
159    fn predictability(&self, w1: &str, w2: &str, w3: &str) -> f64 {
160        let key = (w1.to_lowercase(), w2.to_lowercase());
161        let total = match self.bigram_totals.get(&key) {
162            Some(&t) if t > 0 => t,
163            _ => return 0.0,
164        };
165        let count = self
166            .trigram_table
167            .get(&key)
168            .and_then(|m| m.get(&w3.to_lowercase()))
169            .copied()
170            .unwrap_or(0);
171        count as f64 / total as f64
172    }
173
174    /// Prune highly predictable tokens from prose text.
175    ///
176    /// Returns the pruned text and the number of tokens removed.
177    pub fn prune(&self, text: &str) -> Result<PruneResult> {
178        let lines: Vec<&str> = text.lines().collect();
179        let mut output_lines = Vec::with_capacity(lines.len());
180        let mut total_removed = 0u32;
181        let mut total_original = 0u32;
182
183        for line in &lines {
184            // Skip code-like lines if preserve_code_tokens is enabled
185            if self.config.preserve_code_tokens && is_code_line(line) {
186                output_lines.push(line.to_string());
187                total_original += count_words(line) as u32;
188                continue;
189            }
190
191            let words: Vec<&str> = line.split_whitespace().collect();
192            total_original += words.len() as u32;
193
194            if words.len() < 3 {
195                output_lines.push(line.to_string());
196                continue;
197            }
198
199            let mut kept: Vec<&str> = Vec::with_capacity(words.len());
200            // Always keep first two words (context)
201            kept.push(words[0]);
202            kept.push(words[1]);
203
204            for i in 2..words.len() {
205                let w1 = words[i - 2].to_lowercase();
206                let w2 = words[i - 1].to_lowercase();
207                let w3_clean = words[i]
208                    .trim_matches(|c: char| !c.is_alphanumeric())
209                    .to_lowercase();
210
211                if w3_clean.len() < self.config.min_token_length {
212                    kept.push(words[i]);
213                    continue;
214                }
215
216                let p = self.predictability(&w1, &w2, &w3_clean);
217                if p > self.config.predictability_threshold {
218                    total_removed += 1;
219                } else {
220                    kept.push(words[i]);
221                }
222            }
223
224            output_lines.push(kept.join(" "));
225        }
226
227        let pruned_text = output_lines.join("\n");
228        // Preserve trailing newline
229        let result = if text.ends_with('\n') && !pruned_text.ends_with('\n') {
230            format!("{pruned_text}\n")
231        } else {
232            pruned_text
233        };
234
235        Ok(PruneResult {
236            text: result,
237            tokens_removed: total_removed,
238            tokens_original: total_original,
239        })
240    }
241
242    /// Zipf's Law vocabulary pruning.
243    ///
244    /// Zipf's law: f(r) ∝ 1/r — the frequency of a word is inversely
245    /// proportional to its rank. Words appearing at or above their expected
246    /// Zipf frequency are redundant and can be pruned. Words appearing
247    /// below their expected frequency carry more information and are preserved.
248    ///
249    /// Returns the pruned text and stats.
250    pub fn zipf_prune(&self, text: &str) -> Result<PruneResult> {
251        let words: Vec<&str> = text.split_whitespace().collect();
252        let total_original = words.len() as u32;
253
254        if words.len() < 10 {
255            return Ok(PruneResult {
256                text: text.to_string(),
257                tokens_removed: 0,
258                tokens_original: total_original,
259            });
260        }
261
262        // Count word frequencies
263        let mut freq_map: HashMap<String, usize> = HashMap::new();
264        for &w in &words {
265            *freq_map.entry(w.to_lowercase()).or_insert(0) += 1;
266        }
267
268        // Rank words by frequency (descending)
269        let mut ranked: Vec<(String, usize)> = freq_map.into_iter().collect();
270        ranked.sort_by(|a, b| b.1.cmp(&a.1));
271
272        // Compute expected Zipf frequency for each rank
273        // f_expected(r) = C / r, where C = total_words / H_n (harmonic number)
274        let _n = ranked.len() as f64;
275        let harmonic: f64 = (1..=ranked.len()).map(|r| 1.0 / r as f64).sum();
276        let c = words.len() as f64 / harmonic;
277
278        // Mark words that are "Zipf-redundant": actual frequency >= 1.5× expected
279        let mut redundant_words: std::collections::HashSet<String> = std::collections::HashSet::new();
280        for (rank_idx, (word, actual_freq)) in ranked.iter().enumerate() {
281            let rank = rank_idx + 1;
282            let expected = c / rank as f64;
283            // Word is redundant if it appears much more than Zipf predicts
284            // AND it's a common filler word (short, non-technical)
285            if *actual_freq as f64 > expected * 1.5
286                && word.len() <= 4
287                && !is_technical_word(word)
288            {
289                redundant_words.insert(word.clone());
290            }
291        }
292
293        if redundant_words.is_empty() {
294            return Ok(PruneResult {
295                text: text.to_string(),
296                tokens_removed: 0,
297                tokens_original: total_original,
298            });
299        }
300
301        // Remove redundant words, keeping at least one occurrence of each
302        let mut seen_counts: HashMap<String, usize> = HashMap::new();
303        let mut kept = Vec::new();
304        let mut removed = 0u32;
305
306        for &w in &words {
307            let lower = w.to_lowercase();
308            if redundant_words.contains(&lower) {
309                let count = seen_counts.entry(lower.clone()).or_insert(0);
310                *count += 1;
311                // Keep the first occurrence, prune subsequent ones
312                if *count <= 1 {
313                    kept.push(w);
314                } else {
315                    removed += 1;
316                }
317            } else {
318                kept.push(w);
319            }
320        }
321
322        let result = kept.join(" ");
323        let result = if text.ends_with('\n') && !result.ends_with('\n') {
324            format!("{result}\n")
325        } else {
326            result
327        };
328
329        Ok(PruneResult {
330            text: result,
331            tokens_removed: removed,
332            tokens_original: total_original,
333        })
334    }
335}
336
337impl Default for TokenPruner {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343/// Result of a prune operation.
344#[derive(Debug, Clone)]
345pub struct PruneResult {
346    /// The pruned text.
347    pub text: String,
348    /// Number of tokens removed.
349    pub tokens_removed: u32,
350    /// Original token count.
351    pub tokens_original: u32,
352}
353
354impl PruneResult {
355    /// Fraction of tokens removed (0.0 to 1.0).
356    pub fn reduction_ratio(&self) -> f64 {
357        if self.tokens_original == 0 {
358            0.0
359        } else {
360            self.tokens_removed as f64 / self.tokens_original as f64
361        }
362    }
363}
364
365// ── Helpers ───────────────────────────────────────────────────────────────────
366
367/// Tokenize text into lowercase words.
368fn tokenize_words(text: &str) -> Vec<String> {
369    text.split(|c: char| !c.is_alphanumeric() && c != '\'')
370        .filter(|s| !s.is_empty())
371        .map(|s| s.to_lowercase())
372        .collect()
373}
374
375/// Count whitespace-separated words.
376fn count_words(text: &str) -> usize {
377    text.split_whitespace().count()
378}
379
380/// Heuristic: is this line likely code rather than prose?
381fn is_code_line(line: &str) -> bool {
382    let trimmed = line.trim();
383    if trimmed.is_empty() {
384        return false;
385    }
386    // Lines starting with common code indicators
387    trimmed.starts_with("fn ")
388        || trimmed.starts_with("pub ")
389        || trimmed.starts_with("let ")
390        || trimmed.starts_with("const ")
391        || trimmed.starts_with("var ")
392        || trimmed.starts_with("def ")
393        || trimmed.starts_with("class ")
394        || trimmed.starts_with("import ")
395        || trimmed.starts_with("from ")
396        || trimmed.starts_with("use ")
397        || trimmed.starts_with("return ")
398        || trimmed.starts_with("if ")
399        || trimmed.starts_with("for ")
400        || trimmed.starts_with("while ")
401        || trimmed.starts_with('#')
402        || trimmed.starts_with("//")
403        || trimmed.starts_with("/*")
404        || trimmed.starts_with('*')
405        || trimmed.ends_with('{')
406        || trimmed.ends_with('}')
407        || trimmed.ends_with(';')
408        || trimmed.ends_with(')')
409        || trimmed.contains("->")
410        || trimmed.contains("=>")
411        || trimmed.contains("::")
412        || trimmed.contains("()")
413}
414
415/// Check if a short word is technical/meaningful (should not be pruned).
416fn is_technical_word(word: &str) -> bool {
417    matches!(
418        word,
419        "null" | "none" | "true" | "false" | "void" | "self" | "this"
420        | "type" | "enum" | "impl" | "func" | "main" | "test" | "init"
421        | "open" | "read" | "send" | "recv" | "lock" | "drop" | "move"
422        | "copy" | "sync" | "push" | "pull" | "port" | "host" | "path"
423        | "file" | "line" | "code" | "data" | "node" | "root" | "hash"
424        | "size" | "name" | "list" | "loop" | "exit" | "fail" | "pass"
425        | "skip" | "todo" | "warn" | "info" | "http" | "json" | "yaml"
426        | "toml" | "html" | "rust" | "java" | "bash"
427    )
428}
429
430// ── Tests ─────────────────────────────────────────────────────────────────────
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_default_creates_pruner() {
438        let pruner = TokenPruner::new();
439        assert!(!pruner.trigram_table.is_empty());
440        assert!(!pruner.bigram_totals.is_empty());
441    }
442
443    #[test]
444    fn test_prune_empty_input() {
445        let pruner = TokenPruner::new();
446        let result = pruner.prune("").unwrap();
447        assert_eq!(result.text, "");
448        assert_eq!(result.tokens_removed, 0);
449    }
450
451    #[test]
452    fn test_prune_short_input_unchanged() {
453        let pruner = TokenPruner::new();
454        let result = pruner.prune("hello world").unwrap();
455        assert_eq!(result.text, "hello world");
456        assert_eq!(result.tokens_removed, 0);
457    }
458
459    #[test]
460    fn test_prune_removes_predictable_tokens() {
461        let pruner = TokenPruner::new();
462        // "in order to" is a highly predictable trigram — "to" should be pruned
463        let result = pruner.prune("We need in order to do this task").unwrap();
464        assert!(
465            result.tokens_removed > 0 || result.text.len() <= "We need in order to do this task".len(),
466            "expected some pruning on predictable prose"
467        );
468    }
469
470    #[test]
471    fn test_prune_preserves_code_lines() {
472        let pruner = TokenPruner::new();
473        let code = "fn main() {\n    let x = 42;\n}";
474        let result = pruner.prune(code).unwrap();
475        assert_eq!(result.text, code);
476        assert_eq!(result.tokens_removed, 0);
477    }
478
479    #[test]
480    fn test_prune_preserves_trailing_newline() {
481        let pruner = TokenPruner::new();
482        let result = pruner.prune("hello world\n").unwrap();
483        assert!(result.text.ends_with('\n'));
484    }
485
486    #[test]
487    fn test_train_adds_patterns() {
488        let mut pruner = TokenPruner::new();
489        let initial_size = pruner.trigram_table.len();
490        pruner.train("the quick brown fox jumps over the lazy dog and the quick brown cat");
491        assert!(pruner.trigram_table.len() >= initial_size);
492    }
493
494    #[test]
495    fn test_predictability_unknown_context() {
496        let pruner = TokenPruner::new();
497        let p = pruner.predictability("xyzzy", "plugh", "foo");
498        assert_eq!(p, 0.0);
499    }
500
501    #[test]
502    fn test_predictability_known_pattern() {
503        let pruner = TokenPruner::new();
504        // "in order to" is in the built-in table with high count
505        let p = pruner.predictability("in", "order", "to");
506        assert!(p > 0.5, "expected high predictability, got {p}");
507    }
508
509    #[test]
510    fn test_reduction_ratio_zero_for_empty() {
511        let result = PruneResult {
512            text: String::new(),
513            tokens_removed: 0,
514            tokens_original: 0,
515        };
516        assert_eq!(result.reduction_ratio(), 0.0);
517    }
518
519    #[test]
520    fn test_is_code_line_detection() {
521        assert!(is_code_line("fn main() {"));
522        assert!(is_code_line("  let x = 42;"));
523        assert!(is_code_line("// comment"));
524        assert!(is_code_line("import os"));
525        assert!(!is_code_line("This is a normal sentence."));
526        assert!(!is_code_line("The error occurred in the module."));
527        assert!(!is_code_line(""));
528    }
529
530    #[test]
531    fn test_custom_config() {
532        let config = PrunerConfig {
533            predictability_threshold: 0.5,
534            min_token_length: 1,
535            preserve_code_tokens: false,
536        };
537        let pruner = TokenPruner::with_config(config);
538        // With lower threshold, more tokens should be pruned
539        let result = pruner.prune("this is a very long sentence with many words in order to test").unwrap();
540        // Just verify it doesn't crash
541        assert!(!result.text.is_empty());
542    }
543
544    // ── Property tests ────────────────────────────────────────────────────────
545
546    use proptest::prelude::*;
547
548    proptest! {
549        /// Pruning never produces output longer than input.
550        #[test]
551        fn prop_prune_never_increases_length(
552            text in "[a-z ]{10,200}"
553        ) {
554            let pruner = TokenPruner::new();
555            let result = pruner.prune(&text).unwrap();
556            prop_assert!(
557                result.text.len() <= text.len() + 1, // +1 for possible trailing newline
558                "pruned text ({}) should not be longer than input ({})",
559                result.text.len(), text.len()
560            );
561        }
562
563        /// tokens_removed + remaining tokens == tokens_original
564        #[test]
565        fn prop_prune_token_accounting(
566            text in "[a-z ]{10,200}"
567        ) {
568            let pruner = TokenPruner::new();
569            let result = pruner.prune(&text).unwrap();
570            let remaining = count_words(&result.text) as u32;
571            prop_assert!(
572                result.tokens_removed + remaining <= result.tokens_original + 1,
573                "removed ({}) + remaining ({}) should be <= original ({})",
574                result.tokens_removed, remaining, result.tokens_original
575            );
576        }
577    }
578
579    // ── Zipf's Law pruning tests ──────────────────────────────────────────
580
581    #[test]
582    fn test_zipf_prune_short_text_unchanged() {
583        let pruner = TokenPruner::new();
584        let result = pruner.zipf_prune("hello world").unwrap();
585        assert_eq!(result.text, "hello world");
586        assert_eq!(result.tokens_removed, 0);
587    }
588
589    #[test]
590    fn test_zipf_prune_removes_overrepresented_fillers() {
591        let pruner = TokenPruner::new();
592        // "the" appears way more than Zipf predicts for a text this size
593        let text = "the cat the dog the bird the fish the tree the rock the sky the sun the moon the star";
594        let result = pruner.zipf_prune(text).unwrap();
595        // Should remove some "the" occurrences but keep at least one
596        assert!(result.text.contains("the"), "should keep at least one 'the'");
597        assert!(
598            result.tokens_removed > 0,
599            "should prune overrepresented filler words"
600        );
601    }
602
603    #[test]
604    fn test_zipf_prune_preserves_technical_words() {
605        let pruner = TokenPruner::new();
606        let text = "null null null null null null null null null null check for null values";
607        let result = pruner.zipf_prune(text).unwrap();
608        // "null" is a technical word — should NOT be pruned
609        assert_eq!(result.tokens_removed, 0, "technical words should be preserved");
610    }
611
612    #[test]
613    fn test_is_technical_word() {
614        assert!(is_technical_word("null"));
615        assert!(is_technical_word("type"));
616        assert!(is_technical_word("json"));
617        assert!(!is_technical_word("the"));
618        assert!(!is_technical_word("and"));
619        assert!(!is_technical_word("xyz"));
620    }
621
622    proptest! {
623        /// Zipf pruning never produces empty output from non-empty input.
624        #[test]
625        fn prop_zipf_prune_non_empty(
626            text in "[a-z]{2,5}( [a-z]{2,5}){10,30}"
627        ) {
628            let pruner = TokenPruner::new();
629            let result = pruner.zipf_prune(&text).unwrap();
630            prop_assert!(!result.text.is_empty());
631        }
632    }
633}