use super::ContentError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ModelContext {
#[default]
Claude200K,
ClaudeHaiku,
GeminiPro,
GeminiFlash,
Gpt4Turbo,
Custom(usize),
}
impl ModelContext {
pub fn window_size(&self) -> usize {
match self {
ModelContext::Claude200K => 200_000,
ModelContext::ClaudeHaiku => 200_000,
ModelContext::GeminiPro => 1_000_000,
ModelContext::GeminiFlash => 1_000_000,
ModelContext::Gpt4Turbo => 128_000,
ModelContext::Custom(size) => *size,
}
}
pub fn name(&self) -> &'static str {
match self {
ModelContext::Claude200K => "claude-sonnet",
ModelContext::ClaudeHaiku => "claude-haiku",
ModelContext::GeminiPro => "gemini-pro",
ModelContext::GeminiFlash => "gemini-flash",
ModelContext::Gpt4Turbo => "gpt-4-turbo",
ModelContext::Custom(_) => "custom",
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TokenBudget {
pub context_window: usize,
pub system_reserve: usize,
pub source_context: usize,
pub rag_context: usize,
pub few_shot: usize,
pub output_target: usize,
}
impl TokenBudget {
pub fn new(model: ModelContext) -> Self {
Self {
context_window: model.window_size(),
system_reserve: 2_000,
source_context: 0,
rag_context: 0,
few_shot: 1_500,
output_target: 4_000,
}
}
pub fn with_source_context(mut self, tokens: usize) -> Self {
self.source_context = tokens;
self
}
pub fn with_rag_context(mut self, tokens: usize) -> Self {
self.rag_context = tokens;
self
}
pub fn with_output_target(mut self, tokens: usize) -> Self {
self.output_target = tokens;
self
}
pub fn prompt_tokens(&self) -> usize {
self.system_reserve + self.source_context + self.rag_context + self.few_shot
}
pub fn available_margin(&self) -> usize {
let used = self.prompt_tokens() + self.output_target;
self.context_window.saturating_sub(used)
}
pub fn validate(&self) -> Result<(), ContentError> {
let total = self.prompt_tokens() + self.output_target;
if total > self.context_window {
Err(ContentError::TokenBudgetExceeded { used: total, limit: self.context_window })
} else {
Ok(())
}
}
pub fn words_to_tokens(words: usize) -> usize {
(words as f64 * 1.3).ceil() as usize
}
pub fn tokens_to_words(tokens: usize) -> usize {
(tokens as f64 / 1.3).floor() as usize
}
pub fn format_display(&self, model_name: &str) -> String {
let mut output = String::new();
output.push_str(&format!(
"Token Budget for {} ({}K context):\n",
model_name,
self.context_window / 1000
));
output.push_str(&format!("├── System prompt: {:>6} tokens\n", self.system_reserve));
output.push_str(&format!("├── Source context: {:>6} tokens\n", self.source_context));
output.push_str(&format!("├── RAG context: {:>6} tokens\n", self.rag_context));
output.push_str(&format!("├── Few-shot examples: {:>6} tokens\n", self.few_shot));
output.push_str(&format!(
"├── Output reserved: {:>6} tokens (~{} words)\n",
self.output_target,
Self::tokens_to_words(self.output_target)
));
let margin = self.available_margin();
let status = if margin > 0 { "✓" } else { "✗" };
output.push_str(&format!("└── Available margin: {:>6} tokens {}\n", margin, status));
output
}
}
impl Default for TokenBudget {
fn default() -> Self {
Self::new(ModelContext::Claude200K)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_context_default() {
let ctx = ModelContext::default();
assert_eq!(ctx, ModelContext::Claude200K);
}
#[test]
fn test_model_context_window_sizes() {
assert_eq!(ModelContext::Claude200K.window_size(), 200_000);
assert_eq!(ModelContext::ClaudeHaiku.window_size(), 200_000);
assert_eq!(ModelContext::GeminiPro.window_size(), 1_000_000);
assert_eq!(ModelContext::GeminiFlash.window_size(), 1_000_000);
assert_eq!(ModelContext::Gpt4Turbo.window_size(), 128_000);
assert_eq!(ModelContext::Custom(50_000).window_size(), 50_000);
}
#[test]
fn test_model_context_names() {
assert_eq!(ModelContext::Claude200K.name(), "claude-sonnet");
assert_eq!(ModelContext::ClaudeHaiku.name(), "claude-haiku");
assert_eq!(ModelContext::GeminiPro.name(), "gemini-pro");
assert_eq!(ModelContext::GeminiFlash.name(), "gemini-flash");
assert_eq!(ModelContext::Gpt4Turbo.name(), "gpt-4-turbo");
assert_eq!(ModelContext::Custom(1000).name(), "custom");
}
#[test]
fn test_model_context_serialization() {
let ctx = ModelContext::GeminiPro;
let json = serde_json::to_string(&ctx).expect("json serialize failed");
let deserialized: ModelContext =
serde_json::from_str(&json).expect("json deserialize failed");
assert_eq!(deserialized, ctx);
}
#[test]
fn test_model_context_custom_serialization() {
let ctx = ModelContext::Custom(75_000);
let json = serde_json::to_string(&ctx).expect("json serialize failed");
let deserialized: ModelContext =
serde_json::from_str(&json).expect("json deserialize failed");
assert_eq!(deserialized, ctx);
assert_eq!(deserialized.window_size(), 75_000);
}
#[test]
fn test_token_budget_new() {
let budget = TokenBudget::new(ModelContext::Claude200K);
assert_eq!(budget.context_window, 200_000);
assert_eq!(budget.system_reserve, 2_000);
assert_eq!(budget.source_context, 0);
assert_eq!(budget.rag_context, 0);
assert_eq!(budget.few_shot, 1_500);
assert_eq!(budget.output_target, 4_000);
}
#[test]
fn test_token_budget_default() {
let budget = TokenBudget::default();
assert_eq!(budget.context_window, 200_000);
}
#[test]
fn test_token_budget_with_source_context() {
let budget = TokenBudget::new(ModelContext::Claude200K).with_source_context(10_000);
assert_eq!(budget.source_context, 10_000);
}
#[test]
fn test_token_budget_with_rag_context() {
let budget = TokenBudget::new(ModelContext::Claude200K).with_rag_context(5_000);
assert_eq!(budget.rag_context, 5_000);
}
#[test]
fn test_token_budget_with_output_target() {
let budget = TokenBudget::new(ModelContext::Claude200K).with_output_target(8_000);
assert_eq!(budget.output_target, 8_000);
}
#[test]
fn test_token_budget_prompt_tokens() {
let budget = TokenBudget::new(ModelContext::Claude200K)
.with_source_context(10_000)
.with_rag_context(5_000);
assert_eq!(budget.prompt_tokens(), 18_500);
}
#[test]
fn test_token_budget_available_margin() {
let budget = TokenBudget::new(ModelContext::Claude200K);
let margin = budget.available_margin();
assert_eq!(margin, 200_000 - 3_500 - 4_000);
}
#[test]
fn test_token_budget_validate_ok() {
let budget = TokenBudget::new(ModelContext::Claude200K);
assert!(budget.validate().is_ok());
}
#[test]
fn test_token_budget_validate_exceeded() {
let budget = TokenBudget::new(ModelContext::Custom(1_000)).with_output_target(2_000);
assert!(budget.validate().is_err());
}
#[test]
fn test_words_to_tokens() {
assert_eq!(TokenBudget::words_to_tokens(100), 130);
assert_eq!(TokenBudget::words_to_tokens(0), 0);
}
#[test]
fn test_tokens_to_words() {
assert_eq!(TokenBudget::tokens_to_words(130), 100);
assert_eq!(TokenBudget::tokens_to_words(0), 0);
}
#[test]
fn test_token_budget_format_display() {
let budget = TokenBudget::new(ModelContext::Claude200K);
let output = budget.format_display("claude-sonnet");
assert!(output.contains("Token Budget for claude-sonnet"));
assert!(output.contains("200K context"));
assert!(output.contains("System prompt"));
assert!(output.contains("Available margin"));
assert!(output.contains("✓")); }
#[test]
fn test_token_budget_format_display_exceeded() {
let budget = TokenBudget::new(ModelContext::Custom(1_000)).with_output_target(2_000);
let output = budget.format_display("custom");
assert!(output.contains("Available margin"));
}
#[test]
fn test_token_budget_serialization() {
let budget = TokenBudget::new(ModelContext::GeminiPro)
.with_source_context(5_000)
.with_rag_context(3_000);
let json = serde_json::to_string(&budget).expect("json serialize failed");
let deserialized: TokenBudget =
serde_json::from_str(&json).expect("json deserialize failed");
assert_eq!(deserialized, budget);
}
#[test]
fn test_token_budget_builder_chain() {
let budget = TokenBudget::new(ModelContext::Gpt4Turbo)
.with_source_context(10_000)
.with_rag_context(8_000)
.with_output_target(6_000);
assert_eq!(budget.context_window, 128_000);
assert_eq!(budget.source_context, 10_000);
assert_eq!(budget.rag_context, 8_000);
assert_eq!(budget.output_target, 6_000);
}
}