lean_ctx/core/
adaptive_chunking.rs1use crate::core::tokens::count_tokens;
3
4#[derive(Debug, Clone, PartialEq)]
6pub struct ChunkResult {
7 pub content: String,
8 pub start_line: usize,
9 pub end_line: usize,
10 pub priority: f64,
11}
12
13const TIGHT_PER_ITEM: usize = 50;
14const GENEROUS_PER_ITEM: usize = 200;
15
16fn is_fn_line(line: &str) -> bool {
17 let t = line.trim_start();
18 t.starts_with("fn ")
19 || t.starts_with("pub fn ")
20 || t.starts_with("async fn ")
21 || t.starts_with("pub async fn ")
22 || t.starts_with("unsafe fn ")
23 || t.starts_with("pub unsafe fn ")
24 || t.starts_with("pub(crate) fn ")
25}
26
27fn chunk_ranges(lines: &[&str]) -> Vec<(usize, usize)> {
28 let n = lines.len();
29 if n == 0 {
30 return Vec::new();
31 }
32 let starts: Vec<usize> = lines
33 .iter()
34 .enumerate()
35 .filter_map(|(i, l)| is_fn_line(l).then_some(i))
36 .collect();
37 if starts.is_empty() {
38 return vec![(0, n - 1)];
39 }
40 let mut ranges = Vec::new();
41 if starts[0] > 0 {
42 ranges.push((0, starts[0] - 1));
43 }
44 for (k, &s) in starts.iter().enumerate() {
45 let end = if k + 1 < starts.len() {
46 starts[k + 1] - 1
47 } else {
48 n - 1
49 };
50 ranges.push((s, end));
51 }
52 ranges
53}
54
55fn import_hits(lines: &[&str]) -> usize {
56 lines
57 .iter()
58 .filter(|l| {
59 let t = l.trim_start();
60 t.starts_with("use ") || t.starts_with("import ")
61 })
62 .count()
63}
64
65fn brace_complexity(text: &str) -> f64 {
66 let mut depth = 0i32;
67 let mut maxd = 0i32;
68 for c in text.chars() {
69 match c {
70 '{' | '(' | '[' => {
71 depth += 1;
72 maxd = maxd.max(depth);
73 }
74 '}' | ')' | ']' => {
75 depth -= 1;
76 }
77 _ => {}
78 }
79 }
80 let kw = ["for ", "while ", "match ", "loop ", "if ", "else"];
81 let mut kc = 0.0_f64;
82 for k in kw {
83 kc += text.matches(k).count() as f64;
84 }
85 maxd as f64 * 0.18 + kc * 0.06
86}
87
88fn build_chunk_body(lines: &[&str], start: usize, end: usize) -> String {
89 lines[start..=end].join("\n")
90}
91
92fn signature_body(lines: &[&str], start: usize, end: usize) -> String {
94 let mut out = String::new();
95 for (i, line) in lines.iter().enumerate().take(end + 1).skip(start) {
96 out.push_str(line);
97 if i < end {
98 out.push('\n');
99 }
100 if line.contains('{') || line.trim_end().ends_with(';') {
101 break;
102 }
103 if out.lines().count() >= 6 {
104 break;
105 }
106 }
107 out.trim_end().to_string()
108}
109
110fn chunk_priority(lines: &[&str], start: usize, end: usize, total_lines: usize) -> f64 {
111 let slice = &lines[start..=end];
112 let body = slice.join("\n");
113 let cx = brace_complexity(&body);
114 let im = import_hits(slice) as f64 * 0.12;
115 let denom = total_lines.max(1) as f64;
116 let recency = (usize::midpoint(start, end) + 1) as f64 / denom * 0.55;
117 (cx + im + recency).min(12.0)
118}
119
120fn proportional_body(lines: &[&str], start: usize, end: usize, target_tokens: usize) -> String {
121 let full = build_chunk_body(lines, start, end);
122 let ftoks = count_tokens(&full).max(1);
123 let frac = (target_tokens as f64 / ftoks as f64).clamp(0.12, 1.0);
124 let nlines = end - start + 1;
125 let take = ((nlines as f64 * frac).ceil() as usize).clamp(1, nlines);
126 lines[start..start + take].join("\n")
127}
128
129pub fn adaptive_chunk(content: &str, budget_tokens: usize, total_items: usize) -> Vec<ChunkResult> {
131 let lines: Vec<&str> = content.lines().collect();
132 let total_lines = lines.len().max(1);
133 let ranges = chunk_ranges(&lines);
134 let per_item = budget_tokens / total_items.max(1);
135
136 let mut raw: Vec<(usize, usize, f64)> = ranges
137 .into_iter()
138 .map(|(s, e)| {
139 let p = chunk_priority(&lines, s, e, total_lines);
140 (s, e, p)
141 })
142 .collect();
143
144 if raw.is_empty() {
145 return Vec::new();
146 }
147
148 raw.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
149
150 let mut results = Vec::new();
151
152 if per_item < TIGHT_PER_ITEM {
153 let mut used = 0usize;
154 for (s, e, pri) in raw {
155 let body = signature_body(&lines, s, e);
156 if body.is_empty() {
157 continue;
158 }
159 let t = count_tokens(&body);
160 if used + t > budget_tokens {
161 continue;
162 }
163 used += t;
164 results.push(ChunkResult {
165 content: body,
166 start_line: s + 1,
167 end_line: e + 1,
168 priority: pri,
169 });
170 }
171 results.sort_by(|a, b| {
172 a.start_line
173 .cmp(&b.start_line)
174 .then_with(|| b.priority.partial_cmp(&a.priority).unwrap())
175 });
176 return results;
177 }
178
179 if per_item > GENEROUS_PER_ITEM {
180 for (s, e, pri) in raw {
181 let body = build_chunk_body(&lines, s, e);
182 results.push(ChunkResult {
183 content: body,
184 start_line: s + 1,
185 end_line: e + 1,
186 priority: pri,
187 });
188 }
189 results.sort_by_key(|c| c.start_line);
190 return results;
191 }
192
193 let mut tmp = Vec::new();
195 for (s, e, pri) in raw {
196 let body = proportional_body(&lines, s, e, per_item);
197 tmp.push(ChunkResult {
198 content: body,
199 start_line: s + 1,
200 end_line: e + 1,
201 priority: pri,
202 });
203 }
204 tmp.sort_by_key(|c| c.start_line);
205 tmp
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 const SAMPLE: &str = r#"use std::io;
213
214fn foo() {
215 if true {
216 println!("a");
217 }
218}
219
220fn bar(x: i32) -> i32 {
221 let mut s = 0;
222 for i in 0..x {
223 s += i;
224 }
225 s
226}
227"#;
228
229 #[test]
230 fn tight_mode_prefers_signatures_and_respects_budget() {
231 let chunks = adaptive_chunk(SAMPLE, 80, 4);
232 assert!(!chunks.is_empty());
233 for c in &chunks {
234 assert!(!c.content.contains("println!"));
235 }
236 let tok_total: usize = chunks.iter().map(|c| count_tokens(&c.content)).sum();
237 assert!(tok_total <= 80);
238 }
239
240 #[test]
241 fn generous_mode_keeps_full_bodies() {
242 let chunks = adaptive_chunk(SAMPLE, 50_000, 1);
243 assert!(chunks.iter().any(|c| c.content.contains("println!")));
244 assert!(chunks.iter().any(|c| c.content.contains("for i")));
245 }
246
247 #[test]
248 fn middle_mode_partial_body() {
249 let mut big = SAMPLE.to_string();
250 big.push_str("\nfn baz() {\n");
251 for i in 0..120 {
252 big.push_str(&format!(" let _z{i} = {i};\n"));
253 }
254 big.push_str("}\n");
255 let chunks = adaptive_chunk(&big, 750, 10);
257 let baz = chunks
258 .iter()
259 .find(|c| c.content.contains("baz"))
260 .expect("baz chunk");
261 let baz_full_lines = big.lines().filter(|l| l.contains("_z")).count();
262 let baz_kept_lines = baz.content.lines().filter(|l| l.contains("_z")).count();
263 assert!(
264 baz_kept_lines < baz_full_lines,
265 "expected proportional truncation inside baz, kept={baz_kept_lines} full={baz_full_lines}"
266 );
267 }
268
269 #[test]
270 fn non_fn_file_single_chunk() {
271 let t = "hello world\nline two\n";
272 let chunks = adaptive_chunk(t, 500, 1);
273 assert_eq!(chunks.len(), 1);
274 assert_eq!(chunks[0].start_line, 1);
275 }
276}