Skip to main content

lean_ctx/core/neural/
cache_alignment.rs

1//! KV-Cache alignment for commercial LLM prompt caching.
2//!
3//! Claude's prompt caching stores KV-tensors for byte-exact prefix matches.
4//! GPT models have similar mechanisms. This module ensures lean-ctx outputs
5//! are structured to maximize cache hit rates.
6//!
7//! Key strategies:
8//! 1. Stable prefix: invariant content (instructions, tool defs) comes first
9//! 2. Cache-block alignment: content segmented to match provider breakpoints
10//! 3. Delta-only after cached prefix: only send changes, rest stays in KV-cache
11//! 4. Deterministic ordering: same inputs always produce byte-identical output
12
13const CLAUDE_CACHE_MIN_TOKENS: usize = 1024;
14const CLAUDE_MAX_CACHE_BREAKPOINTS: usize = 4;
15
16#[derive(Debug, Clone)]
17pub struct CacheBlock {
18    pub id: String,
19    pub content: String,
20    pub is_stable: bool,
21    pub priority: u8,
22    pub estimated_tokens: usize,
23}
24
25#[derive(Default)]
26pub struct CacheAlignedOutput {
27    blocks: Vec<CacheBlock>,
28}
29
30impl CacheAlignedOutput {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    pub fn add_stable_block(&mut self, id: &str, content: String, priority: u8) {
36        let tokens = estimate_tokens(&content);
37        self.blocks.push(CacheBlock {
38            id: id.to_string(),
39            content,
40            is_stable: true,
41            priority,
42            estimated_tokens: tokens,
43        });
44    }
45
46    pub fn add_variable_block(&mut self, id: &str, content: String, priority: u8) {
47        let tokens = estimate_tokens(&content);
48        self.blocks.push(CacheBlock {
49            id: id.to_string(),
50            content,
51            is_stable: false,
52            priority,
53            estimated_tokens: tokens,
54        });
55    }
56
57    /// Render the output with cache-optimal ordering:
58    /// stable blocks first (sorted by priority), then variable blocks.
59    pub fn render(&self) -> String {
60        let mut stable: Vec<&CacheBlock> = self.blocks.iter().filter(|b| b.is_stable).collect();
61        let mut variable: Vec<&CacheBlock> = self.blocks.iter().filter(|b| !b.is_stable).collect();
62
63        stable.sort_by_key(|b| b.priority);
64        variable.sort_by_key(|b| b.priority);
65
66        let mut output = String::new();
67
68        for block in &stable {
69            output.push_str(&block.content);
70            output.push('\n');
71        }
72
73        for block in &variable {
74            output.push_str(&block.content);
75            output.push('\n');
76        }
77
78        output
79    }
80
81    /// Render with explicit cache breakpoint markers for Claude.
82    /// Places up to CLAUDE_MAX_CACHE_BREAKPOINTS markers at optimal positions.
83    pub fn render_with_breakpoints(&self) -> (String, Vec<usize>) {
84        let rendered = self.render();
85        let breakpoints = compute_breakpoints(&rendered);
86        (rendered, breakpoints)
87    }
88
89    pub fn stable_token_count(&self) -> usize {
90        self.blocks
91            .iter()
92            .filter(|b| b.is_stable)
93            .map(|b| b.estimated_tokens)
94            .sum()
95    }
96
97    pub fn variable_token_count(&self) -> usize {
98        self.blocks
99            .iter()
100            .filter(|b| !b.is_stable)
101            .map(|b| b.estimated_tokens)
102            .sum()
103    }
104
105    pub fn cache_efficiency(&self) -> f64 {
106        let total = self.stable_token_count() + self.variable_token_count();
107        if total == 0 {
108            return 0.0;
109        }
110        self.stable_token_count() as f64 / total as f64
111    }
112}
113
114/// Compute optimal cache breakpoint positions in the output.
115/// Tries to place breakpoints at natural content boundaries
116/// that align with Claude's minimum cache block size.
117fn compute_breakpoints(content: &str) -> Vec<usize> {
118    let total_tokens = estimate_tokens(content);
119    if total_tokens < CLAUDE_CACHE_MIN_TOKENS {
120        return Vec::new();
121    }
122
123    let mut breakpoints = Vec::new();
124    let lines: Vec<&str> = content.lines().collect();
125    let mut accumulated_tokens = 0;
126    let target_block_size = total_tokens / (CLAUDE_MAX_CACHE_BREAKPOINTS + 1);
127
128    for (i, line) in lines.iter().enumerate() {
129        accumulated_tokens += estimate_tokens(line);
130
131        if accumulated_tokens >= target_block_size
132            && breakpoints.len() < CLAUDE_MAX_CACHE_BREAKPOINTS
133            && is_natural_boundary(line, lines.get(i + 1).copied())
134        {
135            breakpoints.push(i);
136            accumulated_tokens = 0;
137        }
138    }
139
140    breakpoints
141}
142
143fn is_natural_boundary(line: &str, next_line: Option<&str>) -> bool {
144    let trimmed = line.trim();
145    if trimmed.is_empty() {
146        return true;
147    }
148    if trimmed.starts_with("---") || trimmed.starts_with("===") {
149        return true;
150    }
151    if trimmed.starts_with("##") || trimmed.starts_with("//") {
152        return true;
153    }
154    if let Some(next) = next_line {
155        let next_trimmed = next.trim();
156        if next_trimmed.is_empty() || next_trimmed.starts_with("---") {
157            return true;
158        }
159    }
160    false
161}
162
163fn estimate_tokens(text: &str) -> usize {
164    text.len() / 4 + 1
165}
166
167/// Generate a delta between two versions of content for cache-efficient updates.
168/// Returns only the changed portions, prefixed with stable context identifiers.
169pub fn compute_delta(previous: &str, current: &str) -> DeltaResult {
170    let prev_lines: Vec<&str> = previous.lines().collect();
171    let curr_lines: Vec<&str> = current.lines().collect();
172
173    let common_prefix = prev_lines
174        .iter()
175        .zip(curr_lines.iter())
176        .take_while(|(a, b)| a == b)
177        .count();
178
179    let common_suffix = prev_lines
180        .iter()
181        .rev()
182        .zip(curr_lines.iter().rev())
183        .take_while(|(a, b)| a == b)
184        .count();
185
186    let prev_changed = prev_lines
187        .len()
188        .saturating_sub(common_prefix + common_suffix);
189    let curr_changed = curr_lines
190        .len()
191        .saturating_sub(common_prefix + common_suffix);
192
193    let changed_lines: Vec<String> = curr_lines
194        [common_prefix..curr_lines.len().saturating_sub(common_suffix)]
195        .iter()
196        .map(|l| l.to_string())
197        .collect();
198
199    let prefix_tokens = estimate_tokens(&prev_lines[..common_prefix].to_vec().join("\n"));
200
201    DeltaResult {
202        common_prefix_lines: common_prefix,
203        common_suffix_lines: common_suffix,
204        removed_lines: prev_changed,
205        added_lines: curr_changed,
206        changed_content: changed_lines.join("\n"),
207        cached_prefix_tokens: prefix_tokens,
208        total_delta_tokens: estimate_tokens(&changed_lines.join("\n")),
209    }
210}
211
212#[derive(Debug)]
213pub struct DeltaResult {
214    pub common_prefix_lines: usize,
215    pub common_suffix_lines: usize,
216    pub removed_lines: usize,
217    pub added_lines: usize,
218    pub changed_content: String,
219    pub cached_prefix_tokens: usize,
220    pub total_delta_tokens: usize,
221}
222
223impl DeltaResult {
224    pub fn savings_ratio(&self) -> f64 {
225        let total = self.cached_prefix_tokens + self.total_delta_tokens;
226        if total == 0 {
227            return 0.0;
228        }
229        self.cached_prefix_tokens as f64 / total as f64
230    }
231}
232
233/// Order file contents for maximum cache reuse across tool calls.
234/// Stable elements (imports, type defs) first, then variable elements (function bodies).
235pub fn cache_order_code(content: &str) -> String {
236    let lines: Vec<&str> = content.lines().collect();
237
238    let mut imports = Vec::new();
239    let mut definitions = Vec::new();
240    let mut body = Vec::new();
241
242    for line in &lines {
243        let trimmed = line.trim();
244        if trimmed.starts_with("import ")
245            || trimmed.starts_with("use ")
246            || trimmed.starts_with("from ")
247            || trimmed.starts_with("#include")
248        {
249            imports.push(*line);
250        } else if is_type_definition(trimmed) {
251            definitions.push(*line);
252        } else {
253            body.push(*line);
254        }
255    }
256
257    let mut result = Vec::new();
258    let has_imports = !imports.is_empty();
259    let has_definitions = !definitions.is_empty();
260    let has_body = !body.is_empty();
261    result.extend(imports);
262    if has_imports && has_definitions {
263        result.push("");
264    }
265    result.extend(definitions);
266    if has_definitions && has_body {
267        result.push("");
268    }
269    result.extend(body);
270
271    result.join("\n")
272}
273
274fn is_type_definition(line: &str) -> bool {
275    const STARTERS: &[&str] = &[
276        "struct ",
277        "pub struct ",
278        "enum ",
279        "pub enum ",
280        "trait ",
281        "pub trait ",
282        "type ",
283        "pub type ",
284        "interface ",
285        "export interface ",
286        "export type ",
287        "class ",
288        "export class ",
289    ];
290    STARTERS.iter().any(|s| line.starts_with(s))
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn stable_blocks_come_first() {
299        let mut output = CacheAlignedOutput::new();
300        output.add_variable_block("var1", "variable content".into(), 1);
301        output.add_stable_block("stable1", "stable content".into(), 1);
302
303        let rendered = output.render();
304        let stable_pos = rendered.find("stable content").unwrap();
305        let var_pos = rendered.find("variable content").unwrap();
306        assert!(stable_pos < var_pos);
307    }
308
309    #[test]
310    fn delta_detects_changes() {
311        let prev = "line1\nline2\nline3\nline4";
312        let curr = "line1\nline2\nmodified\nline4";
313
314        let delta = compute_delta(prev, curr);
315        assert_eq!(delta.common_prefix_lines, 2);
316        assert_eq!(delta.common_suffix_lines, 1);
317        assert!(delta.changed_content.contains("modified"));
318    }
319
320    #[test]
321    fn cache_efficiency_high_for_stable() {
322        let mut output = CacheAlignedOutput::new();
323        output.add_stable_block("s1", "x".repeat(1000), 1);
324        output.add_variable_block("v1", "y".repeat(100), 1);
325
326        assert!(output.cache_efficiency() > 0.8);
327    }
328
329    #[test]
330    fn code_reordering_puts_imports_first() {
331        let code = "fn main() {}\nuse std::io;\nimport os\nstruct Foo;";
332        let reordered = cache_order_code(code);
333        let lines: Vec<&str> = reordered.lines().collect();
334        assert!(lines[0].starts_with("use ") || lines[0].starts_with("import "));
335    }
336}