lean_ctx/core/neural/
cache_alignment.rs1const CLAUDE_CACHE_MIN_TOKENS: usize = 1024;
14const CLAUDE_MAX_CACHE_BREAKPOINTS: usize = 4;
15
16#[derive(Debug, Clone)]
17pub struct CacheBlock {
18 pub id: String,
19 pub content: String,
20 pub is_stable: bool,
21 pub priority: u8,
22 pub estimated_tokens: usize,
23}
24
25#[derive(Default)]
26pub struct CacheAlignedOutput {
27 blocks: Vec<CacheBlock>,
28}
29
30impl CacheAlignedOutput {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn add_stable_block(&mut self, id: &str, content: String, priority: u8) {
36 let tokens = estimate_tokens(&content);
37 self.blocks.push(CacheBlock {
38 id: id.to_string(),
39 content,
40 is_stable: true,
41 priority,
42 estimated_tokens: tokens,
43 });
44 }
45
46 pub fn add_variable_block(&mut self, id: &str, content: String, priority: u8) {
47 let tokens = estimate_tokens(&content);
48 self.blocks.push(CacheBlock {
49 id: id.to_string(),
50 content,
51 is_stable: false,
52 priority,
53 estimated_tokens: tokens,
54 });
55 }
56
57 pub fn render(&self) -> String {
60 let mut stable: Vec<&CacheBlock> = self.blocks.iter().filter(|b| b.is_stable).collect();
61 let mut variable: Vec<&CacheBlock> = self.blocks.iter().filter(|b| !b.is_stable).collect();
62
63 stable.sort_by_key(|b| b.priority);
64 variable.sort_by_key(|b| b.priority);
65
66 let mut output = String::new();
67
68 for block in &stable {
69 output.push_str(&block.content);
70 output.push('\n');
71 }
72
73 for block in &variable {
74 output.push_str(&block.content);
75 output.push('\n');
76 }
77
78 output
79 }
80
81 pub fn render_with_breakpoints(&self) -> (String, Vec<usize>) {
84 let rendered = self.render();
85 let breakpoints = compute_breakpoints(&rendered);
86 (rendered, breakpoints)
87 }
88
89 pub fn stable_token_count(&self) -> usize {
90 self.blocks
91 .iter()
92 .filter(|b| b.is_stable)
93 .map(|b| b.estimated_tokens)
94 .sum()
95 }
96
97 pub fn variable_token_count(&self) -> usize {
98 self.blocks
99 .iter()
100 .filter(|b| !b.is_stable)
101 .map(|b| b.estimated_tokens)
102 .sum()
103 }
104
105 pub fn cache_efficiency(&self) -> f64 {
106 let total = self.stable_token_count() + self.variable_token_count();
107 if total == 0 {
108 return 0.0;
109 }
110 self.stable_token_count() as f64 / total as f64
111 }
112}
113
114fn compute_breakpoints(content: &str) -> Vec<usize> {
118 let total_tokens = estimate_tokens(content);
119 if total_tokens < CLAUDE_CACHE_MIN_TOKENS {
120 return Vec::new();
121 }
122
123 let mut breakpoints = Vec::new();
124 let lines: Vec<&str> = content.lines().collect();
125 let mut accumulated_tokens = 0;
126 let target_block_size = total_tokens / (CLAUDE_MAX_CACHE_BREAKPOINTS + 1);
127
128 for (i, line) in lines.iter().enumerate() {
129 accumulated_tokens += estimate_tokens(line);
130
131 if accumulated_tokens >= target_block_size
132 && breakpoints.len() < CLAUDE_MAX_CACHE_BREAKPOINTS
133 && is_natural_boundary(line, lines.get(i + 1).copied())
134 {
135 breakpoints.push(i);
136 accumulated_tokens = 0;
137 }
138 }
139
140 breakpoints
141}
142
143fn is_natural_boundary(line: &str, next_line: Option<&str>) -> bool {
144 let trimmed = line.trim();
145 if trimmed.is_empty() {
146 return true;
147 }
148 if trimmed.starts_with("---") || trimmed.starts_with("===") {
149 return true;
150 }
151 if trimmed.starts_with("##") || trimmed.starts_with("//") {
152 return true;
153 }
154 if let Some(next) = next_line {
155 let next_trimmed = next.trim();
156 if next_trimmed.is_empty() || next_trimmed.starts_with("---") {
157 return true;
158 }
159 }
160 false
161}
162
163fn estimate_tokens(text: &str) -> usize {
164 text.len() / 4 + 1
165}
166
167pub fn compute_delta(previous: &str, current: &str) -> DeltaResult {
170 let prev_lines: Vec<&str> = previous.lines().collect();
171 let curr_lines: Vec<&str> = current.lines().collect();
172
173 let common_prefix = prev_lines
174 .iter()
175 .zip(curr_lines.iter())
176 .take_while(|(a, b)| a == b)
177 .count();
178
179 let common_suffix = prev_lines
180 .iter()
181 .rev()
182 .zip(curr_lines.iter().rev())
183 .take_while(|(a, b)| a == b)
184 .count();
185
186 let prev_changed = prev_lines
187 .len()
188 .saturating_sub(common_prefix + common_suffix);
189 let curr_changed = curr_lines
190 .len()
191 .saturating_sub(common_prefix + common_suffix);
192
193 let changed_lines: Vec<String> = curr_lines
194 [common_prefix..curr_lines.len().saturating_sub(common_suffix)]
195 .iter()
196 .map(|l| l.to_string())
197 .collect();
198
199 let prefix_tokens = estimate_tokens(
200 &prev_lines[..common_prefix].to_vec().join("\n"),
201 );
202
203 DeltaResult {
204 common_prefix_lines: common_prefix,
205 common_suffix_lines: common_suffix,
206 removed_lines: prev_changed,
207 added_lines: curr_changed,
208 changed_content: changed_lines.join("\n"),
209 cached_prefix_tokens: prefix_tokens,
210 total_delta_tokens: estimate_tokens(&changed_lines.join("\n")),
211 }
212}
213
214#[derive(Debug)]
215pub struct DeltaResult {
216 pub common_prefix_lines: usize,
217 pub common_suffix_lines: usize,
218 pub removed_lines: usize,
219 pub added_lines: usize,
220 pub changed_content: String,
221 pub cached_prefix_tokens: usize,
222 pub total_delta_tokens: usize,
223}
224
225impl DeltaResult {
226 pub fn savings_ratio(&self) -> f64 {
227 let total = self.cached_prefix_tokens + self.total_delta_tokens;
228 if total == 0 {
229 return 0.0;
230 }
231 self.cached_prefix_tokens as f64 / total as f64
232 }
233}
234
235pub fn cache_order_code(content: &str) -> String {
238 let lines: Vec<&str> = content.lines().collect();
239
240 let mut imports = Vec::new();
241 let mut definitions = Vec::new();
242 let mut body = Vec::new();
243
244 for line in &lines {
245 let trimmed = line.trim();
246 if trimmed.starts_with("import ")
247 || trimmed.starts_with("use ")
248 || trimmed.starts_with("from ")
249 || trimmed.starts_with("#include")
250 {
251 imports.push(*line);
252 } else if is_type_definition(trimmed) {
253 definitions.push(*line);
254 } else {
255 body.push(*line);
256 }
257 }
258
259 let mut result = Vec::new();
260 let has_imports = !imports.is_empty();
261 let has_definitions = !definitions.is_empty();
262 let has_body = !body.is_empty();
263 result.extend(imports);
264 if has_imports && has_definitions {
265 result.push("");
266 }
267 result.extend(definitions);
268 if has_definitions && has_body {
269 result.push("");
270 }
271 result.extend(body);
272
273 result.join("\n")
274}
275
276fn is_type_definition(line: &str) -> bool {
277 const STARTERS: &[&str] = &[
278 "struct ",
279 "pub struct ",
280 "enum ",
281 "pub enum ",
282 "trait ",
283 "pub trait ",
284 "type ",
285 "pub type ",
286 "interface ",
287 "export interface ",
288 "export type ",
289 "class ",
290 "export class ",
291 ];
292 STARTERS.iter().any(|s| line.starts_with(s))
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn stable_blocks_come_first() {
301 let mut output = CacheAlignedOutput::new();
302 output.add_variable_block("var1", "variable content".into(), 1);
303 output.add_stable_block("stable1", "stable content".into(), 1);
304
305 let rendered = output.render();
306 let stable_pos = rendered.find("stable content").unwrap();
307 let var_pos = rendered.find("variable content").unwrap();
308 assert!(stable_pos < var_pos);
309 }
310
311 #[test]
312 fn delta_detects_changes() {
313 let prev = "line1\nline2\nline3\nline4";
314 let curr = "line1\nline2\nmodified\nline4";
315
316 let delta = compute_delta(prev, curr);
317 assert_eq!(delta.common_prefix_lines, 2);
318 assert_eq!(delta.common_suffix_lines, 1);
319 assert!(delta.changed_content.contains("modified"));
320 }
321
322 #[test]
323 fn cache_efficiency_high_for_stable() {
324 let mut output = CacheAlignedOutput::new();
325 output.add_stable_block("s1", "x".repeat(1000), 1);
326 output.add_variable_block("v1", "y".repeat(100), 1);
327
328 assert!(output.cache_efficiency() > 0.8);
329 }
330
331 #[test]
332 fn code_reordering_puts_imports_first() {
333 let code = "fn main() {}\nuse std::io;\nimport os\nstruct Foo;";
334 let reordered = cache_order_code(code);
335 let lines: Vec<&str> = reordered.lines().collect();
336 assert!(lines[0].starts_with("use ") || lines[0].starts_with("import "));
337 }
338}