use anyhow::Result;
use serde::{Deserialize, Serialize};
pub use tiktoken_rs;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TokenEncoding {
Cl100kBase,
P50kBase,
P50kEdit,
R50kBase,
ClaudeApprox,
}
impl TokenEncoding {
pub fn encoding_name(&self) -> &'static str {
match self {
TokenEncoding::Cl100kBase | TokenEncoding::ClaudeApprox => "cl100k_base",
TokenEncoding::P50kBase => "p50k_base",
TokenEncoding::P50kEdit => "p50k_edit",
TokenEncoding::R50kBase => "r50k_base",
}
}
pub fn for_model(model: &str) -> Self {
let model_lower = model.to_lowercase();
if model_lower.contains("gpt-4") || model_lower.contains("gpt-3.5") {
TokenEncoding::Cl100kBase
} else if model_lower.contains("claude") {
TokenEncoding::ClaudeApprox
} else if model_lower.contains("davinci") || model_lower.contains("curie") {
TokenEncoding::P50kBase
} else if model_lower.contains("codex") {
TokenEncoding::P50kEdit
} else {
TokenEncoding::Cl100kBase }
}
}
#[derive(Debug, Clone)]
pub struct TokenCounter {
encoding: TokenEncoding,
}
impl TokenCounter {
pub fn new(encoding: TokenEncoding) -> Self {
Self { encoding }
}
pub fn for_model(model: &str) -> Self {
Self::new(TokenEncoding::for_model(model))
}
pub fn count(&self, text: &str) -> Result<usize> {
let bpe = tiktoken_rs::get_bpe_from_model(self.encoding.encoding_name())?;
Ok(bpe.encode_ordinary(text).len())
}
pub fn count_chat_message(&self, role: &str, content: &str) -> Result<usize> {
let content_tokens = self.count(content)?;
let role_tokens = self.count(role)?;
Ok(content_tokens + role_tokens + 4)
}
pub fn encoding(&self) -> TokenEncoding {
self.encoding
}
}
impl Default for TokenCounter {
fn default() -> Self {
Self::new(TokenEncoding::Cl100kBase)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextWindow {
pub max_tokens: usize,
pub reserved_for_response: usize,
pub current_tokens: usize,
}
impl ContextWindow {
pub fn new(max_tokens: usize, reserved_for_response: usize) -> Self {
Self {
max_tokens,
reserved_for_response,
current_tokens: 0,
}
}
pub fn for_model(model: &str) -> Self {
let (max_tokens, reserved) = match model.to_lowercase().as_str() {
m if m.contains("gpt-4-turbo") || m.contains("gpt-4o") => (128_000, 4_096),
m if m.contains("gpt-4-32k") => (32_768, 4_096),
m if m.contains("gpt-4") => (8_192, 2_048),
m if m.contains("gpt-3.5-turbo-16k") => (16_384, 4_096),
m if m.contains("gpt-3.5") => (4_096, 1_024),
m if m.contains("claude-3-opus") => (200_000, 4_096),
m if m.contains("claude-3-sonnet") => (200_000, 4_096),
m if m.contains("claude-3-haiku") => (200_000, 4_096),
m if m.contains("claude-2") => (100_000, 4_096),
_ => (8_192, 2_048), };
Self::new(max_tokens, reserved)
}
pub fn available(&self) -> usize {
self.max_tokens
.saturating_sub(self.reserved_for_response)
.saturating_sub(self.current_tokens)
}
pub fn add(&mut self, tokens: usize) {
self.current_tokens = self.current_tokens.saturating_add(tokens);
}
pub fn reset(&mut self) {
self.current_tokens = 0;
}
pub fn can_fit(&self, tokens: usize) -> bool {
self.available() >= tokens
}
pub fn usage_percent(&self) -> f64 {
let usable = self.max_tokens.saturating_sub(self.reserved_for_response);
if usable == 0 {
return 100.0;
}
(self.current_tokens as f64 / usable as f64) * 100.0
}
}
pub struct TextTruncator {
counter: TokenCounter,
}
impl TextTruncator {
pub fn new(encoding: TokenEncoding) -> Self {
Self {
counter: TokenCounter::new(encoding),
}
}
pub fn truncate(&self, text: &str, max_tokens: usize) -> Result<String> {
let tokens = self.counter.count(text)?;
if tokens <= max_tokens {
return Ok(text.to_string());
}
let mut low = 0;
let mut high = text.len();
let mut best = 0;
while low < high {
let mid = (low + high) / 2;
let truncated = &text[..mid];
let count = self.counter.count(truncated)?;
if count <= max_tokens {
best = mid;
low = mid + 1;
} else {
high = mid;
}
}
let mut result = text[..best].to_string();
while !result.is_empty() && !text.is_char_boundary(result.len()) {
result.pop();
}
Ok(result)
}
pub fn truncate_with_ellipsis(&self, text: &str, max_tokens: usize) -> Result<String> {
let ellipsis = "...";
let ellipsis_tokens = self.counter.count(ellipsis)?;
if max_tokens <= ellipsis_tokens {
return Ok(ellipsis.to_string());
}
let truncated = self.truncate(text, max_tokens - ellipsis_tokens)?;
if truncated.len() < text.len() {
Ok(format!("{}{}", truncated, ellipsis))
} else {
Ok(text.to_string())
}
}
}
impl Default for TextTruncator {
fn default() -> Self {
Self::new(TokenEncoding::Cl100kBase)
}
}
#[derive(Debug, Clone)]
pub struct TokenBudget {
window: ContextWindow,
counter: TokenCounter,
messages: Vec<(String, usize)>, }
impl TokenBudget {
pub fn for_model(model: &str) -> Self {
Self {
window: ContextWindow::for_model(model),
counter: TokenCounter::for_model(model),
messages: Vec::new(),
}
}
pub fn add_message(&mut self, content: &str) -> Result<bool> {
let tokens = self.counter.count(content)?;
if !self.window.can_fit(tokens) {
return Ok(false);
}
self.window.add(tokens);
self.messages.push((content.to_string(), tokens));
Ok(true)
}
pub fn make_room(&mut self, needed_tokens: usize) {
while !self.window.can_fit(needed_tokens) && !self.messages.is_empty() {
let (_, tokens) = self.messages.remove(0);
self.window.current_tokens = self.window.current_tokens.saturating_sub(tokens);
}
}
pub fn stats(&self) -> BudgetStats {
BudgetStats {
current_tokens: self.window.current_tokens,
max_tokens: self.window.max_tokens,
available_tokens: self.window.available(),
usage_percent: self.window.usage_percent(),
message_count: self.messages.len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BudgetStats {
pub current_tokens: usize,
pub max_tokens: usize,
pub available_tokens: usize,
pub usage_percent: f64,
pub message_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encoding_for_model() {
assert_eq!(TokenEncoding::for_model("gpt-4"), TokenEncoding::Cl100kBase);
assert_eq!(
TokenEncoding::for_model("claude-3-opus"),
TokenEncoding::ClaudeApprox
);
}
#[test]
fn test_context_window() {
let mut window = ContextWindow::new(8192, 2048);
assert_eq!(window.available(), 6144);
window.add(1000);
assert_eq!(window.available(), 5144);
assert!(window.can_fit(5000));
assert!(!window.can_fit(6000));
}
#[test]
fn test_context_window_for_model() {
let window = ContextWindow::for_model("gpt-4-turbo");
assert_eq!(window.max_tokens, 128_000);
let window = ContextWindow::for_model("claude-3-opus");
assert_eq!(window.max_tokens, 200_000);
}
}