use std::collections::HashMap;
use super::error::LlmError;
#[derive(Debug, Clone)]
pub struct CostEstimator {
model_rates: HashMap<String, f64>,
}
impl Default for CostEstimator {
fn default() -> Self {
Self::new()
}
}
impl CostEstimator {
pub fn new() -> Self {
let mut model_rates = HashMap::new();
model_rates.insert("gpt-4".to_string(), 0.03);
model_rates.insert("gpt-4o".to_string(), 0.005);
model_rates.insert("gpt-3.5-turbo".to_string(), 0.002);
model_rates.insert("kimi-k2".to_string(), 0.008);
model_rates.insert("kimi-k1.5".to_string(), 0.008);
model_rates.insert("claude-3-opus".to_string(), 0.015);
model_rates.insert("claude-3-sonnet".to_string(), 0.003);
model_rates.insert("default".to_string(), 0.005);
Self { model_rates }
}
pub fn estimate(&self, prompt_tokens: usize, completion_tokens: usize, model: &str) -> f64 {
let rate = self
.model_rates
.get(model)
.or_else(|| self.model_rates.get("default"))
.copied()
.unwrap_or(0.005);
let total = prompt_tokens + completion_tokens;
(total as f64 / 1_000.0) * rate
}
pub fn count_tokens(&self, text: &str, model: &str) -> Result<usize, LlmError> {
let lower = model.to_lowercase();
let bpe = if lower.contains("gpt-4")
|| lower.contains("gpt-3.5-turbo")
|| lower.contains("kimi")
|| lower.contains("text-embedding")
|| lower.contains("claude")
{
tiktoken_rs::cl100k_base_singleton()
} else if lower.contains("text-davinci-003")
|| lower.contains("text-davinci-002")
|| lower.contains("code-davinci")
|| lower.contains("code-cushman")
{
tiktoken_rs::p50k_base_singleton()
} else if lower.contains("text-davinci-001")
|| lower.contains("davinci")
|| lower.contains("curie")
|| lower.contains("babbage")
|| lower.contains("ada")
|| lower.contains("gpt2")
{
tiktoken_rs::r50k_base_singleton()
} else {
tiktoken_rs::cl100k_base_singleton()
};
let tokens = bpe.encode_with_special_tokens(text);
Ok(tokens.len())
}
pub fn set_rate(&mut self, model: String, rate: f64) {
self.model_rates.insert(model, rate);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cost_estimator_count_tokens() {
let estimator = CostEstimator::new();
let n = estimator.count_tokens("hello world", "gpt-4").unwrap();
assert_eq!(n, 2);
}
#[test]
fn test_cost_estimator_count_tokens_empty() {
let estimator = CostEstimator::new();
let n = estimator.count_tokens("", "gpt-4").unwrap();
assert_eq!(n, 0);
}
#[test]
fn test_cost_estimator_estimate_usd() {
let estimator = CostEstimator::new();
let usd = estimator.estimate(1_000, 1_000, "gpt-4");
assert!((usd - 0.06).abs() < f64::EPSILON * 10.0);
}
#[test]
fn test_cost_estimator_estimate_unknown_model() {
let estimator = CostEstimator::new();
let usd = estimator.estimate(1_000, 0, "unknown-model");
assert!((usd - 0.005).abs() < f64::EPSILON * 10.0);
}
#[test]
fn test_cost_estimator_custom_rate() {
let mut estimator = CostEstimator::new();
estimator.set_rate("custom".to_string(), 0.01);
let usd = estimator.estimate(1_000, 0, "custom");
assert!((usd - 0.01).abs() < f64::EPSILON * 10.0);
}
#[test]
fn test_cost_estimator_count_tokens_davinci() {
let estimator = CostEstimator::new();
let n = estimator
.count_tokens("hello world", "text-davinci-003")
.unwrap();
assert!(n > 0);
}
#[test]
fn test_cost_estimator_count_tokens_unknown_model() {
let estimator = CostEstimator::new();
let n = estimator
.count_tokens("hello world", "unknown-model")
.unwrap();
let gpt4 = estimator.count_tokens("hello world", "gpt-4").unwrap();
assert_eq!(n, gpt4);
}
#[test]
fn test_cost_estimator_default() {
let estimator: CostEstimator = Default::default();
let usd = estimator.estimate(1_000, 0, "gpt-4");
assert!(usd > 0.0);
}
}