lean_ctx/core/
progressive_compression.rs1use super::tokens::count_tokens;
4
5fn truncate_to_token_budget(s: &str, max_tokens: usize) -> String {
6 if max_tokens == 0 {
7 return String::new();
8 }
9 if count_tokens(s) <= max_tokens {
10 return s.to_string();
11 }
12 let mut lo = 0usize;
13 let mut hi = s.len();
14 while lo + 1 < hi {
15 let mid = usize::midpoint(lo, hi);
16 let pref = s.get(..mid).unwrap_or("");
17 if count_tokens(pref) <= max_tokens {
18 lo = mid;
19 } else {
20 hi = mid;
21 }
22 }
23 let pref = s.get(..lo).unwrap_or("");
24 format!("{pref} …")
25}
26
27fn map_like(s: &str, max_tokens: usize) -> String {
28 let keywords = [
29 "fn ", "pub ", "struct ", "enum ", "trait ", "impl ", "mod ", "use ", "def ", "class ",
30 ];
31 let mut picked: Vec<&str> = Vec::new();
32 for (i, line) in s.lines().enumerate() {
33 if i == 0 || keywords.iter().any(|k| line.contains(k)) {
34 picked.push(line);
35 }
36 if picked.len() >= 48 {
37 break;
38 }
39 }
40 if picked.is_empty() {
41 picked.push(s.lines().next().unwrap_or(""));
42 }
43 let draft = picked.join("\n");
44 truncate_to_token_budget(&draft, max_tokens.max(4))
45}
46
47fn one_line_summary(segment_idx: usize, s: &str, max_tokens: usize) -> String {
48 let preview = s
49 .lines()
50 .next()
51 .unwrap_or("")
52 .chars()
53 .take(120)
54 .collect::<String>();
55 let draft = format!(
56 "// seg[{segment_idx}] {} lines, {} chars | {preview}",
57 s.lines().count(),
58 s.len(),
59 );
60 truncate_to_token_budget(&draft, max_tokens.max(8))
61}
62
63fn tier_for_index(i: usize, n: usize) -> usize {
64 if n <= 1 {
65 return 2;
66 }
67 let r = i as f64 / (n.saturating_sub(1)) as f64;
68 if r < 1.0 / 3.0 {
69 0
70 } else if r < 2.0 / 3.0 {
71 1
72 } else {
73 2
74 }
75}
76
77fn allocate_budget_chunks(budget_tokens: usize, w: &[f64]) -> Vec<usize> {
78 let n = w.len();
79 if n == 0 || budget_tokens == 0 {
80 return vec![0; n];
81 }
82 let sum_w: f64 = w.iter().sum::<f64>().max(f64::EPSILON);
83 let mut base = vec![0usize; n];
84 let mut frac = vec![0.0_f64; n];
85 for i in 0..n {
86 let exact = budget_tokens as f64 * w[i] / sum_w;
87 base[i] = exact.floor() as usize;
88 frac[i] = exact - base[i] as f64;
89 }
90 let given: usize = base.iter().sum();
91 let mut order: Vec<usize> = (0..n).collect();
92 order.sort_by(|&a, &b| {
93 frac[b]
94 .partial_cmp(&frac[a])
95 .unwrap_or(std::cmp::Ordering::Equal)
96 });
97 let mut extra = budget_tokens.saturating_sub(given);
98 for &i in &order {
99 if extra == 0 {
100 break;
101 }
102 base[i] += 1;
103 extra -= 1;
104 }
105 base
106}
107
108fn exp_weights(n: usize) -> Vec<f64> {
109 if n == 0 {
110 return Vec::new();
111 }
112 let lambda = 1.35_f64;
113 (0..n).map(|i| (lambda * i as f64).exp()).collect()
114}
115
116pub fn compress_progressive(segments: &[String], budget_tokens: usize) -> Vec<String> {
118 let n = segments.len();
119 if n == 0 {
120 return Vec::new();
121 }
122 if budget_tokens == 0 {
123 return segments.iter().map(|_| String::new()).collect();
124 }
125
126 let w = exp_weights(n);
127 let allocs = allocate_budget_chunks(budget_tokens, &w);
128
129 let mut out = Vec::with_capacity(n);
130 for i in 0..n {
131 let alloc = allocs[i];
132
133 let tier = tier_for_index(i, n);
134 let seg = &segments[i];
135
136 let compressed = if alloc == 0 {
137 String::new()
138 } else {
139 match tier {
140 2 => truncate_to_token_budget(seg, alloc),
141 1 => map_like(seg, alloc),
142 _ => one_line_summary(i, seg, alloc),
143 }
144 };
145
146 let capped = if alloc == 0 {
147 String::new()
148 } else {
149 truncate_to_token_budget(&compressed, alloc)
150 };
151 out.push(capped);
152 }
153
154 out
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn empty_segments() {
163 assert!(compress_progressive(&[], 100).is_empty());
164 }
165
166 #[test]
167 fn newest_more_verbose_than_oldest() {
168 let mut segs = Vec::new();
169 for i in 0..9 {
170 let body = format!(
171 "pub fn func_{i}(x: u32, y: &str) -> Option<()> {{ let z = x.wrapping_add({i}); Some(()) }}\n",
172 );
173 segs.push(body.repeat(4));
174 }
175 let budget = 5000usize;
176 let out = compress_progressive(&segs, budget);
177 assert_eq!(out.len(), segs.len());
178 assert!(count_tokens(&out[0]) < count_tokens(&out[8]));
179 assert!(
180 out[0].starts_with("// seg[") || count_tokens(&out[0]) < 16,
181 "oldest tier should be highly compressed"
182 );
183 assert!(out[8].contains("pub fn"));
184 }
185
186 #[test]
187 fn respects_global_budget_order_of_magnitude() {
188 let segs: Vec<String> = (0..4).map(|i| format!("line {i}\nabc\n")).collect();
189 let out = compress_progressive(&segs, 80);
190 let total: usize = out.iter().map(|s| count_tokens(s)).sum();
191 assert!(total <= 80);
192 }
193
194 #[test]
195 fn single_segment_full_path() {
196 let one = vec!["hello world token budget".into()];
197 let out = compress_progressive(&one, 50);
198 assert_eq!(out.len(), 1);
199 assert!(!out[0].is_empty());
200 }
201}