use crate::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ModelPricing {
pub model_name: String,
pub input_cost_per_1m_tokens: f64,
pub output_cost_per_1m_tokens: f64,
pub context_window: usize,
pub max_output_tokens: usize,
}
impl ModelPricing {
pub fn calculate_cost(&self, input_tokens: usize, output_tokens: usize) -> f64 {
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_cost_per_1m_tokens;
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_cost_per_1m_tokens;
input_cost + output_cost
}
pub fn fits_in_context(&self, total_tokens: usize) -> bool {
total_tokens <= self.context_window
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CostEstimate {
pub input_tokens: usize,
pub output_tokens: usize,
pub total_tokens: usize,
pub estimated_cost_usd: f64,
pub model_used: String,
pub pricing: ModelPricing,
pub breakdown: CostBreakdown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CostBreakdown {
pub input_cost: f64,
pub output_cost: f64,
pub total_cost: f64,
pub cost_per_token: f64,
}
pub struct CostCalculator {
pricing_db: Arc<RwLock<HashMap<String, ModelPricing>>>,
}
impl CostCalculator {
pub fn new() -> Self {
let calculator = Self {
pricing_db: Arc::new(RwLock::new(HashMap::new())),
};
calculator.load_default_pricing();
calculator
}
fn load_default_pricing(&self) {
let mut db = self.pricing_db.write().unwrap();
db.insert(
"gpt-4".to_string(),
ModelPricing {
model_name: "gpt-4".to_string(),
input_cost_per_1m_tokens: 30.0,
output_cost_per_1m_tokens: 60.0,
context_window: 8192,
max_output_tokens: 4096,
},
);
db.insert(
"gpt-4-turbo".to_string(),
ModelPricing {
model_name: "gpt-4-turbo".to_string(),
input_cost_per_1m_tokens: 10.0,
output_cost_per_1m_tokens: 30.0,
context_window: 128000,
max_output_tokens: 4096,
},
);
db.insert(
"gpt-4o".to_string(),
ModelPricing {
model_name: "gpt-4o".to_string(),
input_cost_per_1m_tokens: 5.0,
output_cost_per_1m_tokens: 15.0,
context_window: 128000,
max_output_tokens: 16384,
},
);
db.insert(
"gpt-3.5-turbo".to_string(),
ModelPricing {
model_name: "gpt-3.5-turbo".to_string(),
input_cost_per_1m_tokens: 0.5,
output_cost_per_1m_tokens: 1.5,
context_window: 16385,
max_output_tokens: 4096,
},
);
db.insert(
"claude-3-opus".to_string(),
ModelPricing {
model_name: "claude-3-opus".to_string(),
input_cost_per_1m_tokens: 15.0,
output_cost_per_1m_tokens: 75.0,
context_window: 200000,
max_output_tokens: 4096,
},
);
db.insert(
"claude-3-sonnet".to_string(),
ModelPricing {
model_name: "claude-3-sonnet".to_string(),
input_cost_per_1m_tokens: 3.0,
output_cost_per_1m_tokens: 15.0,
context_window: 200000,
max_output_tokens: 4096,
},
);
db.insert(
"claude-3-haiku".to_string(),
ModelPricing {
model_name: "claude-3-haiku".to_string(),
input_cost_per_1m_tokens: 0.25,
output_cost_per_1m_tokens: 1.25,
context_window: 200000,
max_output_tokens: 4096,
},
);
db.insert(
"claude-3.5-sonnet".to_string(),
ModelPricing {
model_name: "claude-3.5-sonnet".to_string(),
input_cost_per_1m_tokens: 3.0,
output_cost_per_1m_tokens: 15.0,
context_window: 200000,
max_output_tokens: 8192,
},
);
db.insert(
"llama".to_string(),
ModelPricing {
model_name: "llama".to_string(),
input_cost_per_1m_tokens: 0.0,
output_cost_per_1m_tokens: 0.0,
context_window: 4096,
max_output_tokens: 2048,
},
);
db.insert(
"local".to_string(),
ModelPricing {
model_name: "local".to_string(),
input_cost_per_1m_tokens: 0.0,
output_cost_per_1m_tokens: 0.0,
context_window: 8192,
max_output_tokens: 4096,
},
);
}
pub fn set_pricing(&self, pricing: ModelPricing) {
let mut db = self.pricing_db.write().unwrap();
db.insert(pricing.model_name.clone(), pricing);
}
pub fn get_pricing(&self, model_name: &str) -> Option<ModelPricing> {
self.pricing_db.read().unwrap().get(model_name).cloned()
}
pub fn calculate_cost(
&self,
model_name: &str,
input_tokens: usize,
output_tokens: usize,
) -> Result<CostEstimate> {
let pricing = self
.get_pricing(model_name)
.ok_or_else(|| crate::ZoeyError::Other(format!("Unknown model: {}", model_name)))?;
let input_cost = (input_tokens as f64 / 1_000_000.0) * pricing.input_cost_per_1m_tokens;
let output_cost = (output_tokens as f64 / 1_000_000.0) * pricing.output_cost_per_1m_tokens;
let total_cost = input_cost + output_cost;
let total_tokens = input_tokens + output_tokens;
let cost_per_token = if total_tokens > 0 {
total_cost / total_tokens as f64
} else {
0.0
};
Ok(CostEstimate {
input_tokens,
output_tokens,
total_tokens,
estimated_cost_usd: total_cost,
model_used: model_name.to_string(),
pricing,
breakdown: CostBreakdown {
input_cost,
output_cost,
total_cost,
cost_per_token,
},
})
}
pub fn find_cheaper_model(&self, current_model: &str, min_context: usize) -> Option<String> {
let db = self.pricing_db.read().unwrap();
let current_pricing = db.get(current_model)?;
let current_avg_cost = (current_pricing.input_cost_per_1m_tokens
+ current_pricing.output_cost_per_1m_tokens)
/ 2.0;
let mut cheaper_models: Vec<_> = db
.values()
.filter(|p| {
let avg_cost = (p.input_cost_per_1m_tokens + p.output_cost_per_1m_tokens) / 2.0;
avg_cost < current_avg_cost && p.context_window >= min_context
})
.collect();
cheaper_models.sort_by(|a, b| {
let avg_a = (a.input_cost_per_1m_tokens + a.output_cost_per_1m_tokens) / 2.0;
let avg_b = (b.input_cost_per_1m_tokens + b.output_cost_per_1m_tokens) / 2.0;
avg_a.partial_cmp(&avg_b).unwrap()
});
cheaper_models.first().map(|p| p.model_name.clone())
}
pub fn get_models_by_cost(&self) -> Vec<ModelPricing> {
let db = self.pricing_db.read().unwrap();
let mut models: Vec<_> = db.values().cloned().collect();
models.sort_by(|a, b| {
let avg_a = (a.input_cost_per_1m_tokens + a.output_cost_per_1m_tokens) / 2.0;
let avg_b = (b.input_cost_per_1m_tokens + b.output_cost_per_1m_tokens) / 2.0;
avg_a.partial_cmp(&avg_b).unwrap()
});
models
}
pub fn recommend_model(
&self,
budget_usd: f64,
estimated_tokens: usize,
min_context: usize,
) -> Option<String> {
let db = self.pricing_db.read().unwrap();
let estimated_input = estimated_tokens / 2;
let estimated_output = estimated_tokens / 2;
let mut suitable: Vec<_> = db
.values()
.filter(|p| {
let cost = p.calculate_cost(estimated_input, estimated_output);
cost <= budget_usd && p.context_window >= min_context
})
.collect();
if suitable.is_empty() {
return None;
}
suitable.sort_by(|a, b| {
let avg_a = (a.input_cost_per_1m_tokens + a.output_cost_per_1m_tokens) / 2.0;
let avg_b = (b.input_cost_per_1m_tokens + b.output_cost_per_1m_tokens) / 2.0;
avg_b.partial_cmp(&avg_a).unwrap() });
suitable.first().map(|p| p.model_name.clone())
}
}
impl Default for CostCalculator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cost_calculation() {
let calculator = CostCalculator::new();
let estimate = calculator.calculate_cost("gpt-4", 1000, 500).unwrap();
assert_eq!(estimate.input_tokens, 1000);
assert_eq!(estimate.output_tokens, 500);
assert!(estimate.estimated_cost_usd > 0.0);
assert!((estimate.estimated_cost_usd - 0.06).abs() < 0.001);
}
#[test]
fn test_cheaper_model() {
let calculator = CostCalculator::new();
let cheaper = calculator.find_cheaper_model("gpt-4", 8000);
assert!(cheaper.is_some());
let cheaper_model = cheaper.unwrap();
assert_ne!(cheaper_model, "gpt-4");
let gpt4_pricing = calculator.get_pricing("gpt-4").unwrap();
let cheaper_pricing = calculator.get_pricing(&cheaper_model).unwrap();
let gpt4_avg =
(gpt4_pricing.input_cost_per_1m_tokens + gpt4_pricing.output_cost_per_1m_tokens) / 2.0;
let cheaper_avg = (cheaper_pricing.input_cost_per_1m_tokens
+ cheaper_pricing.output_cost_per_1m_tokens)
/ 2.0;
assert!(cheaper_avg < gpt4_avg);
}
#[test]
fn test_model_recommendation() {
let calculator = CostCalculator::new();
let model = calculator.recommend_model(0.001, 1000, 4000);
assert!(model.is_some());
let model_name = model.unwrap();
let pricing = calculator.get_pricing(&model_name).unwrap();
assert!(pricing.input_cost_per_1m_tokens < 5.0);
}
#[test]
fn test_local_models_free() {
let calculator = CostCalculator::new();
let estimate = calculator.calculate_cost("local", 10000, 5000).unwrap();
assert_eq!(estimate.estimated_cost_usd, 0.0);
}
#[test]
fn test_models_sorted_by_cost() {
let calculator = CostCalculator::new();
let models = calculator.get_models_by_cost();
assert!(!models.is_empty());
for i in 0..models.len() - 1 {
let avg_current =
(models[i].input_cost_per_1m_tokens + models[i].output_cost_per_1m_tokens) / 2.0;
let avg_next = (models[i + 1].input_cost_per_1m_tokens
+ models[i + 1].output_cost_per_1m_tokens)
/ 2.0;
assert!(avg_current <= avg_next);
}
}
#[test]
fn test_custom_pricing() {
let calculator = CostCalculator::new();
let custom = ModelPricing {
model_name: "custom-model".to_string(),
input_cost_per_1m_tokens: 1.0,
output_cost_per_1m_tokens: 2.0,
context_window: 4096,
max_output_tokens: 2048,
};
calculator.set_pricing(custom.clone());
let retrieved = calculator.get_pricing("custom-model").unwrap();
assert_eq!(retrieved.model_name, "custom-model");
assert_eq!(retrieved.input_cost_per_1m_tokens, 1.0);
}
}