Skip to main content

engram/intelligence/
compression.rs

1//! Context Compression Engine (Phase 2 - ENG-34)
2//!
3//! Provides token counting and context budget management for LLM interactions.
4//! Uses tiktoken-rs for accurate token counting with explicit error handling.
5
6use crate::error::{EngramError, Result};
7use serde::{Deserialize, Serialize};
8
9/// Compression strategy for memory content
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11#[serde(rename_all = "lowercase")]
12pub enum CompressionStrategy {
13    /// Raw content, no compression (default)
14    #[default]
15    None,
16    /// Keep first 60% + last 30% with ellipsis (uses soft_trim)
17    HeadTail,
18    /// LLM-generated summary (creates new Summary memory)
19    Summary,
20}
21
22/// Supported encoding types for token counting
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum TokenEncoding {
25    /// cl100k_base - GPT-4, GPT-4-turbo, text-embedding-3-*
26    Cl100kBase,
27    /// o200k_base - GPT-4o, GPT-4o-mini
28    O200kBase,
29}
30
31impl TokenEncoding {
32    pub fn as_str(&self) -> &'static str {
33        match self {
34            TokenEncoding::Cl100kBase => "cl100k_base",
35            TokenEncoding::O200kBase => "o200k_base",
36        }
37    }
38}
39
40/// Detect the appropriate encoding for a model name
41pub fn detect_encoding(model: &str) -> Option<TokenEncoding> {
42    let model_lower = model.to_lowercase();
43
44    // GPT-4o and GPT-4o-mini use o200k_base
45    if model_lower.contains("gpt-4o") {
46        return Some(TokenEncoding::O200kBase);
47    }
48
49    // GPT-4, GPT-4-turbo use cl100k_base
50    if model_lower.contains("gpt-4") || model_lower.contains("gpt-3.5") {
51        return Some(TokenEncoding::Cl100kBase);
52    }
53
54    // text-embedding models use cl100k_base
55    if model_lower.contains("text-embedding") {
56        return Some(TokenEncoding::Cl100kBase);
57    }
58
59    // Claude models - use cl100k_base as approximation
60    // (Claude's actual tokenizer is different but cl100k is close enough for budgeting)
61    if model_lower.contains("claude") {
62        return Some(TokenEncoding::Cl100kBase);
63    }
64
65    // OpenRouter prefixed models
66    if let Some(stripped) = model_lower.strip_prefix("openai/") {
67        return detect_encoding(stripped);
68    }
69    if model_lower.starts_with("anthropic/") {
70        return Some(TokenEncoding::Cl100kBase);
71    }
72
73    None
74}
75
76/// Parse encoding string to TokenEncoding
77pub fn parse_encoding(encoding: &str) -> Option<TokenEncoding> {
78    match encoding.to_lowercase().as_str() {
79        "cl100k_base" | "cl100k" => Some(TokenEncoding::Cl100kBase),
80        "o200k_base" | "o200k" => Some(TokenEncoding::O200kBase),
81        _ => None,
82    }
83}
84
85/// Count tokens in text using the specified model or encoding.
86///
87/// # Arguments
88/// * `text` - The text to count tokens for
89/// * `model` - Model name (e.g., "gpt-4", "gpt-4o", "claude-3-opus")
90/// * `encoding` - Optional encoding override (e.g., "cl100k_base", "o200k_base")
91///
92/// # Returns
93/// * `Ok(usize)` - Number of tokens
94/// * `Err` - If model is unknown AND no encoding provided
95///
96/// # Errors
97/// This function will NOT silently fall back to chars/4. If the model is unknown
98/// and no encoding is provided, it returns an error with a helpful message.
99pub fn count_tokens(text: &str, model: &str, encoding: Option<&str>) -> Result<usize> {
100    // First try explicit encoding override
101    let token_encoding = if let Some(enc) = encoding {
102        parse_encoding(enc).ok_or_else(|| {
103            EngramError::InvalidInput(format!(
104                "Unknown encoding '{}'. Supported: cl100k_base, o200k_base",
105                enc
106            ))
107        })?
108    } else {
109        // Try to detect from model name
110        detect_encoding(model).ok_or_else(|| {
111            EngramError::InvalidInput(format!(
112                "Unknown model '{}'. Provide 'encoding' parameter (cl100k_base or o200k_base) or use a known model (gpt-4, gpt-4o, claude-*, text-embedding-*).",
113                model
114            ))
115        })?
116    };
117
118    // Use tiktoken-rs to count tokens
119    let bpe = match token_encoding {
120        TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base(),
121        TokenEncoding::O200kBase => tiktoken_rs::o200k_base(),
122    };
123
124    match bpe {
125        Ok(encoder) => Ok(encoder.encode_with_special_tokens(text).len()),
126        Err(e) => Err(EngramError::Internal(format!(
127            "Failed to initialize tokenizer: {}",
128            e
129        ))),
130    }
131}
132
133/// Input for context budget checking
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ContextBudgetInput {
136    /// Memory IDs to check
137    pub memory_ids: Vec<i64>,
138    /// Model name (required)
139    pub model: String,
140    /// Optional encoding override
141    pub encoding: Option<String>,
142    /// Token budget to check against
143    pub budget: usize,
144}
145
146/// Result of context budget check
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct ContextBudgetResult {
149    /// Total tokens across all memories
150    pub total_tokens: usize,
151    /// The budget that was checked against
152    pub budget: usize,
153    /// Remaining tokens (budget - total, or 0 if over)
154    pub remaining: usize,
155    /// Whether total exceeds budget
156    pub over_budget: bool,
157    /// Number of memories counted
158    pub memories_counted: usize,
159    /// Model used for counting
160    pub model_used: String,
161    /// Encoding used for counting
162    pub encoding_used: String,
163    /// Suggestions if over budget
164    pub suggestions: Vec<String>,
165    /// Per-memory token counts
166    pub memory_tokens: Vec<MemoryTokenCount>,
167}
168
169/// Token count for a single memory
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct MemoryTokenCount {
172    pub memory_id: i64,
173    pub tokens: usize,
174    pub content_preview: String,
175}
176
177impl ContextBudgetResult {
178    pub fn new(
179        total_tokens: usize,
180        budget: usize,
181        model: &str,
182        encoding: TokenEncoding,
183        memory_tokens: Vec<MemoryTokenCount>,
184    ) -> Self {
185        let over_budget = total_tokens > budget;
186        let remaining = if over_budget {
187            0
188        } else {
189            budget - total_tokens
190        };
191
192        let mut suggestions = Vec::new();
193        if over_budget {
194            let excess = total_tokens - budget;
195            suggestions.push(format!(
196                "Over budget by {} tokens ({:.1}% of budget)",
197                excess,
198                (excess as f64 / budget as f64) * 100.0
199            ));
200
201            // Find largest memories
202            let mut sorted = memory_tokens.clone();
203            sorted.sort_by(|a, b| b.tokens.cmp(&a.tokens));
204
205            if let Some(largest) = sorted.first() {
206                suggestions.push(format!(
207                    "Largest memory: id={} ({} tokens) - consider summarizing",
208                    largest.memory_id, largest.tokens
209                ));
210            }
211
212            suggestions.push("Use memory_summarize to compress large memories".to_string());
213            suggestions.push("Use memory_archive_old to batch summarize old memories".to_string());
214        }
215
216        Self {
217            total_tokens,
218            budget,
219            remaining,
220            over_budget,
221            memories_counted: memory_tokens.len(),
222            model_used: model.to_string(),
223            encoding_used: encoding.as_str().to_string(),
224            suggestions,
225            memory_tokens,
226        }
227    }
228}
229
230/// Check token budget for a set of memories
231pub fn check_context_budget(
232    contents: &[(i64, String)],
233    model: &str,
234    encoding: Option<&str>,
235    budget: usize,
236) -> Result<ContextBudgetResult> {
237    // Determine encoding (validates model/encoding)
238    let token_encoding = if let Some(enc) = encoding {
239        parse_encoding(enc).ok_or_else(|| {
240            EngramError::InvalidInput(format!(
241                "Unknown encoding '{}'. Supported: cl100k_base, o200k_base",
242                enc
243            ))
244        })?
245    } else {
246        detect_encoding(model).ok_or_else(|| {
247            EngramError::InvalidInput(format!(
248                "Unknown model '{}'. Provide 'encoding' parameter (cl100k_base or o200k_base) or use a known model.",
249                model
250            ))
251        })?
252    };
253
254    let bpe = match token_encoding {
255        TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base(),
256        TokenEncoding::O200kBase => tiktoken_rs::o200k_base(),
257    }
258    .map_err(|e| EngramError::Internal(format!("Failed to initialize tokenizer: {}", e)))?;
259
260    let mut memory_tokens = Vec::new();
261    let mut total_tokens = 0;
262
263    for (id, content) in contents {
264        let tokens = bpe.encode_with_special_tokens(content).len();
265        total_tokens += tokens;
266
267        // Create preview (first 50 chars)
268        let preview = if content.len() > 50 {
269            format!("{}...", &content[..50])
270        } else {
271            content.clone()
272        };
273
274        memory_tokens.push(MemoryTokenCount {
275            memory_id: *id,
276            tokens,
277            content_preview: preview,
278        });
279    }
280
281    Ok(ContextBudgetResult::new(
282        total_tokens,
283        budget,
284        model,
285        token_encoding,
286        memory_tokens,
287    ))
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_detect_encoding() {
296        assert_eq!(detect_encoding("gpt-4"), Some(TokenEncoding::Cl100kBase));
297        assert_eq!(
298            detect_encoding("gpt-4-turbo"),
299            Some(TokenEncoding::Cl100kBase)
300        );
301        assert_eq!(detect_encoding("gpt-4o"), Some(TokenEncoding::O200kBase));
302        assert_eq!(
303            detect_encoding("gpt-4o-mini"),
304            Some(TokenEncoding::O200kBase)
305        );
306        assert_eq!(
307            detect_encoding("claude-3-opus"),
308            Some(TokenEncoding::Cl100kBase)
309        );
310        assert_eq!(
311            detect_encoding("text-embedding-3-small"),
312            Some(TokenEncoding::Cl100kBase)
313        );
314        assert_eq!(detect_encoding("unknown-model"), None);
315    }
316
317    #[test]
318    fn test_count_tokens_known_model() {
319        let result = count_tokens("Hello, world!", "gpt-4", None);
320        assert!(result.is_ok());
321        assert!(result.unwrap() > 0);
322    }
323
324    #[test]
325    fn test_count_tokens_unknown_model_no_encoding() {
326        let result = count_tokens("Hello, world!", "unknown-model", None);
327        assert!(result.is_err());
328        let err = result.unwrap_err().to_string();
329        assert!(err.contains("Unknown model"));
330    }
331
332    #[test]
333    fn test_count_tokens_unknown_model_with_encoding() {
334        let result = count_tokens("Hello, world!", "unknown-model", Some("cl100k_base"));
335        assert!(result.is_ok());
336    }
337
338    #[test]
339    fn test_context_budget_under() {
340        let contents = vec![
341            (1, "Hello world".to_string()),
342            (2, "Test content".to_string()),
343        ];
344        let result = check_context_budget(&contents, "gpt-4", None, 1000).unwrap();
345        assert!(!result.over_budget);
346        assert!(result.remaining > 0);
347        assert_eq!(result.memories_counted, 2);
348    }
349
350    #[test]
351    fn test_context_budget_over() {
352        let contents = vec![(1, "A".repeat(10000))];
353        let result = check_context_budget(&contents, "gpt-4", None, 100).unwrap();
354        assert!(result.over_budget);
355        assert_eq!(result.remaining, 0);
356        assert!(!result.suggestions.is_empty());
357    }
358}