Skip to main content

lean_ctx/core/
attention_model.rs

1//! Heuristic attention prediction model for LLM context optimization.
2//!
3//! Based on empirical findings from "Lost in the Middle" (Liu et al., 2023):
4//! - Transformers attend strongly to begin and end positions
5//! - Middle positions receive ~50% less attention
6//! - Structural markers (definitions, errors) attract attention regardless of position
7//!
8//! This module provides a position + structure based attention estimator
9//! that can be used to reorder or filter context for maximum LLM utilization.
10
11/// Compute a U-shaped attention weight for a given position.
12/// position: normalized 0.0 (begin) to 1.0 (end)
13/// Returns attention weight in [0, 1].
14///
15/// Uses a quadratic U-curve that better models the empirical findings from
16/// Liu et al. (2023) "Lost in the Middle" — attention drops more steeply
17/// toward the middle than a linear model predicts.
18///
19/// Formula: f(x) = α·(1-2x)² + γ·(2x-1)² + β·(1 - (1-2x)² - (2x-1)²)
20///        simplified for piecewise: quadratic decay from edges toward center.
21pub fn positional_attention(position: f64, alpha: f64, beta: f64, gamma: f64) -> f64 {
22    if position <= 0.0 {
23        return alpha;
24    }
25    if position >= 1.0 {
26        return gamma;
27    }
28
29    if position <= 0.5 {
30        let t = position / 0.5;
31        let t2 = t * t;
32        alpha * (1.0 - t2) + beta * t2
33    } else {
34        let t = (position - 0.5) / 0.5;
35        let t2 = t * t;
36        beta * (1.0 - t2) + gamma * t2
37    }
38}
39
40/// Estimate the structural importance of a line.
41/// Returns a multiplier [0.1, 2.0] based on syntactic patterns.
42///
43/// Weights updated 2026-04-02 based on empirical attention analysis
44/// (Lab Experiment B: TinyLlama 1.1B on 106 Rust files):
45///   import  → 0.0285 mean attn (was rated 0.6, now 1.6)
46///   comment → 0.0123 mean attn (was rated 0.4, now 1.2)
47///   definition → 0.0038 (was 1.8, adjusted to 1.5)
48///   test/assert → 0.0004 (was 1.5, lowered to 0.8)
49pub fn structural_importance(line: &str) -> f64 {
50    let trimmed = line.trim();
51    if trimmed.is_empty() {
52        return 0.1;
53    }
54
55    if trimmed.starts_with("error")
56        || trimmed.starts_with("Error")
57        || trimmed.contains("ERROR")
58        || trimmed.starts_with("panic")
59        || trimmed.starts_with("FAIL")
60    {
61        return 2.0;
62    }
63
64    // Lab finding: imports get 3x more attention than definitions.
65    // They establish namespace context the model needs for all subsequent code.
66    if trimmed.starts_with("use ")
67        || trimmed.starts_with("import ")
68        || trimmed.starts_with("from ")
69        || trimmed.starts_with("#include")
70    {
71        return 1.6;
72    }
73
74    if is_definition(trimmed) {
75        return 1.5;
76    }
77
78    // Lab finding: comments are semantic anchors — 3x more attention than logic.
79    if trimmed.starts_with("//")
80        || trimmed.starts_with("#")
81        || trimmed.starts_with("/*")
82        || trimmed.starts_with("*")
83    {
84        return 1.2;
85    }
86
87    if trimmed.starts_with("return ") || trimmed.starts_with("yield ") {
88        return 1.0;
89    }
90
91    if trimmed.starts_with("if ")
92        || trimmed.starts_with("match ")
93        || trimmed.starts_with("for ")
94        || trimmed.starts_with("while ")
95    {
96        return 0.9;
97    }
98
99    // Lab finding: test assertions get minimal attention (0.0004) —
100    // lowest of all line types unless the task is about testing.
101    if trimmed.starts_with("assert")
102        || trimmed.starts_with("expect(")
103        || trimmed.starts_with("#[test]")
104        || trimmed.starts_with("@Test")
105    {
106        return 0.8;
107    }
108
109    if trimmed == "}" || trimmed == "};" || trimmed == "})" {
110        return 0.3;
111    }
112
113    0.8
114}
115
116/// Compute combined attention score for a line at a given position.
117/// Combines positional U-curve with structural importance.
118pub fn combined_attention(line: &str, position: f64, alpha: f64, beta: f64, gamma: f64) -> f64 {
119    let pos_weight = positional_attention(position, alpha, beta, gamma);
120    let struct_weight = structural_importance(line);
121    // Geometric mean balances both factors
122    (pos_weight * struct_weight).sqrt()
123}
124
125/// Reorder lines to maximize predicted attention utilization.
126/// Places high-attention lines at begin and end positions.
127pub fn attention_optimize(lines: &[&str], _alpha: f64, _beta: f64, _gamma: f64) -> Vec<String> {
128    if lines.len() <= 3 {
129        return lines.iter().map(|l| l.to_string()).collect();
130    }
131
132    let mut scored: Vec<(usize, f64)> = lines
133        .iter()
134        .enumerate()
135        .map(|(i, line)| {
136            let importance = structural_importance(line);
137            (i, importance)
138        })
139        .collect();
140
141    // Sort by importance (high first)
142    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
143
144    // Place most important at begin (alpha) and end (gamma) positions
145    let n = scored.len();
146    let mut result = vec![String::new(); n];
147    let mut begin_idx = 0;
148    let mut end_idx = n - 1;
149    let mut mid_idx = n / 4; // start mid section after first quarter
150
151    for (i, (orig_idx, _importance)) in scored.iter().enumerate() {
152        if i % 3 == 0 && begin_idx < n / 3 {
153            result[begin_idx] = lines[*orig_idx].to_string();
154            begin_idx += 1;
155        } else if i % 3 == 1 && end_idx > 2 * n / 3 {
156            result[end_idx] = lines[*orig_idx].to_string();
157            end_idx -= 1;
158        } else if mid_idx < 2 * n / 3 {
159            result[mid_idx] = lines[*orig_idx].to_string();
160            mid_idx += 1;
161        }
162    }
163
164    // Fill any remaining empty slots with original order
165    let mut remaining: Vec<String> = lines.iter().map(|l| l.to_string()).collect();
166    for slot in &mut result {
167        if slot.is_empty() {
168            if let Some(line) = remaining.pop() {
169                *slot = line;
170            }
171        }
172    }
173
174    result
175}
176
177/// Compute the theoretical attention efficiency for a given context layout.
178/// Returns a percentage [0, 100] indicating how much of the context
179/// is in attention-optimal positions.
180pub fn attention_efficiency(line_importances: &[f64], alpha: f64, beta: f64, gamma: f64) -> f64 {
181    if line_importances.is_empty() {
182        return 0.0;
183    }
184
185    let n = line_importances.len();
186    let mut weighted_sum = 0.0;
187    let mut total_importance = 0.0;
188
189    for (i, &importance) in line_importances.iter().enumerate() {
190        let pos = i as f64 / (n - 1).max(1) as f64;
191        let pos_weight = positional_attention(pos, alpha, beta, gamma);
192        weighted_sum += importance * pos_weight;
193        total_importance += importance;
194    }
195
196    if total_importance == 0.0 {
197        return 0.0;
198    }
199
200    (weighted_sum / total_importance) * 100.0
201}
202
203fn is_definition(line: &str) -> bool {
204    let starts = [
205        "fn ",
206        "pub fn ",
207        "async fn ",
208        "pub async fn ",
209        "struct ",
210        "pub struct ",
211        "enum ",
212        "pub enum ",
213        "trait ",
214        "pub trait ",
215        "impl ",
216        "type ",
217        "pub type ",
218        "const ",
219        "pub const ",
220        "static ",
221        "class ",
222        "export class ",
223        "interface ",
224        "export interface ",
225        "function ",
226        "export function ",
227        "async function ",
228        "def ",
229        "async def ",
230        "func ",
231    ];
232    starts.iter().any(|s| line.starts_with(s))
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn positional_u_curve() {
241        let begin = positional_attention(0.0, 0.9, 0.5, 0.85);
242        let middle = positional_attention(0.5, 0.9, 0.5, 0.85);
243        let end = positional_attention(1.0, 0.9, 0.5, 0.85);
244
245        assert!((begin - 0.9).abs() < 0.01);
246        assert!((middle - 0.5).abs() < 0.01);
247        assert!((end - 0.85).abs() < 0.01);
248        assert!(begin > middle);
249        assert!(end > middle);
250    }
251
252    #[test]
253    fn structural_errors_highest() {
254        let error = structural_importance("error[E0433]: failed to resolve");
255        let import = structural_importance("use std::collections::HashMap;");
256        let def = structural_importance("fn main() {");
257        let comment = structural_importance("// just a comment");
258        let brace = structural_importance("}");
259
260        assert!(error > import, "errors should be highest");
261        assert!(
262            import > def,
263            "imports should outrank definitions (lab finding)"
264        );
265        assert!(def > comment, "definitions should outrank comments");
266        assert!(comment > brace, "comments should outrank closing braces");
267    }
268
269    #[test]
270    fn combined_high_at_begin_with_definition() {
271        let score_begin = combined_attention("fn main() {", 0.0, 0.9, 0.5, 0.85);
272        let score_middle = combined_attention("fn main() {", 0.5, 0.9, 0.5, 0.85);
273        assert!(score_begin > score_middle);
274    }
275
276    #[test]
277    fn efficiency_higher_when_important_at_edges() {
278        let good_layout = vec![1.8, 0.3, 0.3, 0.3, 1.5]; // important at begin+end
279        let bad_layout = vec![0.3, 0.3, 1.8, 1.5, 0.3]; // important in middle
280
281        let eff_good = attention_efficiency(&good_layout, 0.9, 0.5, 0.85);
282        let eff_bad = attention_efficiency(&bad_layout, 0.9, 0.5, 0.85);
283        assert!(
284            eff_good > eff_bad,
285            "edges layout ({eff_good:.1}) should beat middle layout ({eff_bad:.1})"
286        );
287    }
288}