Skip to main content

lean_ctx/core/neural/
cache_alignment.rs

1//! KV-Cache alignment for commercial LLM prompt caching.
2//!
3//! Claude's prompt caching stores KV-tensors for byte-exact prefix matches.
4//! GPT models have similar mechanisms. This module ensures lean-ctx outputs
5//! are structured to maximize cache hit rates.
6//!
7//! Key strategies:
8//! 1. Stable prefix: invariant content (instructions, tool defs) comes first
9//! 2. Cache-block alignment: content segmented to match provider breakpoints
10//! 3. Delta-only after cached prefix: only send changes, rest stays in KV-cache
11//! 4. Deterministic ordering: same inputs always produce byte-identical output
12
13const CLAUDE_CACHE_MIN_TOKENS: usize = 1024;
14const CLAUDE_MAX_CACHE_BREAKPOINTS: usize = 4;
15
16#[derive(Debug, Clone)]
17pub struct CacheBlock {
18    pub id: String,
19    pub content: String,
20    pub is_stable: bool,
21    pub priority: u8,
22    pub estimated_tokens: usize,
23}
24
25#[derive(Default)]
26pub struct CacheAlignedOutput {
27    blocks: Vec<CacheBlock>,
28}
29
30impl CacheAlignedOutput {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    pub fn add_stable_block(&mut self, id: &str, content: String, priority: u8) {
36        let tokens = estimate_tokens(&content);
37        self.blocks.push(CacheBlock {
38            id: id.to_string(),
39            content,
40            is_stable: true,
41            priority,
42            estimated_tokens: tokens,
43        });
44    }
45
46    pub fn add_variable_block(&mut self, id: &str, content: String, priority: u8) {
47        let tokens = estimate_tokens(&content);
48        self.blocks.push(CacheBlock {
49            id: id.to_string(),
50            content,
51            is_stable: false,
52            priority,
53            estimated_tokens: tokens,
54        });
55    }
56
57    /// Render the output with cache-optimal ordering:
58    /// stable blocks first (sorted by priority), then variable blocks.
59    pub fn render(&self) -> String {
60        let mut stable: Vec<&CacheBlock> = self.blocks.iter().filter(|b| b.is_stable).collect();
61        let mut variable: Vec<&CacheBlock> = self.blocks.iter().filter(|b| !b.is_stable).collect();
62
63        stable.sort_by_key(|b| b.priority);
64        variable.sort_by_key(|b| b.priority);
65
66        let mut output = String::new();
67
68        for block in &stable {
69            output.push_str(&block.content);
70            output.push('\n');
71        }
72
73        for block in &variable {
74            output.push_str(&block.content);
75            output.push('\n');
76        }
77
78        output
79    }
80
81    /// Render with explicit cache breakpoint markers for Claude.
82    /// Places up to CLAUDE_MAX_CACHE_BREAKPOINTS markers at optimal positions.
83    pub fn render_with_breakpoints(&self) -> (String, Vec<usize>) {
84        let rendered = self.render();
85        let breakpoints = compute_breakpoints(&rendered);
86        (rendered, breakpoints)
87    }
88
89    pub fn stable_token_count(&self) -> usize {
90        self.blocks
91            .iter()
92            .filter(|b| b.is_stable)
93            .map(|b| b.estimated_tokens)
94            .sum()
95    }
96
97    pub fn variable_token_count(&self) -> usize {
98        self.blocks
99            .iter()
100            .filter(|b| !b.is_stable)
101            .map(|b| b.estimated_tokens)
102            .sum()
103    }
104
105    pub fn cache_efficiency(&self) -> f64 {
106        let total = self.stable_token_count() + self.variable_token_count();
107        if total == 0 {
108            return 0.0;
109        }
110        self.stable_token_count() as f64 / total as f64
111    }
112}
113
114/// Compute optimal cache breakpoint positions in the output.
115/// Tries to place breakpoints at natural content boundaries
116/// that align with Claude's minimum cache block size.
117fn compute_breakpoints(content: &str) -> Vec<usize> {
118    let total_tokens = estimate_tokens(content);
119    if total_tokens < CLAUDE_CACHE_MIN_TOKENS {
120        return Vec::new();
121    }
122
123    let mut breakpoints = Vec::new();
124    let lines: Vec<&str> = content.lines().collect();
125    let mut accumulated_tokens = 0;
126    let target_block_size = total_tokens / (CLAUDE_MAX_CACHE_BREAKPOINTS + 1);
127
128    for (i, line) in lines.iter().enumerate() {
129        accumulated_tokens += estimate_tokens(line);
130
131        if accumulated_tokens >= target_block_size
132            && breakpoints.len() < CLAUDE_MAX_CACHE_BREAKPOINTS
133            && is_natural_boundary(line, lines.get(i + 1).copied())
134        {
135            breakpoints.push(i);
136            accumulated_tokens = 0;
137        }
138    }
139
140    breakpoints
141}
142
143fn is_natural_boundary(line: &str, next_line: Option<&str>) -> bool {
144    let trimmed = line.trim();
145    if trimmed.is_empty() {
146        return true;
147    }
148    if trimmed.starts_with("---") || trimmed.starts_with("===") {
149        return true;
150    }
151    if trimmed.starts_with("##") || trimmed.starts_with("//") {
152        return true;
153    }
154    if let Some(next) = next_line {
155        let next_trimmed = next.trim();
156        if next_trimmed.is_empty() || next_trimmed.starts_with("---") {
157            return true;
158        }
159    }
160    false
161}
162
163fn estimate_tokens(text: &str) -> usize {
164    text.len() / 4 + 1
165}
166
167/// Generate a delta between two versions of content for cache-efficient updates.
168/// Returns only the changed portions, prefixed with stable context identifiers.
169pub fn compute_delta(previous: &str, current: &str) -> DeltaResult {
170    let prev_lines: Vec<&str> = previous.lines().collect();
171    let curr_lines: Vec<&str> = current.lines().collect();
172
173    let common_prefix = prev_lines
174        .iter()
175        .zip(curr_lines.iter())
176        .take_while(|(a, b)| a == b)
177        .count();
178
179    let common_suffix = prev_lines
180        .iter()
181        .rev()
182        .zip(curr_lines.iter().rev())
183        .take_while(|(a, b)| a == b)
184        .count();
185
186    let prev_changed = prev_lines
187        .len()
188        .saturating_sub(common_prefix + common_suffix);
189    let curr_changed = curr_lines
190        .len()
191        .saturating_sub(common_prefix + common_suffix);
192
193    let changed_lines: Vec<String> = curr_lines
194        [common_prefix..curr_lines.len().saturating_sub(common_suffix)]
195        .iter()
196        .map(|l| l.to_string())
197        .collect();
198
199    let prefix_tokens = estimate_tokens(
200        &prev_lines[..common_prefix].to_vec().join("\n"),
201    );
202
203    DeltaResult {
204        common_prefix_lines: common_prefix,
205        common_suffix_lines: common_suffix,
206        removed_lines: prev_changed,
207        added_lines: curr_changed,
208        changed_content: changed_lines.join("\n"),
209        cached_prefix_tokens: prefix_tokens,
210        total_delta_tokens: estimate_tokens(&changed_lines.join("\n")),
211    }
212}
213
214#[derive(Debug)]
215pub struct DeltaResult {
216    pub common_prefix_lines: usize,
217    pub common_suffix_lines: usize,
218    pub removed_lines: usize,
219    pub added_lines: usize,
220    pub changed_content: String,
221    pub cached_prefix_tokens: usize,
222    pub total_delta_tokens: usize,
223}
224
225impl DeltaResult {
226    pub fn savings_ratio(&self) -> f64 {
227        let total = self.cached_prefix_tokens + self.total_delta_tokens;
228        if total == 0 {
229            return 0.0;
230        }
231        self.cached_prefix_tokens as f64 / total as f64
232    }
233}
234
235/// Order file contents for maximum cache reuse across tool calls.
236/// Stable elements (imports, type defs) first, then variable elements (function bodies).
237pub fn cache_order_code(content: &str) -> String {
238    let lines: Vec<&str> = content.lines().collect();
239
240    let mut imports = Vec::new();
241    let mut definitions = Vec::new();
242    let mut body = Vec::new();
243
244    for line in &lines {
245        let trimmed = line.trim();
246        if trimmed.starts_with("import ")
247            || trimmed.starts_with("use ")
248            || trimmed.starts_with("from ")
249            || trimmed.starts_with("#include")
250        {
251            imports.push(*line);
252        } else if is_type_definition(trimmed) {
253            definitions.push(*line);
254        } else {
255            body.push(*line);
256        }
257    }
258
259    let mut result = Vec::new();
260    let has_imports = !imports.is_empty();
261    let has_definitions = !definitions.is_empty();
262    let has_body = !body.is_empty();
263    result.extend(imports);
264    if has_imports && has_definitions {
265        result.push("");
266    }
267    result.extend(definitions);
268    if has_definitions && has_body {
269        result.push("");
270    }
271    result.extend(body);
272
273    result.join("\n")
274}
275
276fn is_type_definition(line: &str) -> bool {
277    const STARTERS: &[&str] = &[
278        "struct ",
279        "pub struct ",
280        "enum ",
281        "pub enum ",
282        "trait ",
283        "pub trait ",
284        "type ",
285        "pub type ",
286        "interface ",
287        "export interface ",
288        "export type ",
289        "class ",
290        "export class ",
291    ];
292    STARTERS.iter().any(|s| line.starts_with(s))
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn stable_blocks_come_first() {
301        let mut output = CacheAlignedOutput::new();
302        output.add_variable_block("var1", "variable content".into(), 1);
303        output.add_stable_block("stable1", "stable content".into(), 1);
304
305        let rendered = output.render();
306        let stable_pos = rendered.find("stable content").unwrap();
307        let var_pos = rendered.find("variable content").unwrap();
308        assert!(stable_pos < var_pos);
309    }
310
311    #[test]
312    fn delta_detects_changes() {
313        let prev = "line1\nline2\nline3\nline4";
314        let curr = "line1\nline2\nmodified\nline4";
315
316        let delta = compute_delta(prev, curr);
317        assert_eq!(delta.common_prefix_lines, 2);
318        assert_eq!(delta.common_suffix_lines, 1);
319        assert!(delta.changed_content.contains("modified"));
320    }
321
322    #[test]
323    fn cache_efficiency_high_for_stable() {
324        let mut output = CacheAlignedOutput::new();
325        output.add_stable_block("s1", "x".repeat(1000), 1);
326        output.add_variable_block("v1", "y".repeat(100), 1);
327
328        assert!(output.cache_efficiency() > 0.8);
329    }
330
331    #[test]
332    fn code_reordering_puts_imports_first() {
333        let code = "fn main() {}\nuse std::io;\nimport os\nstruct Foo;";
334        let reordered = cache_order_code(code);
335        let lines: Vec<&str> = reordered.lines().collect();
336        assert!(lines[0].starts_with("use ") || lines[0].starts_with("import "));
337    }
338}