use crate::config::models::defaults::default_true;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use chrono::{DateTime, Utc};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelPricing {
pub model: String,
pub input_cost_per_1k: f64,
pub output_cost_per_1k: f64,
pub currency: String,
#[serde(default = "chrono::Utc::now")]
pub updated_at: DateTime<Utc>,
pub notes: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderPricing {
pub provider: String,
pub default_pricing: ModelPricing,
pub models: HashMap<String, ModelPricing>,
#[serde(default)]
pub use_external_api: bool,
pub external_api_url: Option<String>,
#[serde(default = "default_cache_ttl")]
pub cache_ttl: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricingConfig {
#[serde(default = "default_currency")]
pub default_currency: String,
#[serde(default)]
pub source_priority: Vec<PricingSource>,
pub providers: HashMap<String, ProviderPricing>,
pub fallback_pricing: ModelPricing,
#[serde(default = "default_true")]
pub enable_cache: bool,
#[serde(default)]
pub auto_update: bool,
#[serde(default = "default_update_interval")]
pub update_interval: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PricingSource {
Config,
ExternalApi,
ProviderApi,
Cache,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricingUpdateEvent {
pub provider: String,
pub model: String,
pub old_pricing: ModelPricing,
pub new_pricing: ModelPricing,
pub timestamp: DateTime<Utc>,
pub source: PricingSource,
}
impl Default for PricingConfig {
fn default() -> Self {
let mut providers = HashMap::new();
providers.insert("openai".to_string(), ProviderPricing {
provider: "openai".to_string(),
default_pricing: ModelPricing {
model: "unknown".to_string(),
input_cost_per_1k: 0.01,
output_cost_per_1k: 0.03,
currency: "USD".to_string(),
updated_at: Utc::now(),
notes: Some("OpenAI default pricing".to_string()),
},
models: HashMap::new(),
use_external_api: false,
external_api_url: None,
cache_ttl: default_cache_ttl(),
});
providers.insert("glm".to_string(), ProviderPricing {
provider: "glm".to_string(),
default_pricing: ModelPricing {
model: "unknown".to_string(),
input_cost_per_1k: 0.0001,
output_cost_per_1k: 0.0003,
currency: "USD".to_string(),
updated_at: Utc::now(),
notes: Some("GLM default pricing (converted from RMB)".to_string()),
},
models: HashMap::new(),
use_external_api: false,
external_api_url: None,
cache_ttl: default_cache_ttl(),
});
Self {
default_currency: default_currency(),
source_priority: vec![
PricingSource::Cache,
PricingSource::Config,
PricingSource::ExternalApi,
PricingSource::ProviderApi,
],
providers,
fallback_pricing: ModelPricing {
model: "fallback".to_string(),
input_cost_per_1k: 0.01,
output_cost_per_1k: 0.01,
currency: "USD".to_string(),
updated_at: Utc::now(),
notes: Some("Global fallback pricing".to_string()),
},
enable_cache: true,
auto_update: false,
update_interval: default_update_interval(),
}
}
}
impl ModelPricing {
pub fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
let input_cost = (input_tokens as f64 / 1000.0) * self.input_cost_per_1k;
let output_cost = (output_tokens as f64 / 1000.0) * self.output_cost_per_1k;
input_cost + output_cost
}
pub fn is_stale(&self, max_age_seconds: u64) -> bool {
let now = Utc::now();
let age = now.signed_duration_since(self.updated_at);
age.num_seconds() > max_age_seconds as i64
}
pub fn new(model: &str, input_cost: f64, output_cost: f64, currency: &str) -> Self {
Self {
model: model.to_string(),
input_cost_per_1k: input_cost,
output_cost_per_1k: output_cost,
currency: currency.to_string(),
updated_at: Utc::now(),
notes: None,
}
}
}
impl ProviderPricing {
pub fn get_model_pricing(&self, model: &str) -> &ModelPricing {
self.models.get(model).unwrap_or(&self.default_pricing)
}
pub fn set_model_pricing(&mut self, pricing: ModelPricing) {
self.models.insert(pricing.model.clone(), pricing);
}
pub fn remove_model_pricing(&mut self, model: &str) -> Option<ModelPricing> {
self.models.remove(model)
}
pub fn get_models(&self) -> Vec<&str> {
self.models.keys().map(|k| k.as_str()).collect()
}
}
fn default_currency() -> String {
"USD".to_string()
}
fn default_cache_ttl() -> u64 {
3600 }
fn default_update_interval() -> u64 {
86400 }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_pricing_calculate_cost() {
let pricing = ModelPricing::new("test-model", 0.001, 0.002, "USD");
let cost = pricing.calculate_cost(1000, 500);
assert!((cost - 0.002).abs() < f64::EPSILON);
}
#[test]
fn test_pricing_is_stale() {
let mut pricing = ModelPricing::new("test", 0.001, 0.002, "USD");
assert!(!pricing.is_stale(3600));
pricing.updated_at = Utc::now() - chrono::Duration::hours(2);
assert!(pricing.is_stale(3600)); }
#[test]
fn test_provider_pricing_get_model() {
let mut provider = ProviderPricing {
provider: "test".to_string(),
default_pricing: ModelPricing::new("default", 0.01, 0.02, "USD"),
models: HashMap::new(),
use_external_api: false,
external_api_url: None,
cache_ttl: 3600,
};
let pricing = provider.get_model_pricing("unknown-model");
assert_eq!(pricing.input_cost_per_1k, 0.01);
provider.set_model_pricing(ModelPricing::new("specific", 0.005, 0.01, "USD"));
let pricing = provider.get_model_pricing("specific");
assert_eq!(pricing.input_cost_per_1k, 0.005);
}
#[test]
fn test_pricing_config_default() {
let config = PricingConfig::default();
assert_eq!(config.default_currency, "USD");
assert!(config.providers.contains_key("openai"));
assert!(config.providers.contains_key("glm"));
assert!(config.enable_cache);
}
}