Skip to main content

lean_ctx/core/
semantic_chunks.rs

1//! Semantic Chunking with Attention Bridges.
2//!
3//! Groups content into semantic chunks (function bodies, import blocks, type
4//! definitions) rather than treating lines independently. Orders chunks for
5//! optimal LLM attention flow:
6//!
7//! 1. Most relevant chunk FIRST (high-attention position)
8//! 2. Its immediate dependencies (imports, types it uses) adjacent
9//! 3. Supporting context in the middle
10//! 4. Tail anchor: brief reference back to the primary chunk (attention bridge)
11//!
12//! This exploits how transformer attention actually works:
13//! local coherence + global anchors beats scattered high-importance lines.
14
15use std::collections::HashSet;
16
17#[derive(Debug, Clone)]
18pub struct SemanticChunk {
19    pub lines: Vec<String>,
20    pub kind: ChunkKind,
21    pub relevance: f64,
22    pub start_line: usize,
23    pub identifier: Option<String>,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum ChunkKind {
28    Imports,
29    TypeDefinition,
30    FunctionDef,
31    Logic,
32    Empty,
33}
34
35/// Detect semantic boundaries in content and group lines into chunks.
36pub fn detect_chunks(content: &str) -> Vec<SemanticChunk> {
37    let lines: Vec<&str> = content.lines().collect();
38    if lines.is_empty() {
39        return Vec::new();
40    }
41
42    let mut chunks: Vec<SemanticChunk> = Vec::new();
43    let mut current_lines: Vec<String> = Vec::new();
44    let mut current_kind = ChunkKind::Empty;
45    let mut current_start = 0;
46    let mut current_ident: Option<String> = None;
47    let mut brace_depth: i32 = 0;
48    let mut in_block = false;
49
50    for (i, &line) in lines.iter().enumerate() {
51        let trimmed = line.trim();
52        let line_kind = classify_line(trimmed);
53
54        let opens = trimmed.matches('{').count() as i32;
55        let closes = trimmed.matches('}').count() as i32;
56
57        if !in_block && is_block_start(trimmed) {
58            if !current_lines.is_empty() {
59                chunks.push(SemanticChunk {
60                    lines: current_lines.clone(),
61                    kind: current_kind,
62                    relevance: 0.0,
63                    start_line: current_start,
64                    identifier: current_ident.take(),
65                });
66                current_lines.clear();
67            }
68            current_start = i;
69            current_kind = line_kind;
70            current_ident = extract_identifier(trimmed);
71            in_block = opens > closes;
72            brace_depth = opens - closes;
73            current_lines.push(line.to_string());
74            continue;
75        }
76
77        if in_block {
78            brace_depth += opens - closes;
79            current_lines.push(line.to_string());
80            if brace_depth <= 0 {
81                in_block = false;
82                chunks.push(SemanticChunk {
83                    lines: current_lines.clone(),
84                    kind: current_kind,
85                    relevance: 0.0,
86                    start_line: current_start,
87                    identifier: current_ident.take(),
88                });
89                current_lines.clear();
90            }
91            continue;
92        }
93
94        // Boundary detection: blank lines or kind changes
95        let is_boundary =
96            trimmed.is_empty() || (line_kind != current_kind && !current_lines.is_empty());
97
98        if is_boundary && !current_lines.is_empty() {
99            chunks.push(SemanticChunk {
100                lines: current_lines.clone(),
101                kind: current_kind,
102                relevance: 0.0,
103                start_line: current_start,
104                identifier: current_ident.take(),
105            });
106            current_lines.clear();
107        }
108
109        if !trimmed.is_empty() {
110            if current_lines.is_empty() {
111                current_start = i;
112                current_kind = line_kind;
113            }
114            current_lines.push(line.to_string());
115        }
116    }
117
118    if !current_lines.is_empty() {
119        chunks.push(SemanticChunk {
120            lines: current_lines,
121            kind: current_kind,
122            relevance: 0.0,
123            start_line: current_start,
124            identifier: current_ident,
125        });
126    }
127
128    chunks
129}
130
131/// Score chunks by task relevance and reorder for optimal attention flow.
132pub fn order_for_attention(
133    mut chunks: Vec<SemanticChunk>,
134    task_keywords: &[String],
135) -> Vec<SemanticChunk> {
136    if chunks.is_empty() {
137        return chunks;
138    }
139
140    let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
141
142    // Score each chunk
143    for chunk in &mut chunks {
144        let text = chunk.lines.join(" ").to_lowercase();
145        let keyword_score: f64 = kw_lower
146            .iter()
147            .filter(|kw| text.contains(kw.as_str()))
148            .count() as f64;
149
150        let kind_weight = match chunk.kind {
151            ChunkKind::FunctionDef => 2.0,
152            ChunkKind::TypeDefinition => 1.8,
153            ChunkKind::Imports => 1.0,
154            ChunkKind::Logic => 0.8,
155            ChunkKind::Empty => 0.1,
156        };
157
158        let size_factor = (chunk.lines.len() as f64 / 5.0).min(1.5);
159
160        chunk.relevance = keyword_score * 2.0 + kind_weight + size_factor * 0.3;
161    }
162
163    // Sort by relevance (most relevant first)
164    chunks.sort_by(|a, b| {
165        b.relevance
166            .partial_cmp(&a.relevance)
167            .unwrap_or(std::cmp::Ordering::Equal)
168            .then_with(|| a.start_line.cmp(&b.start_line))
169    });
170
171    if chunks.len() <= 2 {
172        return chunks;
173    }
174
175    // Reorder: primary chunk first, then its dependencies, then rest
176    let primary = &chunks[0];
177    let primary_tokens: HashSet<String> = primary
178        .lines
179        .iter()
180        .flat_map(|l| l.split_whitespace().map(str::to_lowercase))
181        .collect();
182
183    let (mut deps, mut rest): (Vec<_>, Vec<_>) = chunks[1..].iter().cloned().partition(|chunk| {
184        if chunk.kind == ChunkKind::Imports || chunk.kind == ChunkKind::TypeDefinition {
185            let chunk_tokens: HashSet<String> = chunk
186                .lines
187                .iter()
188                .flat_map(|l| l.split_whitespace().map(str::to_lowercase))
189                .collect();
190            let overlap = primary_tokens.intersection(&chunk_tokens).count();
191            overlap >= 2
192        } else {
193            false
194        }
195    });
196
197    deps.sort_by(|a, b| {
198        b.relevance
199            .partial_cmp(&a.relevance)
200            .unwrap_or(std::cmp::Ordering::Equal)
201            .then_with(|| a.start_line.cmp(&b.start_line))
202    });
203    rest.sort_by(|a, b| {
204        b.relevance
205            .partial_cmp(&a.relevance)
206            .unwrap_or(std::cmp::Ordering::Equal)
207            .then_with(|| a.start_line.cmp(&b.start_line))
208    });
209
210    let mut ordered = Vec::with_capacity(chunks.len());
211    ordered.push(chunks[0].clone());
212    ordered.extend(deps);
213    ordered.extend(rest);
214
215    ordered
216}
217
218/// Render chunks back to text with attention bridges.
219pub fn render_with_bridges(chunks: &[SemanticChunk]) -> String {
220    if chunks.is_empty() {
221        return String::new();
222    }
223
224    let mut output = Vec::new();
225
226    for (i, chunk) in chunks.iter().enumerate() {
227        if i > 0 {
228            output.push(String::new());
229        }
230        for line in &chunk.lines {
231            output.push(line.clone());
232        }
233    }
234
235    // Tail anchor: reference back to primary chunk
236    if chunks.len() > 2 {
237        if let Some(ref ident) = chunks[0].identifier {
238            output.push(String::new());
239            output.push(format!("[primary: {ident}]"));
240        }
241    }
242
243    output.join("\n")
244}
245
246fn classify_line(trimmed: &str) -> ChunkKind {
247    if trimmed.is_empty() {
248        return ChunkKind::Empty;
249    }
250    if is_import(trimmed) {
251        return ChunkKind::Imports;
252    }
253    if is_type_def(trimmed) {
254        return ChunkKind::TypeDefinition;
255    }
256    if is_fn_start(trimmed) {
257        return ChunkKind::FunctionDef;
258    }
259    ChunkKind::Logic
260}
261
262fn is_block_start(trimmed: &str) -> bool {
263    is_fn_start(trimmed) || is_type_def(trimmed)
264}
265
266fn is_fn_start(line: &str) -> bool {
267    let starters = [
268        "fn ",
269        "pub fn ",
270        "async fn ",
271        "pub async fn ",
272        "function ",
273        "export function ",
274        "async function ",
275        "def ",
276        "async def ",
277        "func ",
278        "pub(crate) fn ",
279        "pub(super) fn ",
280    ];
281    starters.iter().any(|s| line.starts_with(s))
282}
283
284fn is_type_def(line: &str) -> bool {
285    let starters = [
286        "struct ",
287        "pub struct ",
288        "enum ",
289        "pub enum ",
290        "trait ",
291        "pub trait ",
292        "type ",
293        "pub type ",
294        "interface ",
295        "export interface ",
296        "class ",
297        "export class ",
298    ];
299    starters.iter().any(|s| line.starts_with(s))
300}
301
302fn is_import(line: &str) -> bool {
303    line.starts_with("use ")
304        || line.starts_with("import ")
305        || line.starts_with("from ")
306        || line.starts_with("#include")
307}
308
309fn extract_identifier(line: &str) -> Option<String> {
310    let cleaned = line
311        .replace("pub ", "")
312        .replace("async ", "")
313        .replace("export ", "");
314    let trimmed = cleaned.trim();
315
316    for prefix in &[
317        "fn ",
318        "struct ",
319        "enum ",
320        "trait ",
321        "type ",
322        "class ",
323        "interface ",
324        "function ",
325        "def ",
326        "func ",
327    ] {
328        if let Some(rest) = trimmed.strip_prefix(prefix) {
329            let name: String = rest
330                .chars()
331                .take_while(|c| c.is_alphanumeric() || *c == '_')
332                .collect();
333            if !name.is_empty() {
334                return Some(name);
335            }
336        }
337    }
338    None
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn detect_chunks_basic() {
347        let content = "use std::io;\nuse std::fs;\n\nfn main() {\n    let x = 1;\n}\n\nfn helper() {\n    let y = 2;\n}";
348        let chunks = detect_chunks(content);
349        assert!(
350            chunks.len() >= 2,
351            "should detect multiple chunks, got {}",
352            chunks.len()
353        );
354    }
355
356    #[test]
357    fn detect_chunks_identifies_functions() {
358        let content = "fn main() {\n    println!(\"hello\");\n}";
359        let chunks = detect_chunks(content);
360        assert!(
361            chunks.iter().any(|c| c.kind == ChunkKind::FunctionDef),
362            "should detect function definition"
363        );
364    }
365
366    #[test]
367    fn order_puts_relevant_first() {
368        let content =
369            "fn unrelated() {\n    let x = 1;\n}\n\nfn validate_token() {\n    check();\n}";
370        let chunks = detect_chunks(content);
371        let ordered = order_for_attention(chunks, &["validate".to_string()]);
372        assert!(
373            ordered[0].identifier.as_deref() == Some("validate_token"),
374            "most relevant chunk should be first"
375        );
376    }
377
378    #[test]
379    fn render_with_bridges_adds_anchor() {
380        let chunks = vec![
381            SemanticChunk {
382                lines: vec!["fn main() {".into(), "  let x = 1;".into(), "}".into()],
383                kind: ChunkKind::FunctionDef,
384                relevance: 5.0,
385                start_line: 0,
386                identifier: Some("main".into()),
387            },
388            SemanticChunk {
389                lines: vec!["use std::io;".into()],
390                kind: ChunkKind::Imports,
391                relevance: 1.0,
392                start_line: 5,
393                identifier: None,
394            },
395            SemanticChunk {
396                lines: vec!["fn helper() {".into(), "}".into()],
397                kind: ChunkKind::FunctionDef,
398                relevance: 0.5,
399                start_line: 8,
400                identifier: Some("helper".into()),
401            },
402        ];
403        let result = render_with_bridges(&chunks);
404        assert!(
405            result.contains("[primary: main]"),
406            "should have tail anchor"
407        );
408    }
409
410    #[test]
411    fn extract_identifier_fn() {
412        assert_eq!(
413            extract_identifier("pub fn validate_token() {"),
414            Some("validate_token".into())
415        );
416        assert_eq!(extract_identifier("struct Config {"), Some("Config".into()));
417        assert_eq!(extract_identifier("let x = 1;"), None);
418    }
419}