use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct Model {
pub id: String,
pub name: String,
pub provider: String,
pub reasoning: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost: Option<ModelCost>,
pub limit: ModelLimit,
#[serde(skip_serializing_if = "Option::is_none")]
pub release_date: Option<String>,
}
impl Model {
pub fn new(
id: impl Into<String>,
name: impl Into<String>,
provider: impl Into<String>,
reasoning: bool,
cost: Option<ModelCost>,
limit: ModelLimit,
) -> Self {
Self {
id: id.into(),
name: name.into(),
provider: provider.into(),
reasoning,
cost,
limit,
release_date: None,
}
}
pub fn custom(id: impl Into<String>, provider: impl Into<String>) -> Self {
let id = id.into();
Self {
name: id.clone(),
id,
provider: provider.into(),
reasoning: false,
cost: None,
limit: ModelLimit::default(),
release_date: None,
}
}
pub fn has_pricing(&self) -> bool {
self.cost.is_some()
}
pub fn display_name(&self) -> &str {
&self.name
}
pub fn model_id(&self) -> &str {
&self.id
}
pub fn provider_name(&self) -> &str {
&self.provider
}
}
impl std::fmt::Display for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelCost {
pub input: f64,
pub output: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_read: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_write: Option<f64>,
}
impl ModelCost {
pub fn new(input: f64, output: f64) -> Self {
Self {
input,
output,
cache_read: None,
cache_write: None,
}
}
pub fn with_cache(input: f64, output: f64, cache_read: f64, cache_write: f64) -> Self {
Self {
input,
output,
cache_read: Some(cache_read),
cache_write: Some(cache_write),
}
}
pub fn calculate(&self, input_tokens: u64, output_tokens: u64) -> f64 {
let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input;
let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output;
input_cost + output_cost
}
pub fn calculate_with_cache(
&self,
input_tokens: u64,
output_tokens: u64,
cache_read_tokens: u64,
cache_write_tokens: u64,
) -> f64 {
let base_cost = self.calculate(input_tokens, output_tokens);
let cache_read_cost = self
.cache_read
.map(|rate| (cache_read_tokens as f64 / 1_000_000.0) * rate)
.unwrap_or(0.0);
let cache_write_cost = self
.cache_write
.map(|rate| (cache_write_tokens as f64 / 1_000_000.0) * rate)
.unwrap_or(0.0);
base_cost + cache_read_cost + cache_write_cost
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelLimit {
pub context: u64,
pub output: u64,
}
impl ModelLimit {
pub fn new(context: u64, output: u64) -> Self {
Self { context, output }
}
}
impl Default for ModelLimit {
fn default() -> Self {
Self {
context: 128_000,
output: 8_192,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_creation() {
let model = Model::new(
"claude-sonnet-4-5-20250929",
"Claude Sonnet 4.5",
"anthropic",
true,
Some(ModelCost::with_cache(3.0, 15.0, 0.30, 3.75)),
ModelLimit::new(200_000, 16_384),
);
assert_eq!(model.id, "claude-sonnet-4-5-20250929");
assert_eq!(model.name, "Claude Sonnet 4.5");
assert_eq!(model.provider, "anthropic");
assert!(model.reasoning);
assert!(model.has_pricing());
}
#[test]
fn test_custom_model() {
let model = Model::custom("llama3", "ollama");
assert_eq!(model.id, "llama3");
assert_eq!(model.name, "llama3");
assert_eq!(model.provider, "ollama");
assert!(!model.reasoning);
assert!(!model.has_pricing());
}
#[test]
fn test_cost_calculation() {
let cost = ModelCost::new(3.0, 15.0);
let total = cost.calculate(1000, 500);
assert!((total - 0.0105).abs() < 0.0001);
}
#[test]
fn test_cost_with_cache() {
let cost = ModelCost::with_cache(3.0, 15.0, 0.30, 3.75);
let total = cost.calculate_with_cache(1000, 500, 2000, 1000);
assert!((total - 0.01485).abs() < 0.0001);
}
#[test]
fn test_model_display() {
let model = Model::new(
"gpt-5",
"GPT-5",
"openai",
false,
None,
ModelLimit::default(),
);
assert_eq!(format!("{}", model), "GPT-5");
}
#[test]
fn test_serialization() {
let model = Model::new(
"claude-sonnet-4-5-20250929",
"Claude Sonnet 4.5",
"anthropic",
true,
Some(ModelCost::new(3.0, 15.0)),
ModelLimit::new(200_000, 16_384),
);
let json = serde_json::to_string(&model).unwrap();
assert!(json.contains("\"id\":\"claude-sonnet-4-5-20250929\""));
assert!(json.contains("\"provider\":\"anthropic\""));
let deserialized: Model = serde_json::from_str(&json).unwrap();
assert_eq!(model, deserialized);
}
}