Skip to main content

lean_ctx/core/
adaptive_chunking.rs

1//! Adaptive chunk sizing from a rough “prefrontal” budget controller (tight → signatures, generous → full bodies).
2use crate::core::tokens::count_tokens;
3
4/// One slice of source text chosen for inclusion.
5#[derive(Debug, Clone, PartialEq)]
6pub struct ChunkResult {
7    pub content: String,
8    pub start_line: usize,
9    pub end_line: usize,
10    pub priority: f64,
11}
12
13const TIGHT_PER_ITEM: usize = 50;
14const GENEROUS_PER_ITEM: usize = 200;
15
16fn is_fn_line(line: &str) -> bool {
17    let t = line.trim_start();
18    t.starts_with("fn ")
19        || t.starts_with("pub fn ")
20        || t.starts_with("async fn ")
21        || t.starts_with("pub async fn ")
22        || t.starts_with("unsafe fn ")
23        || t.starts_with("pub unsafe fn ")
24        || t.starts_with("pub(crate) fn ")
25}
26
27fn chunk_ranges(lines: &[&str]) -> Vec<(usize, usize)> {
28    let n = lines.len();
29    if n == 0 {
30        return Vec::new();
31    }
32    let starts: Vec<usize> = lines
33        .iter()
34        .enumerate()
35        .filter_map(|(i, l)| is_fn_line(l).then_some(i))
36        .collect();
37    if starts.is_empty() {
38        return vec![(0, n - 1)];
39    }
40    let mut ranges = Vec::new();
41    if starts[0] > 0 {
42        ranges.push((0, starts[0] - 1));
43    }
44    for (k, &s) in starts.iter().enumerate() {
45        let end = if k + 1 < starts.len() {
46            starts[k + 1] - 1
47        } else {
48            n - 1
49        };
50        ranges.push((s, end));
51    }
52    ranges
53}
54
55fn import_hits(lines: &[&str]) -> usize {
56    lines
57        .iter()
58        .filter(|l| {
59            let t = l.trim_start();
60            t.starts_with("use ") || t.starts_with("import ")
61        })
62        .count()
63}
64
65fn brace_complexity(text: &str) -> f64 {
66    let mut depth = 0i32;
67    let mut maxd = 0i32;
68    for c in text.chars() {
69        match c {
70            '{' | '(' | '[' => {
71                depth += 1;
72                maxd = maxd.max(depth);
73            }
74            '}' | ')' | ']' => {
75                depth -= 1;
76            }
77            _ => {}
78        }
79    }
80    let kw = ["for ", "while ", "match ", "loop ", "if ", "else"];
81    let mut kc = 0.0_f64;
82    for k in kw {
83        kc += text.matches(k).count() as f64;
84    }
85    maxd as f64 * 0.18 + kc * 0.06
86}
87
88fn build_chunk_body(lines: &[&str], start: usize, end: usize) -> String {
89    lines[start..=end].join("\n")
90}
91
92/// Extract a compact signature-oriented prefix (first lines until `{` or trailing `;`).
93fn signature_body(lines: &[&str], start: usize, end: usize) -> String {
94    let mut out = String::new();
95    for (i, line) in lines.iter().enumerate().take(end + 1).skip(start) {
96        out.push_str(line);
97        if i < end {
98            out.push('\n');
99        }
100        if line.contains('{') || line.trim_end().ends_with(';') {
101            break;
102        }
103        if out.lines().count() >= 6 {
104            break;
105        }
106    }
107    out.trim_end().to_string()
108}
109
110fn chunk_priority(lines: &[&str], start: usize, end: usize, total_lines: usize) -> f64 {
111    let slice = &lines[start..=end];
112    let body = slice.join("\n");
113    let cx = brace_complexity(&body);
114    let im = import_hits(slice) as f64 * 0.12;
115    let denom = total_lines.max(1) as f64;
116    let recency = (usize::midpoint(start, end) + 1) as f64 / denom * 0.55;
117    (cx + im + recency).min(12.0)
118}
119
120fn proportional_body(lines: &[&str], start: usize, end: usize, target_tokens: usize) -> String {
121    let full = build_chunk_body(lines, start, end);
122    let ftoks = count_tokens(&full).max(1);
123    let frac = (target_tokens as f64 / ftoks as f64).clamp(0.12, 1.0);
124    let nlines = end - start + 1;
125    let take = ((nlines as f64 * frac).ceil() as usize).clamp(1, nlines);
126    lines[start..start + take].join("\n")
127}
128
129/// Split `content` into prioritized chunks sized to `budget_tokens` spread across `total_items` sibling slices.
130pub fn adaptive_chunk(content: &str, budget_tokens: usize, total_items: usize) -> Vec<ChunkResult> {
131    let lines: Vec<&str> = content.lines().collect();
132    let total_lines = lines.len().max(1);
133    let ranges = chunk_ranges(&lines);
134    let per_item = budget_tokens / total_items.max(1);
135
136    let mut raw: Vec<(usize, usize, f64)> = ranges
137        .into_iter()
138        .map(|(s, e)| {
139            let p = chunk_priority(&lines, s, e, total_lines);
140            (s, e, p)
141        })
142        .collect();
143
144    if raw.is_empty() {
145        return Vec::new();
146    }
147
148    raw.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
149
150    let mut results = Vec::new();
151
152    if per_item < TIGHT_PER_ITEM {
153        let mut used = 0usize;
154        for (s, e, pri) in raw {
155            let body = signature_body(&lines, s, e);
156            if body.is_empty() {
157                continue;
158            }
159            let t = count_tokens(&body);
160            if used + t > budget_tokens {
161                continue;
162            }
163            used += t;
164            results.push(ChunkResult {
165                content: body,
166                start_line: s + 1,
167                end_line: e + 1,
168                priority: pri,
169            });
170        }
171        results.sort_by(|a, b| {
172            a.start_line.cmp(&b.start_line).then_with(|| {
173                b.priority
174                    .partial_cmp(&a.priority)
175                    .unwrap_or(std::cmp::Ordering::Equal)
176            })
177        });
178        return results;
179    }
180
181    if per_item > GENEROUS_PER_ITEM {
182        for (s, e, pri) in raw {
183            let body = build_chunk_body(&lines, s, e);
184            results.push(ChunkResult {
185                content: body,
186                start_line: s + 1,
187                end_line: e + 1,
188                priority: pri,
189            });
190        }
191        results.sort_by_key(|c| c.start_line);
192        return results;
193    }
194
195    // Middle: proportional inclusion per chunk, processed in priority order but output sorted by line.
196    let mut tmp = Vec::new();
197    for (s, e, pri) in raw {
198        let body = proportional_body(&lines, s, e, per_item);
199        tmp.push(ChunkResult {
200            content: body,
201            start_line: s + 1,
202            end_line: e + 1,
203            priority: pri,
204        });
205    }
206    tmp.sort_by_key(|c| c.start_line);
207    tmp
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    const SAMPLE: &str = r#"use std::io;
215
216fn foo() {
217    if true {
218        println!("a");
219    }
220}
221
222fn bar(x: i32) -> i32 {
223    let mut s = 0;
224    for i in 0..x {
225        s += i;
226    }
227    s
228}
229"#;
230
231    #[test]
232    fn tight_mode_prefers_signatures_and_respects_budget() {
233        let chunks = adaptive_chunk(SAMPLE, 80, 4);
234        assert!(!chunks.is_empty());
235        for c in &chunks {
236            assert!(!c.content.contains("println!"));
237        }
238        let tok_total: usize = chunks.iter().map(|c| count_tokens(&c.content)).sum();
239        assert!(tok_total <= 80);
240    }
241
242    #[test]
243    fn generous_mode_keeps_full_bodies() {
244        let chunks = adaptive_chunk(SAMPLE, 50_000, 1);
245        assert!(chunks.iter().any(|c| c.content.contains("println!")));
246        assert!(chunks.iter().any(|c| c.content.contains("for i")));
247    }
248
249    #[test]
250    fn middle_mode_partial_body() {
251        let mut big = SAMPLE.to_string();
252        big.push_str("\nfn baz() {\n");
253        for i in 0..120 {
254            big.push_str(&format!("    let _z{i} = {i};\n"));
255        }
256        big.push_str("}\n");
257        // per-item budget ~75 tokens → proportional clipping on large fn body.
258        let chunks = adaptive_chunk(&big, 750, 10);
259        let baz = chunks
260            .iter()
261            .find(|c| c.content.contains("baz"))
262            .expect("baz chunk");
263        let baz_full_lines = big.lines().filter(|l| l.contains("_z")).count();
264        let baz_kept_lines = baz.content.lines().filter(|l| l.contains("_z")).count();
265        assert!(
266            baz_kept_lines < baz_full_lines,
267            "expected proportional truncation inside baz, kept={baz_kept_lines} full={baz_full_lines}"
268        );
269    }
270
271    #[test]
272    fn non_fn_file_single_chunk() {
273        let t = "hello world\nline two\n";
274        let chunks = adaptive_chunk(t, 500, 1);
275        assert_eq!(chunks.len(), 1);
276        assert_eq!(chunks[0].start_line, 1);
277    }
278}