use anyhow::Result;
use crate::claude::ai::AiClientMetadata;
use crate::claude::error::ClaudeError;
const CHARS_PER_TOKEN: f64 = 2.5;
const SAFETY_MARGIN: f64 = 1.20;
#[must_use]
pub(crate) fn estimate_tokens(text: &str) -> usize {
estimate_tokens_from_char_count(text.len())
}
#[must_use]
pub(crate) fn estimate_tokens_from_char_count(char_count: usize) -> usize {
let raw_estimate = char_count as f64 / CHARS_PER_TOKEN;
(raw_estimate * SAFETY_MARGIN).ceil() as usize
}
#[derive(Debug, Clone)]
pub(crate) struct TokenEstimate {
pub estimated_tokens: usize,
pub available_tokens: usize,
pub utilization_pct: f64,
}
#[derive(Debug, Clone)]
pub(crate) struct TokenBudget {
model: String,
max_context_length: usize,
reserved_output_tokens: usize,
}
impl TokenBudget {
#[must_use]
pub fn from_metadata(metadata: &AiClientMetadata) -> Self {
Self {
model: metadata.model.clone(),
max_context_length: metadata.max_context_length,
reserved_output_tokens: metadata.max_response_length,
}
}
#[must_use]
pub(crate) fn available_input_tokens(&self) -> usize {
self.max_context_length
.saturating_sub(self.reserved_output_tokens)
}
pub fn validate_prompt(&self, system_prompt: &str, user_prompt: &str) -> Result<TokenEstimate> {
let system_tokens = estimate_tokens(system_prompt);
let user_tokens = estimate_tokens(user_prompt);
let estimated_tokens = system_tokens + user_tokens;
let available = self.available_input_tokens();
let utilization_pct = if available > 0 {
(estimated_tokens as f64 / available as f64) * 100.0
} else {
f64::INFINITY
};
if estimated_tokens > available {
return Err(ClaudeError::PromptTooLarge {
estimated_tokens,
max_tokens: available,
model: self.model.clone(),
}
.into());
}
Ok(TokenEstimate {
estimated_tokens,
available_tokens: available,
utilization_pct,
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn estimate_tokens_empty_string() {
assert_eq!(estimate_tokens(""), 0);
}
#[test]
fn estimate_tokens_short_text() {
let tokens = estimate_tokens("hello");
assert_eq!(tokens, 3);
}
#[test]
fn estimate_tokens_scales_linearly() {
let text = "a".repeat(500);
let tokens = estimate_tokens(&text);
assert_eq!(tokens, 240);
}
#[test]
fn estimate_tokens_includes_safety_margin() {
let text = "x".repeat(2500);
assert_eq!(estimate_tokens(&text), 1200);
}
fn make_metadata(context: usize, response: usize) -> AiClientMetadata {
AiClientMetadata {
provider: "test".to_string(),
model: "test-model".to_string(),
max_context_length: context,
max_response_length: response,
active_beta: None,
}
}
#[test]
fn budget_validation_within_limits() {
let metadata = make_metadata(200_000, 64_000);
let budget = TokenBudget::from_metadata(&metadata);
let estimate = budget.validate_prompt("system", "user").unwrap();
assert!(estimate.utilization_pct < 1.0);
assert_eq!(estimate.available_tokens, 136_000);
}
#[test]
fn budget_validation_exceeds_limits() {
let metadata = make_metadata(1000, 500);
let budget = TokenBudget::from_metadata(&metadata);
let large_text = "x".repeat(1100);
let result = budget.validate_prompt(&large_text, "user");
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Prompt too large"));
assert!(err_msg.contains("test-model"));
}
#[test]
fn budget_saturates_when_output_exceeds_context() {
let metadata = make_metadata(100, 200);
let budget = TokenBudget::from_metadata(&metadata);
let result = budget.validate_prompt("a", "b");
assert!(result.is_err());
}
#[test]
fn token_estimate_utilization_percentage() {
let metadata = make_metadata(200_000, 0);
let budget = TokenBudget::from_metadata(&metadata);
let estimate = budget.validate_prompt("test prompt here", "").unwrap();
assert!(estimate.utilization_pct > 0.0);
assert!(estimate.utilization_pct < 100.0);
}
#[test]
fn estimate_tokens_from_char_count_matches_estimate_tokens() {
let text = "hello world, this is a test string for token estimation";
assert_eq!(
estimate_tokens(text),
estimate_tokens_from_char_count(text.len())
);
}
#[test]
fn estimate_tokens_from_char_count_zero() {
assert_eq!(estimate_tokens_from_char_count(0), 0);
}
}