use crate::error::{EngramError, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum CompressionStrategy {
#[default]
None,
HeadTail,
Summary,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenEncoding {
Cl100kBase,
O200kBase,
}
impl TokenEncoding {
pub fn as_str(&self) -> &'static str {
match self {
TokenEncoding::Cl100kBase => "cl100k_base",
TokenEncoding::O200kBase => "o200k_base",
}
}
}
pub fn detect_encoding(model: &str) -> Option<TokenEncoding> {
let model_lower = model.to_lowercase();
if model_lower.contains("gpt-4o") {
return Some(TokenEncoding::O200kBase);
}
if model_lower.contains("gpt-4") || model_lower.contains("gpt-3.5") {
return Some(TokenEncoding::Cl100kBase);
}
if model_lower.contains("text-embedding") {
return Some(TokenEncoding::Cl100kBase);
}
if model_lower.contains("claude") {
return Some(TokenEncoding::Cl100kBase);
}
if let Some(stripped) = model_lower.strip_prefix("openai/") {
return detect_encoding(stripped);
}
if model_lower.starts_with("anthropic/") {
return Some(TokenEncoding::Cl100kBase);
}
None
}
pub fn parse_encoding(encoding: &str) -> Option<TokenEncoding> {
match encoding.to_lowercase().as_str() {
"cl100k_base" | "cl100k" => Some(TokenEncoding::Cl100kBase),
"o200k_base" | "o200k" => Some(TokenEncoding::O200kBase),
_ => None,
}
}
pub fn count_tokens(text: &str, model: &str, encoding: Option<&str>) -> Result<usize> {
let token_encoding = if let Some(enc) = encoding {
parse_encoding(enc).ok_or_else(|| {
EngramError::InvalidInput(format!(
"Unknown encoding '{}'. Supported: cl100k_base, o200k_base",
enc
))
})?
} else {
detect_encoding(model).ok_or_else(|| {
EngramError::InvalidInput(format!(
"Unknown model '{}'. Provide 'encoding' parameter (cl100k_base or o200k_base) or use a known model (gpt-4, gpt-4o, claude-*, text-embedding-*).",
model
))
})?
};
let bpe = match token_encoding {
TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base(),
TokenEncoding::O200kBase => tiktoken_rs::o200k_base(),
};
match bpe {
Ok(encoder) => Ok(encoder.encode_with_special_tokens(text).len()),
Err(e) => Err(EngramError::Internal(format!(
"Failed to initialize tokenizer: {}",
e
))),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextBudgetInput {
pub memory_ids: Vec<i64>,
pub model: String,
pub encoding: Option<String>,
pub budget: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextBudgetResult {
pub total_tokens: usize,
pub budget: usize,
pub remaining: usize,
pub over_budget: bool,
pub memories_counted: usize,
pub model_used: String,
pub encoding_used: String,
pub suggestions: Vec<String>,
pub memory_tokens: Vec<MemoryTokenCount>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryTokenCount {
pub memory_id: i64,
pub tokens: usize,
pub content_preview: String,
}
impl ContextBudgetResult {
pub fn new(
total_tokens: usize,
budget: usize,
model: &str,
encoding: TokenEncoding,
memory_tokens: Vec<MemoryTokenCount>,
) -> Self {
let over_budget = total_tokens > budget;
let remaining = if over_budget {
0
} else {
budget - total_tokens
};
let mut suggestions = Vec::new();
if over_budget {
let excess = total_tokens - budget;
suggestions.push(format!(
"Over budget by {} tokens ({:.1}% of budget)",
excess,
(excess as f64 / budget as f64) * 100.0
));
let mut sorted = memory_tokens.clone();
sorted.sort_by(|a, b| b.tokens.cmp(&a.tokens));
if let Some(largest) = sorted.first() {
suggestions.push(format!(
"Largest memory: id={} ({} tokens) - consider summarizing",
largest.memory_id, largest.tokens
));
}
suggestions.push("Use memory_summarize to compress large memories".to_string());
suggestions.push("Use memory_archive_old to batch summarize old memories".to_string());
}
Self {
total_tokens,
budget,
remaining,
over_budget,
memories_counted: memory_tokens.len(),
model_used: model.to_string(),
encoding_used: encoding.as_str().to_string(),
suggestions,
memory_tokens,
}
}
}
pub fn check_context_budget(
contents: &[(i64, String)],
model: &str,
encoding: Option<&str>,
budget: usize,
) -> Result<ContextBudgetResult> {
let token_encoding = if let Some(enc) = encoding {
parse_encoding(enc).ok_or_else(|| {
EngramError::InvalidInput(format!(
"Unknown encoding '{}'. Supported: cl100k_base, o200k_base",
enc
))
})?
} else {
detect_encoding(model).ok_or_else(|| {
EngramError::InvalidInput(format!(
"Unknown model '{}'. Provide 'encoding' parameter (cl100k_base or o200k_base) or use a known model.",
model
))
})?
};
let bpe = match token_encoding {
TokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base(),
TokenEncoding::O200kBase => tiktoken_rs::o200k_base(),
}
.map_err(|e| EngramError::Internal(format!("Failed to initialize tokenizer: {}", e)))?;
let mut memory_tokens = Vec::new();
let mut total_tokens = 0;
for (id, content) in contents {
let tokens = bpe.encode_with_special_tokens(content).len();
total_tokens += tokens;
let preview = if content.len() > 50 {
format!("{}...", &content[..50])
} else {
content.clone()
};
memory_tokens.push(MemoryTokenCount {
memory_id: *id,
tokens,
content_preview: preview,
});
}
Ok(ContextBudgetResult::new(
total_tokens,
budget,
model,
token_encoding,
memory_tokens,
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_encoding() {
assert_eq!(detect_encoding("gpt-4"), Some(TokenEncoding::Cl100kBase));
assert_eq!(
detect_encoding("gpt-4-turbo"),
Some(TokenEncoding::Cl100kBase)
);
assert_eq!(detect_encoding("gpt-4o"), Some(TokenEncoding::O200kBase));
assert_eq!(
detect_encoding("gpt-4o-mini"),
Some(TokenEncoding::O200kBase)
);
assert_eq!(
detect_encoding("claude-3-opus"),
Some(TokenEncoding::Cl100kBase)
);
assert_eq!(
detect_encoding("text-embedding-3-small"),
Some(TokenEncoding::Cl100kBase)
);
assert_eq!(detect_encoding("unknown-model"), None);
}
#[test]
fn test_count_tokens_known_model() {
let result = count_tokens("Hello, world!", "gpt-4", None);
assert!(result.is_ok());
assert!(result.unwrap() > 0);
}
#[test]
fn test_count_tokens_unknown_model_no_encoding() {
let result = count_tokens("Hello, world!", "unknown-model", None);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Unknown model"));
}
#[test]
fn test_count_tokens_unknown_model_with_encoding() {
let result = count_tokens("Hello, world!", "unknown-model", Some("cl100k_base"));
assert!(result.is_ok());
}
#[test]
fn test_context_budget_under() {
let contents = vec![
(1, "Hello world".to_string()),
(2, "Test content".to_string()),
];
let result = check_context_budget(&contents, "gpt-4", None, 1000).unwrap();
assert!(!result.over_budget);
assert!(result.remaining > 0);
assert_eq!(result.memories_counted, 2);
}
#[test]
fn test_context_budget_over() {
let contents = vec![(1, "A".repeat(10000))];
let result = check_context_budget(&contents, "gpt-4", None, 100).unwrap();
assert!(result.over_budget);
assert_eq!(result.remaining, 0);
assert!(!result.suggestions.is_empty());
}
}