Skip to main content

aster/prompt/
cache.rs

1//! 提示词缓存系统
2//!
3//! 实现 system_prompt_hash 计算和缓存优化
4
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
8
9use super::types::PromptHashInfo;
10
11/// 估算 tokens
12pub fn estimate_tokens(text: &str) -> usize {
13    if text.is_empty() {
14        return 0;
15    }
16
17    // 检测是否包含亚洲字符
18    let has_asian = text.chars().any(|c| {
19        matches!(c,
20            '\u{4e00}'..='\u{9fa5}' |  // CJK
21            '\u{3040}'..='\u{309f}' |  // Hiragana
22            '\u{30a0}'..='\u{30ff}'    // Katakana
23        )
24    });
25
26    // 检测是否包含代码
27    let has_code = text.starts_with("```")
28        || text.contains("function ")
29        || text.contains("class ")
30        || text.contains("const ")
31        || text.contains("let ")
32        || text.contains("var ")
33        || text.contains("import ")
34        || text.contains("export ");
35
36    let chars_per_token = if has_asian {
37        2.0
38    } else if has_code {
39        3.0
40    } else {
41        3.5
42    };
43
44    let mut tokens = text.len() as f64 / chars_per_token;
45
46    // 特殊字符计数
47    let special_chars = text
48        .chars()
49        .filter(|c| {
50            matches!(
51                c,
52                '{' | '}' | '[' | ']' | '(' | ')' | '.' | ',' | ';' | ':' | '!' | '?' | '<' | '>'
53            )
54        })
55        .count();
56    tokens += special_chars as f64 * 0.1;
57
58    // 换行符计数
59    let newlines = text.chars().filter(|c| *c == '\n').count();
60    tokens += newlines as f64 * 0.5;
61
62    tokens.ceil() as usize
63}
64
65/// 缓存条目
66struct CacheEntry {
67    content: String,
68    hash_info: PromptHashInfo,
69    expires_at: Instant,
70}
71
72/// 提示词缓存
73pub struct PromptCache {
74    cache: HashMap<String, CacheEntry>,
75    ttl: Duration,
76    max_entries: usize,
77}
78
79impl PromptCache {
80    /// 创建新的缓存实例
81    pub fn new(ttl_ms: Option<u64>, max_entries: Option<usize>) -> Self {
82        Self {
83            cache: HashMap::new(),
84            ttl: Duration::from_millis(ttl_ms.unwrap_or(5 * 60 * 1000)), // 5 分钟
85            max_entries: max_entries.unwrap_or(100),
86        }
87    }
88
89    /// 计算提示词哈希
90    pub fn compute_hash(&self, content: &str) -> PromptHashInfo {
91        let mut hasher = Sha256::new();
92        hasher.update(content.as_bytes());
93        let result = hasher.finalize();
94        let hash = hex::encode(&result[..8]); // 取前 16 个字符
95
96        let estimated_tokens = estimate_tokens(content);
97        let computed_at = SystemTime::now()
98            .duration_since(UNIX_EPOCH)
99            .unwrap_or_default()
100            .as_millis() as u64;
101
102        PromptHashInfo {
103            hash,
104            computed_at,
105            length: content.len(),
106            estimated_tokens,
107        }
108    }
109
110    /// 获取缓存的提示词
111    pub fn get(&self, key: &str) -> Option<(String, PromptHashInfo)> {
112        let entry = self.cache.get(key)?;
113
114        // 检查是否过期
115        if Instant::now() > entry.expires_at {
116            return None;
117        }
118
119        Some((entry.content.clone(), entry.hash_info.clone()))
120    }
121
122    /// 设置缓存
123    pub fn set(
124        &mut self,
125        key: String,
126        content: String,
127        hash_info: Option<PromptHashInfo>,
128    ) -> PromptHashInfo {
129        // 清理过期条目
130        self.cleanup();
131
132        // 检查容量
133        if self.cache.len() >= self.max_entries {
134            // 删除最旧的条目
135            if let Some(oldest_key) = self
136                .cache
137                .iter()
138                .min_by_key(|(_, v)| v.expires_at)
139                .map(|(k, _)| k.clone())
140            {
141                self.cache.remove(&oldest_key);
142            }
143        }
144
145        let computed_hash_info = hash_info.unwrap_or_else(|| self.compute_hash(&content));
146
147        self.cache.insert(
148            key,
149            CacheEntry {
150                content,
151                hash_info: computed_hash_info.clone(),
152                expires_at: Instant::now() + self.ttl,
153            },
154        );
155
156        computed_hash_info
157    }
158
159    /// 检查缓存是否有效
160    pub fn is_valid(&self, key: &str, hash: &str) -> bool {
161        match self.cache.get(key) {
162            Some(entry) => {
163                if Instant::now() > entry.expires_at {
164                    return false;
165                }
166                entry.hash_info.hash == hash
167            }
168            None => false,
169        }
170    }
171
172    /// 清理过期条目
173    fn cleanup(&mut self) {
174        let now = Instant::now();
175        self.cache.retain(|_, entry| now <= entry.expires_at);
176    }
177
178    /// 清空缓存
179    pub fn clear(&mut self) {
180        self.cache.clear();
181    }
182
183    /// 获取缓存大小
184    pub fn size(&self) -> usize {
185        self.cache.len()
186    }
187
188    /// 获取缓存统计
189    pub fn get_stats(&self) -> CacheStats {
190        let mut total_bytes = 0;
191        let mut oldest_entry: Option<u64> = None;
192        let mut newest_entry: Option<u64> = None;
193
194        for entry in self.cache.values() {
195            total_bytes += entry.content.len();
196            let computed_at = entry.hash_info.computed_at;
197
198            match oldest_entry {
199                Some(old) if computed_at < old => oldest_entry = Some(computed_at),
200                None => oldest_entry = Some(computed_at),
201                _ => {}
202            }
203
204            match newest_entry {
205                Some(new) if computed_at > new => newest_entry = Some(computed_at),
206                None => newest_entry = Some(computed_at),
207                _ => {}
208            }
209        }
210
211        CacheStats {
212            size: self.cache.len(),
213            total_bytes,
214            oldest_entry,
215            newest_entry,
216        }
217    }
218}
219
220impl Default for PromptCache {
221    fn default() -> Self {
222        Self::new(None, None)
223    }
224}
225
226/// 缓存统计信息
227#[derive(Debug, Clone)]
228pub struct CacheStats {
229    pub size: usize,
230    pub total_bytes: usize,
231    pub oldest_entry: Option<u64>,
232    pub newest_entry: Option<u64>,
233}
234
235/// 生成缓存键
236pub fn generate_cache_key(
237    working_dir: &str,
238    model: Option<&str>,
239    permission_mode: Option<&str>,
240    plan_mode: bool,
241) -> String {
242    format!(
243        "{}:{}:{}:{}",
244        working_dir,
245        model.unwrap_or("default"),
246        permission_mode.unwrap_or("default"),
247        if plan_mode { "plan" } else { "normal" }
248    )
249}