lean_ctx/core/
surprise.rs1use std::sync::OnceLock;
14
15use super::tokens::encode_tokens;
16
17static VOCAB_LOG_PROBS: OnceLock<Vec<f64>> = OnceLock::new();
18
19fn get_vocab_log_probs() -> &'static Vec<f64> {
23 VOCAB_LOG_PROBS.get_or_init(|| {
24 let vocab_size = 200_000usize;
25 let h_n: f64 = (1..=vocab_size).map(|r| 1.0 / r as f64).sum();
26 (0..vocab_size)
27 .map(|rank| {
28 let r = rank + 1; let p = 1.0 / (r as f64 * h_n);
30 -p.log2()
31 })
32 .collect()
33 })
34}
35
36pub fn line_surprise(text: &str) -> f64 {
44 let tokens = encode_tokens(text);
45 if tokens.is_empty() {
46 return 0.0;
47 }
48 let log_probs = get_vocab_log_probs();
49 let max_id = log_probs.len();
50
51 let total: f64 = tokens
52 .iter()
53 .map(|&t| {
54 let id = t as usize;
55 if id < max_id {
56 log_probs[id]
57 } else {
58 17.6 }
60 })
61 .sum();
62
63 total / tokens.len() as f64
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum SurpriseLevel {
70 Low,
72 Medium,
74 High,
76}
77
78pub fn classify_surprise(text: &str) -> SurpriseLevel {
79 let s = line_surprise(text);
80 if s < 8.0 {
81 SurpriseLevel::Low
82 } else if s < 12.0 {
83 SurpriseLevel::Medium
84 } else {
85 SurpriseLevel::High
86 }
87}
88
89pub fn should_keep_line(trimmed: &str, entropy_threshold: f64) -> bool {
93 if trimmed.is_empty() || trimmed.len() < 3 {
94 return true;
95 }
96
97 let tokens = encode_tokens(trimmed);
98 let h = super::entropy::token_entropy_from_ids(&tokens);
99 if h >= entropy_threshold {
100 return true;
101 }
102
103 let h_norm = super::entropy::normalized_token_entropy_from_ids(&tokens);
104 if h_norm >= 0.3 {
105 return true;
106 }
107
108 let surprise = line_surprise(trimmed);
112 surprise >= 11.0
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn common_code_has_low_surprise() {
121 let common = "let x = 1;";
122 let s = line_surprise(common);
123 assert!(s > 0.0, "surprise should be positive");
124 }
125
126 #[test]
127 fn rare_identifiers_have_higher_surprise() {
128 let common = "let x = 1;";
129 let rare = "let zygomorphic_validator = XenolithProcessor::new();";
130 assert!(
131 line_surprise(rare) > line_surprise(common),
132 "rare identifiers should have higher surprise"
133 );
134 }
135
136 #[test]
137 fn empty_returns_zero() {
138 assert_eq!(line_surprise(""), 0.0);
139 }
140
141 #[test]
142 fn classify_surprise_is_consistent() {
143 let simple = "let x = 1;";
144 let complex = "ZygomorphicXenolithValidator::process_quantum_state(&mut ctx)";
145 let s_simple = line_surprise(simple);
146 let s_complex = line_surprise(complex);
147 assert!(
148 s_complex > s_simple,
149 "rare identifiers ({s_complex}) should have higher surprise than common code ({s_simple})"
150 );
151 }
152
153 #[test]
154 fn should_keep_preserves_rare_lines() {
155 let rare = "ZygomorphicValidator::process_xenolith(&mut state)";
156 assert!(
157 should_keep_line(rare, 1.0) || line_surprise(rare) < 11.0,
158 "rare lines should be preserved or have measurable surprise"
159 );
160 }
161}