lean_ctx/core/
attention_model.rs1pub fn positional_attention(position: f64, alpha: f64, beta: f64, gamma: f64) -> f64 {
15 if position <= 0.0 {
16 return alpha;
17 }
18 if position >= 1.0 {
19 return gamma;
20 }
21
22 if position <= 0.5 {
24 let t = position / 0.5;
25 alpha * (1.0 - t) + beta * t
26 } else {
27 let t = (position - 0.5) / 0.5;
28 beta * (1.0 - t) + gamma * t
29 }
30}
31
32pub fn structural_importance(line: &str) -> f64 {
35 let trimmed = line.trim();
36 if trimmed.is_empty() {
37 return 0.1;
38 }
39
40 if trimmed.starts_with("error")
42 || trimmed.starts_with("Error")
43 || trimmed.contains("ERROR")
44 || trimmed.starts_with("panic")
45 || trimmed.starts_with("FAIL")
46 {
47 return 2.0;
48 }
49
50 if is_definition(trimmed) {
52 return 1.8;
53 }
54
55 if trimmed.starts_with("assert")
57 || trimmed.starts_with("expect(")
58 || trimmed.starts_with("#[test]")
59 || trimmed.starts_with("@Test")
60 {
61 return 1.5;
62 }
63
64 if trimmed.starts_with("return ") || trimmed.starts_with("yield ") {
66 return 1.3;
67 }
68
69 if trimmed.starts_with("if ")
71 || trimmed.starts_with("match ")
72 || trimmed.starts_with("for ")
73 || trimmed.starts_with("while ")
74 {
75 return 1.1;
76 }
77
78 if trimmed.starts_with("use ")
80 || trimmed.starts_with("import ")
81 || trimmed.starts_with("from ")
82 || trimmed.starts_with("#include")
83 {
84 return 0.6;
85 }
86
87 if trimmed.starts_with("//")
89 || trimmed.starts_with("#")
90 || trimmed.starts_with("/*")
91 || trimmed.starts_with("*")
92 {
93 return 0.4;
94 }
95
96 if trimmed == "}" || trimmed == "};" || trimmed == "})" {
98 return 0.3;
99 }
100
101 0.8
103}
104
105pub fn combined_attention(line: &str, position: f64, alpha: f64, beta: f64, gamma: f64) -> f64 {
108 let pos_weight = positional_attention(position, alpha, beta, gamma);
109 let struct_weight = structural_importance(line);
110 (pos_weight * struct_weight).sqrt()
112}
113
114pub fn attention_optimize(lines: &[&str], _alpha: f64, _beta: f64, _gamma: f64) -> Vec<String> {
117 if lines.len() <= 3 {
118 return lines.iter().map(|l| l.to_string()).collect();
119 }
120
121 let mut scored: Vec<(usize, f64)> = lines
122 .iter()
123 .enumerate()
124 .map(|(i, line)| {
125 let importance = structural_importance(line);
126 (i, importance)
127 })
128 .collect();
129
130 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
132
133 let n = scored.len();
135 let mut result = vec![String::new(); n];
136 let mut begin_idx = 0;
137 let mut end_idx = n - 1;
138 let mut mid_idx = n / 4; for (i, (orig_idx, _importance)) in scored.iter().enumerate() {
141 if i % 3 == 0 && begin_idx < n / 3 {
142 result[begin_idx] = lines[*orig_idx].to_string();
143 begin_idx += 1;
144 } else if i % 3 == 1 && end_idx > 2 * n / 3 {
145 result[end_idx] = lines[*orig_idx].to_string();
146 end_idx -= 1;
147 } else {
148 if mid_idx < 2 * n / 3 {
149 result[mid_idx] = lines[*orig_idx].to_string();
150 mid_idx += 1;
151 }
152 }
153 }
154
155 let mut remaining: Vec<String> = lines.iter().map(|l| l.to_string()).collect();
157 for slot in &mut result {
158 if slot.is_empty() {
159 if let Some(line) = remaining.pop() {
160 *slot = line;
161 }
162 }
163 }
164
165 result
166}
167
168pub fn attention_efficiency(line_importances: &[f64], alpha: f64, beta: f64, gamma: f64) -> f64 {
172 if line_importances.is_empty() {
173 return 0.0;
174 }
175
176 let n = line_importances.len();
177 let mut weighted_sum = 0.0;
178 let mut total_importance = 0.0;
179
180 for (i, &importance) in line_importances.iter().enumerate() {
181 let pos = i as f64 / (n - 1).max(1) as f64;
182 let pos_weight = positional_attention(pos, alpha, beta, gamma);
183 weighted_sum += importance * pos_weight;
184 total_importance += importance;
185 }
186
187 if total_importance == 0.0 {
188 return 0.0;
189 }
190
191 (weighted_sum / total_importance) * 100.0
192}
193
194fn is_definition(line: &str) -> bool {
195 let starts = [
196 "fn ",
197 "pub fn ",
198 "async fn ",
199 "pub async fn ",
200 "struct ",
201 "pub struct ",
202 "enum ",
203 "pub enum ",
204 "trait ",
205 "pub trait ",
206 "impl ",
207 "type ",
208 "pub type ",
209 "const ",
210 "pub const ",
211 "static ",
212 "class ",
213 "export class ",
214 "interface ",
215 "export interface ",
216 "function ",
217 "export function ",
218 "async function ",
219 "def ",
220 "async def ",
221 "func ",
222 ];
223 starts.iter().any(|s| line.starts_with(s))
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn positional_u_curve() {
232 let begin = positional_attention(0.0, 0.9, 0.5, 0.85);
233 let middle = positional_attention(0.5, 0.9, 0.5, 0.85);
234 let end = positional_attention(1.0, 0.9, 0.5, 0.85);
235
236 assert!((begin - 0.9).abs() < 0.01);
237 assert!((middle - 0.5).abs() < 0.01);
238 assert!((end - 0.85).abs() < 0.01);
239 assert!(begin > middle);
240 assert!(end > middle);
241 }
242
243 #[test]
244 fn structural_errors_highest() {
245 let error = structural_importance("error[E0433]: failed to resolve");
246 let def = structural_importance("fn main() {");
247 let comment = structural_importance("// just a comment");
248 let brace = structural_importance("}");
249
250 assert!(error > def);
251 assert!(def > comment);
252 assert!(comment > brace);
253 }
254
255 #[test]
256 fn combined_high_at_begin_with_definition() {
257 let score_begin = combined_attention("fn main() {", 0.0, 0.9, 0.5, 0.85);
258 let score_middle = combined_attention("fn main() {", 0.5, 0.9, 0.5, 0.85);
259 assert!(score_begin > score_middle);
260 }
261
262 #[test]
263 fn efficiency_higher_when_important_at_edges() {
264 let good_layout = vec![1.8, 0.3, 0.3, 0.3, 1.5]; let bad_layout = vec![0.3, 0.3, 1.8, 1.5, 0.3]; let eff_good = attention_efficiency(&good_layout, 0.9, 0.5, 0.85);
268 let eff_bad = attention_efficiency(&bad_layout, 0.9, 0.5, 0.85);
269 assert!(
270 eff_good > eff_bad,
271 "edges layout ({eff_good:.1}) should beat middle layout ({eff_bad:.1})"
272 );
273 }
274}