Skip to main content

lean_ctx/core/
attention_model.rs

1//! Heuristic attention prediction model for LLM context optimization.
2//!
3//! Based on empirical findings from "Lost in the Middle" (Liu et al., 2023):
4//! - Transformers attend strongly to begin and end positions
5//! - Middle positions receive ~50% less attention
6//! - Structural markers (definitions, errors) attract attention regardless of position
7//!
8//! This module provides a position + structure based attention estimator
9//! that can be used to reorder or filter context for maximum LLM utilization.
10
11/// Compute a U-shaped attention weight for a given position.
12/// position: normalized 0.0 (begin) to 1.0 (end)
13/// Returns attention weight in [0, 1].
14///
15/// Uses a quadratic U-curve that better models the empirical findings from
16/// Liu et al. (2023) "Lost in the Middle" — attention drops more steeply
17/// toward the middle than a linear model predicts.
18///
19/// Formula: f(x) = α·(1-2x)² + γ·(2x-1)² + β·(1 - (1-2x)² - (2x-1)²)
20///        simplified for piecewise: quadratic decay from edges toward center.
21pub fn positional_attention(position: f64, alpha: f64, beta: f64, gamma: f64) -> f64 {
22    if position <= 0.0 {
23        return alpha;
24    }
25    if position >= 1.0 {
26        return gamma;
27    }
28
29    if position <= 0.5 {
30        let t = position / 0.5;
31        let t2 = t * t;
32        alpha * (1.0 - t2) + beta * t2
33    } else {
34        let t = (position - 0.5) / 0.5;
35        let t2 = t * t;
36        beta * (1.0 - t2) + gamma * t2
37    }
38}
39
40/// Estimate the structural importance of a line.
41/// Returns a multiplier [0.1, 2.0] based on syntactic patterns.
42///
43/// Weights updated 2026-04-02 based on empirical attention analysis
44/// (Lab Experiment B: TinyLlama 1.1B on 106 Rust files):
45///   import  → 0.0285 mean attn (was rated 0.6, now 1.6)
46///   comment → 0.0123 mean attn (was rated 0.4, now 1.2)
47///   definition → 0.0038 (was 1.8, adjusted to 1.5)
48///   test/assert → 0.0004 (was 1.5, lowered to 0.8)
49pub fn structural_importance(line: &str) -> f64 {
50    let trimmed = line.trim();
51    if trimmed.is_empty() {
52        return 0.1;
53    }
54
55    if trimmed.starts_with("error")
56        || trimmed.starts_with("Error")
57        || trimmed.contains("ERROR")
58        || trimmed.starts_with("panic")
59        || trimmed.starts_with("FAIL")
60    {
61        return 2.0;
62    }
63
64    // Lab finding: imports get 3x more attention than definitions.
65    // They establish namespace context the model needs for all subsequent code.
66    if trimmed.starts_with("use ")
67        || trimmed.starts_with("import ")
68        || trimmed.starts_with("from ")
69        || trimmed.starts_with("#include")
70    {
71        return 1.6;
72    }
73
74    if is_definition(trimmed) {
75        return 1.5;
76    }
77
78    // Lab finding: comments are semantic anchors — 3x more attention than logic.
79    if trimmed.starts_with("//")
80        || trimmed.starts_with("#")
81        || trimmed.starts_with("/*")
82        || trimmed.starts_with("*")
83    {
84        return 1.2;
85    }
86
87    if trimmed.starts_with("return ") || trimmed.starts_with("yield ") {
88        return 1.0;
89    }
90
91    if trimmed.starts_with("if ")
92        || trimmed.starts_with("match ")
93        || trimmed.starts_with("for ")
94        || trimmed.starts_with("while ")
95    {
96        return 0.9;
97    }
98
99    // Lab finding: test assertions get minimal attention (0.0004) —
100    // lowest of all line types unless the task is about testing.
101    if trimmed.starts_with("assert")
102        || trimmed.starts_with("expect(")
103        || trimmed.starts_with("#[test]")
104        || trimmed.starts_with("@Test")
105    {
106        return 0.8;
107    }
108
109    if trimmed == "}" || trimmed == "};" || trimmed == "})" {
110        return 0.3;
111    }
112
113    0.8
114}
115
116/// Compute combined attention score for a line at a given position.
117/// Combines positional U-curve with structural importance.
118pub fn combined_attention(line: &str, position: f64, alpha: f64, beta: f64, gamma: f64) -> f64 {
119    let pos_weight = positional_attention(position, alpha, beta, gamma);
120    let struct_weight = structural_importance(line);
121    // Geometric mean balances both factors
122    (pos_weight * struct_weight).sqrt()
123}
124
125/// Reorder lines to maximize predicted attention utilization.
126/// Places high-attention lines at begin and end positions.
127pub fn attention_optimize(lines: &[&str], _alpha: f64, _beta: f64, _gamma: f64) -> Vec<String> {
128    if lines.len() <= 3 {
129        return lines.iter().map(|l| l.to_string()).collect();
130    }
131
132    let mut scored: Vec<(usize, f64)> = lines
133        .iter()
134        .enumerate()
135        .map(|(i, line)| {
136            let importance = structural_importance(line);
137            (i, importance)
138        })
139        .collect();
140
141    // Sort by importance (high first)
142    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
143
144    // Place most important at begin (alpha) and end (gamma) positions
145    let n = scored.len();
146    let mut result = vec![String::new(); n];
147    let mut begin_idx = 0;
148    let mut end_idx = n - 1;
149    let mut mid_idx = n / 4; // start mid section after first quarter
150
151    for (i, (orig_idx, _importance)) in scored.iter().enumerate() {
152        if i % 3 == 0 && begin_idx < n / 3 {
153            result[begin_idx] = lines[*orig_idx].to_string();
154            begin_idx += 1;
155        } else if i % 3 == 1 && end_idx > 2 * n / 3 {
156            result[end_idx] = lines[*orig_idx].to_string();
157            end_idx -= 1;
158        } else {
159            if mid_idx < 2 * n / 3 {
160                result[mid_idx] = lines[*orig_idx].to_string();
161                mid_idx += 1;
162            }
163        }
164    }
165
166    // Fill any remaining empty slots with original order
167    let mut remaining: Vec<String> = lines.iter().map(|l| l.to_string()).collect();
168    for slot in &mut result {
169        if slot.is_empty() {
170            if let Some(line) = remaining.pop() {
171                *slot = line;
172            }
173        }
174    }
175
176    result
177}
178
179/// Compute the theoretical attention efficiency for a given context layout.
180/// Returns a percentage [0, 100] indicating how much of the context
181/// is in attention-optimal positions.
182pub fn attention_efficiency(line_importances: &[f64], alpha: f64, beta: f64, gamma: f64) -> f64 {
183    if line_importances.is_empty() {
184        return 0.0;
185    }
186
187    let n = line_importances.len();
188    let mut weighted_sum = 0.0;
189    let mut total_importance = 0.0;
190
191    for (i, &importance) in line_importances.iter().enumerate() {
192        let pos = i as f64 / (n - 1).max(1) as f64;
193        let pos_weight = positional_attention(pos, alpha, beta, gamma);
194        weighted_sum += importance * pos_weight;
195        total_importance += importance;
196    }
197
198    if total_importance == 0.0 {
199        return 0.0;
200    }
201
202    (weighted_sum / total_importance) * 100.0
203}
204
205fn is_definition(line: &str) -> bool {
206    let starts = [
207        "fn ",
208        "pub fn ",
209        "async fn ",
210        "pub async fn ",
211        "struct ",
212        "pub struct ",
213        "enum ",
214        "pub enum ",
215        "trait ",
216        "pub trait ",
217        "impl ",
218        "type ",
219        "pub type ",
220        "const ",
221        "pub const ",
222        "static ",
223        "class ",
224        "export class ",
225        "interface ",
226        "export interface ",
227        "function ",
228        "export function ",
229        "async function ",
230        "def ",
231        "async def ",
232        "func ",
233    ];
234    starts.iter().any(|s| line.starts_with(s))
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn positional_u_curve() {
243        let begin = positional_attention(0.0, 0.9, 0.5, 0.85);
244        let middle = positional_attention(0.5, 0.9, 0.5, 0.85);
245        let end = positional_attention(1.0, 0.9, 0.5, 0.85);
246
247        assert!((begin - 0.9).abs() < 0.01);
248        assert!((middle - 0.5).abs() < 0.01);
249        assert!((end - 0.85).abs() < 0.01);
250        assert!(begin > middle);
251        assert!(end > middle);
252    }
253
254    #[test]
255    fn structural_errors_highest() {
256        let error = structural_importance("error[E0433]: failed to resolve");
257        let import = structural_importance("use std::collections::HashMap;");
258        let def = structural_importance("fn main() {");
259        let comment = structural_importance("// just a comment");
260        let brace = structural_importance("}");
261
262        assert!(error > import, "errors should be highest");
263        assert!(
264            import > def,
265            "imports should outrank definitions (lab finding)"
266        );
267        assert!(def > comment, "definitions should outrank comments");
268        assert!(comment > brace, "comments should outrank closing braces");
269    }
270
271    #[test]
272    fn combined_high_at_begin_with_definition() {
273        let score_begin = combined_attention("fn main() {", 0.0, 0.9, 0.5, 0.85);
274        let score_middle = combined_attention("fn main() {", 0.5, 0.9, 0.5, 0.85);
275        assert!(score_begin > score_middle);
276    }
277
278    #[test]
279    fn efficiency_higher_when_important_at_edges() {
280        let good_layout = vec![1.8, 0.3, 0.3, 0.3, 1.5]; // important at begin+end
281        let bad_layout = vec![0.3, 0.3, 1.8, 1.5, 0.3]; // important in middle
282
283        let eff_good = attention_efficiency(&good_layout, 0.9, 0.5, 0.85);
284        let eff_bad = attention_efficiency(&bad_layout, 0.9, 0.5, 0.85);
285        assert!(
286            eff_good > eff_bad,
287            "edges layout ({eff_good:.1}) should beat middle layout ({eff_bad:.1})"
288        );
289    }
290}