1use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
8
9use super::types::PromptHashInfo;
10
11pub fn estimate_tokens(text: &str) -> usize {
13 if text.is_empty() {
14 return 0;
15 }
16
17 let has_asian = text.chars().any(|c| {
19 matches!(c,
20 '\u{4e00}'..='\u{9fa5}' | '\u{3040}'..='\u{309f}' | '\u{30a0}'..='\u{30ff}' )
24 });
25
26 let has_code = text.starts_with("```")
28 || text.contains("function ")
29 || text.contains("class ")
30 || text.contains("const ")
31 || text.contains("let ")
32 || text.contains("var ")
33 || text.contains("import ")
34 || text.contains("export ");
35
36 let chars_per_token = if has_asian {
37 2.0
38 } else if has_code {
39 3.0
40 } else {
41 3.5
42 };
43
44 let mut tokens = text.len() as f64 / chars_per_token;
45
46 let special_chars = text
48 .chars()
49 .filter(|c| {
50 matches!(
51 c,
52 '{' | '}' | '[' | ']' | '(' | ')' | '.' | ',' | ';' | ':' | '!' | '?' | '<' | '>'
53 )
54 })
55 .count();
56 tokens += special_chars as f64 * 0.1;
57
58 let newlines = text.chars().filter(|c| *c == '\n').count();
60 tokens += newlines as f64 * 0.5;
61
62 tokens.ceil() as usize
63}
64
65struct CacheEntry {
67 content: String,
68 hash_info: PromptHashInfo,
69 expires_at: Instant,
70}
71
72pub struct PromptCache {
74 cache: HashMap<String, CacheEntry>,
75 ttl: Duration,
76 max_entries: usize,
77}
78
79impl PromptCache {
80 pub fn new(ttl_ms: Option<u64>, max_entries: Option<usize>) -> Self {
82 Self {
83 cache: HashMap::new(),
84 ttl: Duration::from_millis(ttl_ms.unwrap_or(5 * 60 * 1000)), max_entries: max_entries.unwrap_or(100),
86 }
87 }
88
89 pub fn compute_hash(&self, content: &str) -> PromptHashInfo {
91 let mut hasher = Sha256::new();
92 hasher.update(content.as_bytes());
93 let result = hasher.finalize();
94 let hash = hex::encode(&result[..8]); let estimated_tokens = estimate_tokens(content);
97 let computed_at = SystemTime::now()
98 .duration_since(UNIX_EPOCH)
99 .unwrap_or_default()
100 .as_millis() as u64;
101
102 PromptHashInfo {
103 hash,
104 computed_at,
105 length: content.len(),
106 estimated_tokens,
107 }
108 }
109
110 pub fn get(&self, key: &str) -> Option<(String, PromptHashInfo)> {
112 let entry = self.cache.get(key)?;
113
114 if Instant::now() > entry.expires_at {
116 return None;
117 }
118
119 Some((entry.content.clone(), entry.hash_info.clone()))
120 }
121
122 pub fn set(
124 &mut self,
125 key: String,
126 content: String,
127 hash_info: Option<PromptHashInfo>,
128 ) -> PromptHashInfo {
129 self.cleanup();
131
132 if self.cache.len() >= self.max_entries {
134 if let Some(oldest_key) = self
136 .cache
137 .iter()
138 .min_by_key(|(_, v)| v.expires_at)
139 .map(|(k, _)| k.clone())
140 {
141 self.cache.remove(&oldest_key);
142 }
143 }
144
145 let computed_hash_info = hash_info.unwrap_or_else(|| self.compute_hash(&content));
146
147 self.cache.insert(
148 key,
149 CacheEntry {
150 content,
151 hash_info: computed_hash_info.clone(),
152 expires_at: Instant::now() + self.ttl,
153 },
154 );
155
156 computed_hash_info
157 }
158
159 pub fn is_valid(&self, key: &str, hash: &str) -> bool {
161 match self.cache.get(key) {
162 Some(entry) => {
163 if Instant::now() > entry.expires_at {
164 return false;
165 }
166 entry.hash_info.hash == hash
167 }
168 None => false,
169 }
170 }
171
172 fn cleanup(&mut self) {
174 let now = Instant::now();
175 self.cache.retain(|_, entry| now <= entry.expires_at);
176 }
177
178 pub fn clear(&mut self) {
180 self.cache.clear();
181 }
182
183 pub fn size(&self) -> usize {
185 self.cache.len()
186 }
187
188 pub fn get_stats(&self) -> CacheStats {
190 let mut total_bytes = 0;
191 let mut oldest_entry: Option<u64> = None;
192 let mut newest_entry: Option<u64> = None;
193
194 for entry in self.cache.values() {
195 total_bytes += entry.content.len();
196 let computed_at = entry.hash_info.computed_at;
197
198 match oldest_entry {
199 Some(old) if computed_at < old => oldest_entry = Some(computed_at),
200 None => oldest_entry = Some(computed_at),
201 _ => {}
202 }
203
204 match newest_entry {
205 Some(new) if computed_at > new => newest_entry = Some(computed_at),
206 None => newest_entry = Some(computed_at),
207 _ => {}
208 }
209 }
210
211 CacheStats {
212 size: self.cache.len(),
213 total_bytes,
214 oldest_entry,
215 newest_entry,
216 }
217 }
218}
219
220impl Default for PromptCache {
221 fn default() -> Self {
222 Self::new(None, None)
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct CacheStats {
229 pub size: usize,
230 pub total_bytes: usize,
231 pub oldest_entry: Option<u64>,
232 pub newest_entry: Option<u64>,
233}
234
235pub fn generate_cache_key(
237 working_dir: &str,
238 model: Option<&str>,
239 permission_mode: Option<&str>,
240 plan_mode: bool,
241) -> String {
242 format!(
243 "{}:{}:{}:{}",
244 working_dir,
245 model.unwrap_or("default"),
246 permission_mode.unwrap_or("default"),
247 if plan_mode { "plan" } else { "normal" }
248 )
249}