use serde::{Deserialize, Serialize};
use super::tokens::TokenCounts;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostEstimate {
pub input_cost: f64,
pub output_cost: f64,
pub total_cost: f64,
pub currency: String,
}
impl Default for CostEstimate {
fn default() -> Self {
Self {
input_cost: 0.0,
output_cost: 0.0,
total_cost: 0.0,
currency: "USD".to_string(),
}
}
}
impl CostEstimate {
pub fn from_tokens(tokens: &TokenCounts, input_price: f64, output_price: f64) -> Self {
let input_cost = (tokens.input_tokens as f64 / 1_000_000.0) * input_price;
let output_cost = (tokens.output_tokens as f64 / 1_000_000.0) * output_price;
Self {
input_cost,
output_cost,
total_cost: input_cost + output_cost,
currency: "USD".to_string(),
}
}
pub fn format_currency(&self) -> String {
format!("${:.4}", self.total_cost)
}
pub fn format_smart(&self) -> String {
if self.total_cost < 0.01 {
format!("${:.4}", self.total_cost)
} else {
format!("${:.2}", self.total_cost)
}
}
}