1use crate::error::{EngramError, Result};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11#[serde(rename_all = "lowercase")]
12pub enum CompressionStrategy {
13 #[default]
15 None,
16 HeadTail,
18 Summary,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum TokenEncoding {
25 Cl100kBase,
27 O200kBase,
29}
30
31impl TokenEncoding {
32 pub fn as_str(&self) -> &'static str {
33 match self {
34 TokenEncoding::Cl100kBase => "cl100k_base",
35 TokenEncoding::O200kBase => "o200k_base",
36 }
37 }
38}
39
40pub fn detect_encoding(model: &str) -> Option<TokenEncoding> {
42 let model_lower = model.to_lowercase();
43
44 if model_lower.contains("gpt-4o") {
46 return Some(TokenEncoding::O200kBase);
47 }
48
49 if model_lower.contains("gpt-4") || model_lower.contains("gpt-3.5") {
51 return Some(TokenEncoding::Cl100kBase);
52 }
53
54 if model_lower.contains("text-embedding") {
56 return Some(TokenEncoding::Cl100kBase);
57 }
58
59 if model_lower.contains("claude") {
62 return Some(TokenEncoding::Cl100kBase);
63 }
64
65 if let Some(stripped) = model_lower.strip_prefix("openai/") {
67 return detect_encoding(stripped);
68 }
69 if model_lower.starts_with("anthropic/") {
70 return Some(TokenEncoding::Cl100kBase);
71 }
72
73 None
74}
75
76pub fn parse_encoding(encoding: &str) -> Option<TokenEncoding> {
78 match encoding.to_lowercase().as_str() {
79 "cl100k_base" | "cl100k" => Some(TokenEncoding::Cl100kBase),
80 "o200k_base" | "o200k" => Some(TokenEncoding::O200kBase),
81 _ => None,
82 }
83}
84
85pub fn count_tokens(text: &str, model: &str, encoding: Option<&str>) -> Result<usize> {
100 let token_encoding = if let Some(enc) = encoding {
102 parse_encoding(enc).ok_or_else(|| {
103 EngramError::InvalidInput(format!(
104 "Unknown encoding '{}'. Supported: cl100k_base, o200k_base",
105 enc
106 ))
107 })?
108 } else {
109 detect_encoding(model).ok_or_else(|| {
111 EngramError::InvalidInput(format!(
112 "Unknown model '{}'. Provide 'encoding' parameter (cl100k_base or o200k_base) or use a known model (gpt-4, gpt-4o, claude-*, text-embedding-*).",
113 model
114 ))
115 })?
116 };
117
118 let bpe = match token_encoding {
120 TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base(),
121 TokenEncoding::O200kBase => tiktoken_rs::o200k_base(),
122 };
123
124 match bpe {
125 Ok(encoder) => Ok(encoder.encode_with_special_tokens(text).len()),
126 Err(e) => Err(EngramError::Internal(format!(
127 "Failed to initialize tokenizer: {}",
128 e
129 ))),
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ContextBudgetInput {
136 pub memory_ids: Vec<i64>,
138 pub model: String,
140 pub encoding: Option<String>,
142 pub budget: usize,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct ContextBudgetResult {
149 pub total_tokens: usize,
151 pub budget: usize,
153 pub remaining: usize,
155 pub over_budget: bool,
157 pub memories_counted: usize,
159 pub model_used: String,
161 pub encoding_used: String,
163 pub suggestions: Vec<String>,
165 pub memory_tokens: Vec<MemoryTokenCount>,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct MemoryTokenCount {
172 pub memory_id: i64,
173 pub tokens: usize,
174 pub content_preview: String,
175}
176
177impl ContextBudgetResult {
178 pub fn new(
179 total_tokens: usize,
180 budget: usize,
181 model: &str,
182 encoding: TokenEncoding,
183 memory_tokens: Vec<MemoryTokenCount>,
184 ) -> Self {
185 let over_budget = total_tokens > budget;
186 let remaining = if over_budget {
187 0
188 } else {
189 budget - total_tokens
190 };
191
192 let mut suggestions = Vec::new();
193 if over_budget {
194 let excess = total_tokens - budget;
195 suggestions.push(format!(
196 "Over budget by {} tokens ({:.1}% of budget)",
197 excess,
198 (excess as f64 / budget as f64) * 100.0
199 ));
200
201 let mut sorted = memory_tokens.clone();
203 sorted.sort_by(|a, b| b.tokens.cmp(&a.tokens));
204
205 if let Some(largest) = sorted.first() {
206 suggestions.push(format!(
207 "Largest memory: id={} ({} tokens) - consider summarizing",
208 largest.memory_id, largest.tokens
209 ));
210 }
211
212 suggestions.push("Use memory_summarize to compress large memories".to_string());
213 suggestions.push("Use memory_archive_old to batch summarize old memories".to_string());
214 }
215
216 Self {
217 total_tokens,
218 budget,
219 remaining,
220 over_budget,
221 memories_counted: memory_tokens.len(),
222 model_used: model.to_string(),
223 encoding_used: encoding.as_str().to_string(),
224 suggestions,
225 memory_tokens,
226 }
227 }
228}
229
230pub fn check_context_budget(
232 contents: &[(i64, String)],
233 model: &str,
234 encoding: Option<&str>,
235 budget: usize,
236) -> Result<ContextBudgetResult> {
237 let token_encoding = if let Some(enc) = encoding {
239 parse_encoding(enc).ok_or_else(|| {
240 EngramError::InvalidInput(format!(
241 "Unknown encoding '{}'. Supported: cl100k_base, o200k_base",
242 enc
243 ))
244 })?
245 } else {
246 detect_encoding(model).ok_or_else(|| {
247 EngramError::InvalidInput(format!(
248 "Unknown model '{}'. Provide 'encoding' parameter (cl100k_base or o200k_base) or use a known model.",
249 model
250 ))
251 })?
252 };
253
254 let bpe = match token_encoding {
255 TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base(),
256 TokenEncoding::O200kBase => tiktoken_rs::o200k_base(),
257 }
258 .map_err(|e| EngramError::Internal(format!("Failed to initialize tokenizer: {}", e)))?;
259
260 let mut memory_tokens = Vec::new();
261 let mut total_tokens = 0;
262
263 for (id, content) in contents {
264 let tokens = bpe.encode_with_special_tokens(content).len();
265 total_tokens += tokens;
266
267 let preview = if content.len() > 50 {
269 format!("{}...", &content[..50])
270 } else {
271 content.clone()
272 };
273
274 memory_tokens.push(MemoryTokenCount {
275 memory_id: *id,
276 tokens,
277 content_preview: preview,
278 });
279 }
280
281 Ok(ContextBudgetResult::new(
282 total_tokens,
283 budget,
284 model,
285 token_encoding,
286 memory_tokens,
287 ))
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_detect_encoding() {
296 assert_eq!(detect_encoding("gpt-4"), Some(TokenEncoding::Cl100kBase));
297 assert_eq!(
298 detect_encoding("gpt-4-turbo"),
299 Some(TokenEncoding::Cl100kBase)
300 );
301 assert_eq!(detect_encoding("gpt-4o"), Some(TokenEncoding::O200kBase));
302 assert_eq!(
303 detect_encoding("gpt-4o-mini"),
304 Some(TokenEncoding::O200kBase)
305 );
306 assert_eq!(
307 detect_encoding("claude-3-opus"),
308 Some(TokenEncoding::Cl100kBase)
309 );
310 assert_eq!(
311 detect_encoding("text-embedding-3-small"),
312 Some(TokenEncoding::Cl100kBase)
313 );
314 assert_eq!(detect_encoding("unknown-model"), None);
315 }
316
317 #[test]
318 fn test_count_tokens_known_model() {
319 let result = count_tokens("Hello, world!", "gpt-4", None);
320 assert!(result.is_ok());
321 assert!(result.unwrap() > 0);
322 }
323
324 #[test]
325 fn test_count_tokens_unknown_model_no_encoding() {
326 let result = count_tokens("Hello, world!", "unknown-model", None);
327 assert!(result.is_err());
328 let err = result.unwrap_err().to_string();
329 assert!(err.contains("Unknown model"));
330 }
331
332 #[test]
333 fn test_count_tokens_unknown_model_with_encoding() {
334 let result = count_tokens("Hello, world!", "unknown-model", Some("cl100k_base"));
335 assert!(result.is_ok());
336 }
337
338 #[test]
339 fn test_context_budget_under() {
340 let contents = vec![
341 (1, "Hello world".to_string()),
342 (2, "Test content".to_string()),
343 ];
344 let result = check_context_budget(&contents, "gpt-4", None, 1000).unwrap();
345 assert!(!result.over_budget);
346 assert!(result.remaining > 0);
347 assert_eq!(result.memories_counted, 2);
348 }
349
350 #[test]
351 fn test_context_budget_over() {
352 let contents = vec![(1, "A".repeat(10000))];
353 let result = check_context_budget(&contents, "gpt-4", None, 100).unwrap();
354 assert!(result.over_budget);
355 assert_eq!(result.remaining, 0);
356 assert!(!result.suggestions.is_empty());
357 }
358}