Skip to main content

lean_ctx/core/
attention_model.rs

1//! Heuristic attention prediction model for LLM context optimization.
2//!
3//! Based on empirical findings from "Lost in the Middle" (Liu et al., 2023):
4//! - Transformers attend strongly to begin and end positions
5//! - Middle positions receive ~50% less attention
6//! - Structural markers (definitions, errors) attract attention regardless of position
7//!
8//! This module provides a position + structure based attention estimator
9//! that can be used to reorder or filter context for maximum LLM utilization.
10
11/// Compute a U-shaped attention weight for a given position.
12/// position: normalized 0.0 (begin) to 1.0 (end)
13/// Returns attention weight in [0, 1].
14pub fn positional_attention(position: f64, alpha: f64, beta: f64, gamma: f64) -> f64 {
15    if position <= 0.0 {
16        return alpha;
17    }
18    if position >= 1.0 {
19        return gamma;
20    }
21
22    // Piecewise linear U-curve: alpha at 0, beta at 0.5, gamma at 1.0
23    if position <= 0.5 {
24        let t = position / 0.5;
25        alpha * (1.0 - t) + beta * t
26    } else {
27        let t = (position - 0.5) / 0.5;
28        beta * (1.0 - t) + gamma * t
29    }
30}
31
32/// Estimate the structural importance of a line.
33/// Returns a multiplier [0.5, 2.0] based on syntactic patterns.
34pub fn structural_importance(line: &str) -> f64 {
35    let trimmed = line.trim();
36    if trimmed.is_empty() {
37        return 0.1;
38    }
39
40    // Error/warning lines always attract attention
41    if trimmed.starts_with("error")
42        || trimmed.starts_with("Error")
43        || trimmed.contains("ERROR")
44        || trimmed.starts_with("panic")
45        || trimmed.starts_with("FAIL")
46    {
47        return 2.0;
48    }
49
50    // Function/type definitions
51    if is_definition(trimmed) {
52        return 1.8;
53    }
54
55    // Assertions and test expectations
56    if trimmed.starts_with("assert")
57        || trimmed.starts_with("expect(")
58        || trimmed.starts_with("#[test]")
59        || trimmed.starts_with("@Test")
60    {
61        return 1.5;
62    }
63
64    // Return statements, assignments with computation
65    if trimmed.starts_with("return ") || trimmed.starts_with("yield ") {
66        return 1.3;
67    }
68
69    // Control flow
70    if trimmed.starts_with("if ")
71        || trimmed.starts_with("match ")
72        || trimmed.starts_with("for ")
73        || trimmed.starts_with("while ")
74    {
75        return 1.1;
76    }
77
78    // Import/use statements — less important
79    if trimmed.starts_with("use ")
80        || trimmed.starts_with("import ")
81        || trimmed.starts_with("from ")
82        || trimmed.starts_with("#include")
83    {
84        return 0.6;
85    }
86
87    // Comments — usually low attention
88    if trimmed.starts_with("//")
89        || trimmed.starts_with("#")
90        || trimmed.starts_with("/*")
91        || trimmed.starts_with("*")
92    {
93        return 0.4;
94    }
95
96    // Closing braces — minimal attention
97    if trimmed == "}" || trimmed == "};" || trimmed == "})" {
98        return 0.3;
99    }
100
101    // Default: moderate attention
102    0.8
103}
104
105/// Compute combined attention score for a line at a given position.
106/// Combines positional U-curve with structural importance.
107pub fn combined_attention(line: &str, position: f64, alpha: f64, beta: f64, gamma: f64) -> f64 {
108    let pos_weight = positional_attention(position, alpha, beta, gamma);
109    let struct_weight = structural_importance(line);
110    // Geometric mean balances both factors
111    (pos_weight * struct_weight).sqrt()
112}
113
114/// Reorder lines to maximize predicted attention utilization.
115/// Places high-attention lines at begin and end positions.
116pub fn attention_optimize(lines: &[&str], _alpha: f64, _beta: f64, _gamma: f64) -> Vec<String> {
117    if lines.len() <= 3 {
118        return lines.iter().map(|l| l.to_string()).collect();
119    }
120
121    let mut scored: Vec<(usize, f64)> = lines
122        .iter()
123        .enumerate()
124        .map(|(i, line)| {
125            let importance = structural_importance(line);
126            (i, importance)
127        })
128        .collect();
129
130    // Sort by importance (high first)
131    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
132
133    // Place most important at begin (alpha) and end (gamma) positions
134    let n = scored.len();
135    let mut result = vec![String::new(); n];
136    let mut begin_idx = 0;
137    let mut end_idx = n - 1;
138    let mut mid_idx = n / 4; // start mid section after first quarter
139
140    for (i, (orig_idx, _importance)) in scored.iter().enumerate() {
141        if i % 3 == 0 && begin_idx < n / 3 {
142            result[begin_idx] = lines[*orig_idx].to_string();
143            begin_idx += 1;
144        } else if i % 3 == 1 && end_idx > 2 * n / 3 {
145            result[end_idx] = lines[*orig_idx].to_string();
146            end_idx -= 1;
147        } else {
148            if mid_idx < 2 * n / 3 {
149                result[mid_idx] = lines[*orig_idx].to_string();
150                mid_idx += 1;
151            }
152        }
153    }
154
155    // Fill any remaining empty slots with original order
156    let mut remaining: Vec<String> = lines.iter().map(|l| l.to_string()).collect();
157    for slot in &mut result {
158        if slot.is_empty() {
159            if let Some(line) = remaining.pop() {
160                *slot = line;
161            }
162        }
163    }
164
165    result
166}
167
168/// Compute the theoretical attention efficiency for a given context layout.
169/// Returns a percentage [0, 100] indicating how much of the context
170/// is in attention-optimal positions.
171pub fn attention_efficiency(line_importances: &[f64], alpha: f64, beta: f64, gamma: f64) -> f64 {
172    if line_importances.is_empty() {
173        return 0.0;
174    }
175
176    let n = line_importances.len();
177    let mut weighted_sum = 0.0;
178    let mut total_importance = 0.0;
179
180    for (i, &importance) in line_importances.iter().enumerate() {
181        let pos = i as f64 / (n - 1).max(1) as f64;
182        let pos_weight = positional_attention(pos, alpha, beta, gamma);
183        weighted_sum += importance * pos_weight;
184        total_importance += importance;
185    }
186
187    if total_importance == 0.0 {
188        return 0.0;
189    }
190
191    (weighted_sum / total_importance) * 100.0
192}
193
194fn is_definition(line: &str) -> bool {
195    let starts = [
196        "fn ",
197        "pub fn ",
198        "async fn ",
199        "pub async fn ",
200        "struct ",
201        "pub struct ",
202        "enum ",
203        "pub enum ",
204        "trait ",
205        "pub trait ",
206        "impl ",
207        "type ",
208        "pub type ",
209        "const ",
210        "pub const ",
211        "static ",
212        "class ",
213        "export class ",
214        "interface ",
215        "export interface ",
216        "function ",
217        "export function ",
218        "async function ",
219        "def ",
220        "async def ",
221        "func ",
222    ];
223    starts.iter().any(|s| line.starts_with(s))
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn positional_u_curve() {
232        let begin = positional_attention(0.0, 0.9, 0.5, 0.85);
233        let middle = positional_attention(0.5, 0.9, 0.5, 0.85);
234        let end = positional_attention(1.0, 0.9, 0.5, 0.85);
235
236        assert!((begin - 0.9).abs() < 0.01);
237        assert!((middle - 0.5).abs() < 0.01);
238        assert!((end - 0.85).abs() < 0.01);
239        assert!(begin > middle);
240        assert!(end > middle);
241    }
242
243    #[test]
244    fn structural_errors_highest() {
245        let error = structural_importance("error[E0433]: failed to resolve");
246        let def = structural_importance("fn main() {");
247        let comment = structural_importance("// just a comment");
248        let brace = structural_importance("}");
249
250        assert!(error > def);
251        assert!(def > comment);
252        assert!(comment > brace);
253    }
254
255    #[test]
256    fn combined_high_at_begin_with_definition() {
257        let score_begin = combined_attention("fn main() {", 0.0, 0.9, 0.5, 0.85);
258        let score_middle = combined_attention("fn main() {", 0.5, 0.9, 0.5, 0.85);
259        assert!(score_begin > score_middle);
260    }
261
262    #[test]
263    fn efficiency_higher_when_important_at_edges() {
264        let good_layout = vec![1.8, 0.3, 0.3, 0.3, 1.5]; // important at begin+end
265        let bad_layout = vec![0.3, 0.3, 1.8, 1.5, 0.3]; // important in middle
266
267        let eff_good = attention_efficiency(&good_layout, 0.9, 0.5, 0.85);
268        let eff_bad = attention_efficiency(&bad_layout, 0.9, 0.5, 0.85);
269        assert!(
270            eff_good > eff_bad,
271            "edges layout ({eff_good:.1}) should beat middle layout ({eff_bad:.1})"
272        );
273    }
274}