lean_ctx/core/
attention_context.rs1use std::collections::HashSet;
15
16#[derive(Debug, Clone)]
18pub struct ChunkDensity {
19 pub chunk_idx: usize,
20 pub lexical_diversity: f64,
22 pub structural_weight: f64,
24 pub attention_score: f64,
26 pub token_budget: usize,
28}
29
30pub fn compute_density(content: &str, is_definition: bool) -> f64 {
32 let tokens: Vec<&str> = content.split_whitespace().collect();
33 if tokens.is_empty() {
34 return 0.0;
35 }
36
37 let unique: HashSet<&str> = tokens.iter().copied().collect();
38 let lexical_diversity = unique.len() as f64 / tokens.len() as f64;
39
40 let structural = if is_definition { 1.3 } else { 1.0 };
42
43 let repetition_penalty = if lexical_diversity < 0.3 { 0.5 } else { 1.0 };
45
46 lexical_diversity * structural * repetition_penalty
47}
48
49pub fn compute_redundancy(content_a: &str, content_b: &str) -> f64 {
51 let tokens_a: HashSet<&str> = content_a.split_whitespace().collect();
52 let tokens_b: HashSet<&str> = content_b.split_whitespace().collect();
53
54 if tokens_a.is_empty() || tokens_b.is_empty() {
55 return 0.0;
56 }
57
58 let intersection = tokens_a.intersection(&tokens_b).count();
59 let union = tokens_a.union(&tokens_b).count();
60
61 if union == 0 {
62 0.0
63 } else {
64 intersection as f64 / union as f64
65 }
66}
67
68pub fn attention_weighted_assembly(
74 chunks: &[(usize, &str, bool)], total_budget: usize,
76) -> Vec<ChunkDensity> {
77 if chunks.is_empty() {
78 return Vec::new();
79 }
80
81 let mut densities: Vec<ChunkDensity> = chunks
83 .iter()
84 .map(|&(idx, content, is_def)| {
85 let density = compute_density(content, is_def);
86 ChunkDensity {
87 chunk_idx: idx,
88 lexical_diversity: density,
89 structural_weight: if is_def { 1.3 } else { 1.0 },
90 attention_score: density,
91 token_budget: 0,
92 }
93 })
94 .collect();
95
96 let window_size = 20.min(densities.len());
99 for i in 1..densities.len() {
100 let mut max_redundancy = 0.0f64;
101 let start = i.saturating_sub(window_size);
102 for j in start..i {
103 let redundancy = compute_redundancy(chunks[i].1, chunks[j].1);
104 max_redundancy = max_redundancy.max(redundancy);
105 }
106 densities[i].attention_score *= 1.0 - (max_redundancy * 0.7);
107 }
108
109 let total_attention: f64 = densities.iter().map(|d| d.attention_score).sum();
111 if total_attention > 0.0 {
112 for density in &mut densities {
113 let fraction = density.attention_score / total_attention;
114 let equal_share = total_budget as f64 / chunks.len() as f64;
116 let raw_budget = fraction * total_budget as f64;
117 let clamped = raw_budget.max(equal_share * 0.1).min(equal_share * 3.0);
118 density.token_budget = clamped as usize;
119 }
120 } else {
121 let per_chunk = total_budget / chunks.len().max(1);
123 for density in &mut densities {
124 density.token_budget = per_chunk;
125 }
126 }
127
128 densities
129}
130
131pub fn truncate_to_budget(content: &str, token_budget: usize) -> &str {
133 let char_budget = token_budget * 4; if content.len() <= char_budget {
135 return content;
136 }
137
138 let truncated = &content[..char_budget.min(content.len())];
140 match truncated.rfind('\n') {
141 Some(pos) => &content[..=pos],
142 None => truncated,
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn high_diversity_gets_more_budget() {
152 let chunks = vec![
153 (
154 0,
155 "fn unique_function_name() { let x = compute_something(); }",
156 true,
157 ),
158 (
159 1,
160 "test test test test test test test test test test",
161 false,
162 ),
163 ];
164
165 let result = attention_weighted_assembly(&chunks, 1000);
166 assert_eq!(result.len(), 2);
167 assert!(result[0].token_budget > result[1].token_budget);
169 }
170
171 #[test]
172 fn redundant_chunks_get_less_budget() {
173 let chunks = vec![
174 (0, "fn auth_login() { validate_token(jwt) }", true),
175 (1, "fn auth_login() { validate_token(jwt) }", false), (2, "fn database_query() { execute_sql(conn) }", true),
177 ];
178
179 let result = attention_weighted_assembly(&chunks, 1000);
180 assert!(result[1].token_budget < result[0].token_budget);
182 }
183
184 #[test]
185 fn empty_input_returns_empty() {
186 let result = attention_weighted_assembly(&[], 1000);
187 assert!(result.is_empty());
188 }
189
190 #[test]
191 fn compute_density_values_make_sense() {
192 let high = compute_density("fn unique name with diverse tokens here now", false);
193 let low = compute_density("test test test test test test", false);
194 assert!(high > low);
195 }
196
197 #[test]
198 fn redundancy_of_identical_is_one() {
199 let r = compute_redundancy("hello world foo bar", "hello world foo bar");
200 assert!((r - 1.0).abs() < 0.001);
201 }
202
203 #[test]
204 fn redundancy_of_disjoint_is_zero() {
205 let r = compute_redundancy("alpha beta gamma", "delta epsilon zeta");
206 assert!((r - 0.0).abs() < 0.001);
207 }
208
209 #[test]
210 fn truncate_respects_budget() {
211 let content = "line1\nline2\nline3\nline4\nline5\n";
212 let truncated = truncate_to_budget(content, 3); assert!(truncated.len() <= 12);
214 }
215}