1use std::collections::HashMap;
16
17use md5::{Digest, Md5};
18
19#[derive(Debug, Clone)]
21pub struct CacheableSection {
22 pub id: String,
23 pub content: String,
24 pub hash: String,
25 pub priority: SectionPriority,
26 pub stable: bool,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
30pub enum SectionPriority {
31 System = 0,
32 ProjectStructure = 1,
33 TypeDefinitions = 2,
34 Dependencies = 3,
35 RecentContext = 4,
36 CurrentTask = 5,
37}
38
39#[derive(Debug)]
41pub struct ProviderCacheState {
42 sent_hashes: HashMap<String, String>,
43 cache_hits: u64,
44 cache_misses: u64,
45}
46
47impl ProviderCacheState {
48 pub fn new() -> Self {
49 Self {
50 sent_hashes: HashMap::new(),
51 cache_hits: 0,
52 cache_misses: 0,
53 }
54 }
55
56 pub fn needs_update(&self, section: &CacheableSection) -> bool {
58 match self.sent_hashes.get(§ion.id) {
59 Some(prev_hash) => prev_hash != §ion.hash,
60 None => true,
61 }
62 }
63
64 pub fn mark_sent(&mut self, section: &CacheableSection) {
66 self.sent_hashes
67 .insert(section.id.clone(), section.hash.clone());
68 }
69
70 pub fn filter_changed<'a>(&mut self, sections: &'a [CacheableSection]) -> Vec<&'a CacheableSection> {
73 let mut result = Vec::new();
74 for section in sections {
75 if self.needs_update(section) {
76 self.cache_misses += 1;
77 result.push(section);
78 } else {
79 self.cache_hits += 1;
80 }
81 }
82 result
83 }
84
85 pub fn cache_hit_rate(&self) -> f64 {
86 let total = self.cache_hits + self.cache_misses;
87 if total == 0 {
88 return 0.0;
89 }
90 self.cache_hits as f64 / total as f64
91 }
92
93 pub fn reset(&mut self) {
94 self.sent_hashes.clear();
95 self.cache_hits = 0;
96 self.cache_misses = 0;
97 }
98}
99
100impl Default for ProviderCacheState {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106impl CacheableSection {
107 pub fn new(id: &str, content: String, priority: SectionPriority, stable: bool) -> Self {
108 let hash = content_hash(&content);
109 Self {
110 id: id.to_string(),
111 content,
112 hash,
113 priority,
114 stable,
115 }
116 }
117}
118
119pub fn order_for_caching(mut sections: Vec<CacheableSection>) -> Vec<CacheableSection> {
123 sections.sort_by(|a, b| {
124 a.stable
125 .cmp(&b.stable)
126 .reverse()
127 .then(a.priority.cmp(&b.priority))
128 });
129 sections
130}
131
132pub fn render_with_cache_hints(sections: &[CacheableSection]) -> String {
134 let mut output = String::new();
135 let mut last_stable = true;
136
137 for section in sections {
138 if last_stable && !section.stable {
139 output.push_str("\n--- dynamic context ---\n");
140 }
141 output.push_str(§ion.content);
142 output.push('\n');
143 last_stable = section.stable;
144 }
145
146 output
147}
148
149fn content_hash(content: &str) -> String {
150 let mut hasher = Md5::new();
151 hasher.update(content.as_bytes());
152 format!("{:x}", hasher.finalize())
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn section_hash_deterministic() {
161 let s1 = CacheableSection::new("id", "content".into(), SectionPriority::System, true);
162 let s2 = CacheableSection::new("id", "content".into(), SectionPriority::System, true);
163 assert_eq!(s1.hash, s2.hash);
164 }
165
166 #[test]
167 fn section_hash_changes_with_content() {
168 let s1 = CacheableSection::new("id", "content_v1".into(), SectionPriority::System, true);
169 let s2 = CacheableSection::new("id", "content_v2".into(), SectionPriority::System, true);
170 assert_ne!(s1.hash, s2.hash);
171 }
172
173 #[test]
174 fn needs_update_new_section() {
175 let state = ProviderCacheState::new();
176 let section = CacheableSection::new("test", "content".into(), SectionPriority::System, true);
177 assert!(state.needs_update(§ion));
178 }
179
180 #[test]
181 fn needs_update_unchanged() {
182 let mut state = ProviderCacheState::new();
183 let section = CacheableSection::new("test", "content".into(), SectionPriority::System, true);
184 state.mark_sent(§ion);
185 assert!(!state.needs_update(§ion));
186 }
187
188 #[test]
189 fn needs_update_changed() {
190 let mut state = ProviderCacheState::new();
191 let s1 = CacheableSection::new("test", "v1".into(), SectionPriority::System, true);
192 state.mark_sent(&s1);
193 let s2 = CacheableSection::new("test", "v2".into(), SectionPriority::System, true);
194 assert!(state.needs_update(&s2));
195 }
196
197 #[test]
198 fn filter_changed_tracks_hits() {
199 let mut state = ProviderCacheState::new();
200 let s1 = CacheableSection::new("a", "stable".into(), SectionPriority::System, true);
201 state.mark_sent(&s1);
202
203 let sections = vec![
204 s1.clone(),
205 CacheableSection::new("b", "new".into(), SectionPriority::CurrentTask, false),
206 ];
207 let changed = state.filter_changed(§ions);
208 assert_eq!(changed.len(), 1);
209 assert_eq!(changed[0].id, "b");
210 assert!((state.cache_hit_rate() - 0.5).abs() < 1e-6);
211 }
212
213 #[test]
214 fn order_stable_first() {
215 let sections = vec![
216 CacheableSection::new("task", "current".into(), SectionPriority::CurrentTask, false),
217 CacheableSection::new("system", "system".into(), SectionPriority::System, true),
218 CacheableSection::new("types", "types".into(), SectionPriority::TypeDefinitions, true),
219 ];
220 let ordered = order_for_caching(sections);
221 assert!(ordered[0].stable);
222 assert!(ordered[1].stable);
223 assert!(!ordered[2].stable);
224 assert_eq!(ordered[0].id, "system");
225 assert_eq!(ordered[1].id, "types");
226 }
227
228 #[test]
229 fn render_marks_dynamic_boundary() {
230 let sections = vec![
231 CacheableSection::new("sys", "system prompt".into(), SectionPriority::System, true),
232 CacheableSection::new("task", "current task".into(), SectionPriority::CurrentTask, false),
233 ];
234 let output = render_with_cache_hints(§ions);
235 assert!(output.contains("--- dynamic context ---"));
236 assert!(output.contains("system prompt"));
237 assert!(output.contains("current task"));
238 }
239}