lean_ctx/core/neural/
cache_alignment.rs1const CLAUDE_CACHE_MIN_TOKENS: usize = 1024;
14const CLAUDE_MAX_CACHE_BREAKPOINTS: usize = 4;
15
16#[derive(Debug, Clone)]
17pub struct CacheBlock {
18 pub id: String,
19 pub content: String,
20 pub is_stable: bool,
21 pub priority: u8,
22 pub estimated_tokens: usize,
23}
24
25#[derive(Default)]
26pub struct CacheAlignedOutput {
27 blocks: Vec<CacheBlock>,
28}
29
30impl CacheAlignedOutput {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn add_stable_block(&mut self, id: &str, content: String, priority: u8) {
36 let tokens = estimate_tokens(&content);
37 self.blocks.push(CacheBlock {
38 id: id.to_string(),
39 content,
40 is_stable: true,
41 priority,
42 estimated_tokens: tokens,
43 });
44 }
45
46 pub fn add_variable_block(&mut self, id: &str, content: String, priority: u8) {
47 let tokens = estimate_tokens(&content);
48 self.blocks.push(CacheBlock {
49 id: id.to_string(),
50 content,
51 is_stable: false,
52 priority,
53 estimated_tokens: tokens,
54 });
55 }
56
57 pub fn render(&self) -> String {
60 let mut stable: Vec<&CacheBlock> = self.blocks.iter().filter(|b| b.is_stable).collect();
61 let mut variable: Vec<&CacheBlock> = self.blocks.iter().filter(|b| !b.is_stable).collect();
62
63 stable.sort_by_key(|b| b.priority);
64 variable.sort_by_key(|b| b.priority);
65
66 let mut output = String::new();
67
68 for block in &stable {
69 output.push_str(&block.content);
70 output.push('\n');
71 }
72
73 for block in &variable {
74 output.push_str(&block.content);
75 output.push('\n');
76 }
77
78 output
79 }
80
81 pub fn render_with_breakpoints(&self) -> (String, Vec<usize>) {
84 let rendered = self.render();
85 let breakpoints = compute_breakpoints(&rendered);
86 (rendered, breakpoints)
87 }
88
89 pub fn stable_token_count(&self) -> usize {
90 self.blocks
91 .iter()
92 .filter(|b| b.is_stable)
93 .map(|b| b.estimated_tokens)
94 .sum()
95 }
96
97 pub fn variable_token_count(&self) -> usize {
98 self.blocks
99 .iter()
100 .filter(|b| !b.is_stable)
101 .map(|b| b.estimated_tokens)
102 .sum()
103 }
104
105 pub fn cache_efficiency(&self) -> f64 {
106 let total = self.stable_token_count() + self.variable_token_count();
107 if total == 0 {
108 return 0.0;
109 }
110 self.stable_token_count() as f64 / total as f64
111 }
112}
113
114fn compute_breakpoints(content: &str) -> Vec<usize> {
118 let total_tokens = estimate_tokens(content);
119 if total_tokens < CLAUDE_CACHE_MIN_TOKENS {
120 return Vec::new();
121 }
122
123 let mut breakpoints = Vec::new();
124 let lines: Vec<&str> = content.lines().collect();
125 let mut accumulated_tokens = 0;
126 let target_block_size = total_tokens / (CLAUDE_MAX_CACHE_BREAKPOINTS + 1);
127
128 for (i, line) in lines.iter().enumerate() {
129 accumulated_tokens += estimate_tokens(line);
130
131 if accumulated_tokens >= target_block_size
132 && breakpoints.len() < CLAUDE_MAX_CACHE_BREAKPOINTS
133 && is_natural_boundary(line, lines.get(i + 1).copied())
134 {
135 breakpoints.push(i);
136 accumulated_tokens = 0;
137 }
138 }
139
140 breakpoints
141}
142
143fn is_natural_boundary(line: &str, next_line: Option<&str>) -> bool {
144 let trimmed = line.trim();
145 if trimmed.is_empty() {
146 return true;
147 }
148 if trimmed.starts_with("---") || trimmed.starts_with("===") {
149 return true;
150 }
151 if trimmed.starts_with("##") || trimmed.starts_with("//") {
152 return true;
153 }
154 if let Some(next) = next_line {
155 let next_trimmed = next.trim();
156 if next_trimmed.is_empty() || next_trimmed.starts_with("---") {
157 return true;
158 }
159 }
160 false
161}
162
163fn estimate_tokens(text: &str) -> usize {
164 text.len() / 4 + 1
165}
166
167pub fn compute_delta(previous: &str, current: &str) -> DeltaResult {
170 let prev_lines: Vec<&str> = previous.lines().collect();
171 let curr_lines: Vec<&str> = current.lines().collect();
172
173 let common_prefix = prev_lines
174 .iter()
175 .zip(curr_lines.iter())
176 .take_while(|(a, b)| a == b)
177 .count();
178
179 let common_suffix = prev_lines
180 .iter()
181 .rev()
182 .zip(curr_lines.iter().rev())
183 .take_while(|(a, b)| a == b)
184 .count();
185
186 let prev_changed = prev_lines
187 .len()
188 .saturating_sub(common_prefix + common_suffix);
189 let curr_changed = curr_lines
190 .len()
191 .saturating_sub(common_prefix + common_suffix);
192
193 let changed_lines: Vec<String> = curr_lines
194 [common_prefix..curr_lines.len().saturating_sub(common_suffix)]
195 .iter()
196 .map(|l| l.to_string())
197 .collect();
198
199 let prefix_tokens = estimate_tokens(&prev_lines[..common_prefix].to_vec().join("\n"));
200
201 DeltaResult {
202 common_prefix_lines: common_prefix,
203 common_suffix_lines: common_suffix,
204 removed_lines: prev_changed,
205 added_lines: curr_changed,
206 changed_content: changed_lines.join("\n"),
207 cached_prefix_tokens: prefix_tokens,
208 total_delta_tokens: estimate_tokens(&changed_lines.join("\n")),
209 }
210}
211
212#[derive(Debug)]
213pub struct DeltaResult {
214 pub common_prefix_lines: usize,
215 pub common_suffix_lines: usize,
216 pub removed_lines: usize,
217 pub added_lines: usize,
218 pub changed_content: String,
219 pub cached_prefix_tokens: usize,
220 pub total_delta_tokens: usize,
221}
222
223impl DeltaResult {
224 pub fn savings_ratio(&self) -> f64 {
225 let total = self.cached_prefix_tokens + self.total_delta_tokens;
226 if total == 0 {
227 return 0.0;
228 }
229 self.cached_prefix_tokens as f64 / total as f64
230 }
231}
232
233pub fn cache_order_code(content: &str) -> String {
236 let lines: Vec<&str> = content.lines().collect();
237
238 let mut imports = Vec::new();
239 let mut definitions = Vec::new();
240 let mut body = Vec::new();
241
242 for line in &lines {
243 let trimmed = line.trim();
244 if trimmed.starts_with("import ")
245 || trimmed.starts_with("use ")
246 || trimmed.starts_with("from ")
247 || trimmed.starts_with("#include")
248 {
249 imports.push(*line);
250 } else if is_type_definition(trimmed) {
251 definitions.push(*line);
252 } else {
253 body.push(*line);
254 }
255 }
256
257 let mut result = Vec::new();
258 let has_imports = !imports.is_empty();
259 let has_definitions = !definitions.is_empty();
260 let has_body = !body.is_empty();
261 result.extend(imports);
262 if has_imports && has_definitions {
263 result.push("");
264 }
265 result.extend(definitions);
266 if has_definitions && has_body {
267 result.push("");
268 }
269 result.extend(body);
270
271 result.join("\n")
272}
273
274fn is_type_definition(line: &str) -> bool {
275 const STARTERS: &[&str] = &[
276 "struct ",
277 "pub struct ",
278 "enum ",
279 "pub enum ",
280 "trait ",
281 "pub trait ",
282 "type ",
283 "pub type ",
284 "interface ",
285 "export interface ",
286 "export type ",
287 "class ",
288 "export class ",
289 ];
290 STARTERS.iter().any(|s| line.starts_with(s))
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn stable_blocks_come_first() {
299 let mut output = CacheAlignedOutput::new();
300 output.add_variable_block("var1", "variable content".into(), 1);
301 output.add_stable_block("stable1", "stable content".into(), 1);
302
303 let rendered = output.render();
304 let stable_pos = rendered.find("stable content").unwrap();
305 let var_pos = rendered.find("variable content").unwrap();
306 assert!(stable_pos < var_pos);
307 }
308
309 #[test]
310 fn delta_detects_changes() {
311 let prev = "line1\nline2\nline3\nline4";
312 let curr = "line1\nline2\nmodified\nline4";
313
314 let delta = compute_delta(prev, curr);
315 assert_eq!(delta.common_prefix_lines, 2);
316 assert_eq!(delta.common_suffix_lines, 1);
317 assert!(delta.changed_content.contains("modified"));
318 }
319
320 #[test]
321 fn cache_efficiency_high_for_stable() {
322 let mut output = CacheAlignedOutput::new();
323 output.add_stable_block("s1", "x".repeat(1000), 1);
324 output.add_variable_block("v1", "y".repeat(100), 1);
325
326 assert!(output.cache_efficiency() > 0.8);
327 }
328
329 #[test]
330 fn code_reordering_puts_imports_first() {
331 let code = "fn main() {}\nuse std::io;\nimport os\nstruct Foo;";
332 let reordered = cache_order_code(code);
333 let lines: Vec<&str> = reordered.lines().collect();
334 assert!(lines[0].starts_with("use ") || lines[0].starts_with("import "));
335 }
336}