1use std::collections::HashMap;
16
17use md5::{Digest, Md5};
18
19#[derive(Debug, Clone)]
21pub struct CacheableSection {
22 pub id: String,
23 pub content: String,
24 pub hash: String,
25 pub priority: SectionPriority,
26 pub stable: bool,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
30pub enum SectionPriority {
31 System = 0,
32 ProjectStructure = 1,
33 TypeDefinitions = 2,
34 Dependencies = 3,
35 RecentContext = 4,
36 CurrentTask = 5,
37}
38
39#[derive(Debug)]
41pub struct ProviderCacheState {
42 sent_hashes: HashMap<String, String>,
43 cache_hits: u64,
44 cache_misses: u64,
45}
46
47impl ProviderCacheState {
48 pub fn new() -> Self {
49 Self {
50 sent_hashes: HashMap::new(),
51 cache_hits: 0,
52 cache_misses: 0,
53 }
54 }
55
56 pub fn needs_update(&self, section: &CacheableSection) -> bool {
58 match self.sent_hashes.get(§ion.id) {
59 Some(prev_hash) => prev_hash != §ion.hash,
60 None => true,
61 }
62 }
63
64 pub fn mark_sent(&mut self, section: &CacheableSection) {
66 self.sent_hashes
67 .insert(section.id.clone(), section.hash.clone());
68 }
69
70 pub fn filter_changed<'a>(
73 &mut self,
74 sections: &'a [CacheableSection],
75 ) -> Vec<&'a CacheableSection> {
76 let mut result = Vec::new();
77 for section in sections {
78 if self.needs_update(section) {
79 self.cache_misses += 1;
80 result.push(section);
81 } else {
82 self.cache_hits += 1;
83 }
84 }
85 result
86 }
87
88 pub fn cache_hit_rate(&self) -> f64 {
89 let total = self.cache_hits + self.cache_misses;
90 if total == 0 {
91 return 0.0;
92 }
93 self.cache_hits as f64 / total as f64
94 }
95
96 pub fn reset(&mut self) {
97 self.sent_hashes.clear();
98 self.cache_hits = 0;
99 self.cache_misses = 0;
100 }
101}
102
103impl Default for ProviderCacheState {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl CacheableSection {
110 pub fn new(id: &str, content: String, priority: SectionPriority, stable: bool) -> Self {
111 let hash = content_hash(&content);
112 Self {
113 id: id.to_string(),
114 content,
115 hash,
116 priority,
117 stable,
118 }
119 }
120}
121
122pub fn order_for_caching(mut sections: Vec<CacheableSection>) -> Vec<CacheableSection> {
126 sections.sort_by(|a, b| {
127 a.stable
128 .cmp(&b.stable)
129 .reverse()
130 .then(a.priority.cmp(&b.priority))
131 });
132 sections
133}
134
135pub fn render_with_cache_hints(sections: &[CacheableSection]) -> String {
137 let mut output = String::new();
138 let mut last_stable = true;
139
140 for section in sections {
141 if last_stable && !section.stable {
142 output.push_str("\n--- dynamic context ---\n");
143 }
144 output.push_str(§ion.content);
145 output.push('\n');
146 last_stable = section.stable;
147 }
148
149 output
150}
151
152fn content_hash(content: &str) -> String {
153 let mut hasher = Md5::new();
154 hasher.update(content.as_bytes());
155 format!("{:x}", hasher.finalize())
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn section_hash_deterministic() {
164 let s1 = CacheableSection::new("id", "content".into(), SectionPriority::System, true);
165 let s2 = CacheableSection::new("id", "content".into(), SectionPriority::System, true);
166 assert_eq!(s1.hash, s2.hash);
167 }
168
169 #[test]
170 fn section_hash_changes_with_content() {
171 let s1 = CacheableSection::new("id", "content_v1".into(), SectionPriority::System, true);
172 let s2 = CacheableSection::new("id", "content_v2".into(), SectionPriority::System, true);
173 assert_ne!(s1.hash, s2.hash);
174 }
175
176 #[test]
177 fn needs_update_new_section() {
178 let state = ProviderCacheState::new();
179 let section =
180 CacheableSection::new("test", "content".into(), SectionPriority::System, true);
181 assert!(state.needs_update(§ion));
182 }
183
184 #[test]
185 fn needs_update_unchanged() {
186 let mut state = ProviderCacheState::new();
187 let section =
188 CacheableSection::new("test", "content".into(), SectionPriority::System, true);
189 state.mark_sent(§ion);
190 assert!(!state.needs_update(§ion));
191 }
192
193 #[test]
194 fn needs_update_changed() {
195 let mut state = ProviderCacheState::new();
196 let s1 = CacheableSection::new("test", "v1".into(), SectionPriority::System, true);
197 state.mark_sent(&s1);
198 let s2 = CacheableSection::new("test", "v2".into(), SectionPriority::System, true);
199 assert!(state.needs_update(&s2));
200 }
201
202 #[test]
203 fn filter_changed_tracks_hits() {
204 let mut state = ProviderCacheState::new();
205 let s1 = CacheableSection::new("a", "stable".into(), SectionPriority::System, true);
206 state.mark_sent(&s1);
207
208 let sections = vec![
209 s1.clone(),
210 CacheableSection::new("b", "new".into(), SectionPriority::CurrentTask, false),
211 ];
212 let changed = state.filter_changed(§ions);
213 assert_eq!(changed.len(), 1);
214 assert_eq!(changed[0].id, "b");
215 assert!((state.cache_hit_rate() - 0.5).abs() < 1e-6);
216 }
217
218 #[test]
219 fn order_stable_first() {
220 let sections = vec![
221 CacheableSection::new(
222 "task",
223 "current".into(),
224 SectionPriority::CurrentTask,
225 false,
226 ),
227 CacheableSection::new("system", "system".into(), SectionPriority::System, true),
228 CacheableSection::new(
229 "types",
230 "types".into(),
231 SectionPriority::TypeDefinitions,
232 true,
233 ),
234 ];
235 let ordered = order_for_caching(sections);
236 assert!(ordered[0].stable);
237 assert!(ordered[1].stable);
238 assert!(!ordered[2].stable);
239 assert_eq!(ordered[0].id, "system");
240 assert_eq!(ordered[1].id, "types");
241 }
242
243 #[test]
244 fn render_marks_dynamic_boundary() {
245 let sections = vec![
246 CacheableSection::new("sys", "system prompt".into(), SectionPriority::System, true),
247 CacheableSection::new(
248 "task",
249 "current task".into(),
250 SectionPriority::CurrentTask,
251 false,
252 ),
253 ];
254 let output = render_with_cache_hints(§ions);
255 assert!(output.contains("--- dynamic context ---"));
256 assert!(output.contains("system prompt"));
257 assert!(output.contains("current task"));
258 }
259}