use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use crate::prices::{builtin_prices, ModelPrice};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct EstimatedCost {
pub total_usd: f64,
pub input_usd: f64,
pub output_usd: f64,
pub cached_input_usd: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CapExceeded {
pub projected_usd: f64,
pub cap_usd: f64,
pub model: String,
}
impl fmt::Display for CapExceeded {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"estimated cost ${:.6} for model {:?} exceeds cap ${:.6}",
self.projected_usd, self.model, self.cap_usd
)
}
}
impl Error for CapExceeded {}
#[derive(Debug, Clone, PartialEq)]
pub struct UnknownModel {
pub model: String,
}
impl fmt::Display for UnknownModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"unknown model {:?}: pass a custom price table or call add_model",
self.model
)
}
}
impl Error for UnknownModel {}
#[derive(Debug, Clone, PartialEq)]
pub enum EstimateError {
UnknownModel(UnknownModel),
}
impl fmt::Display for EstimateError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EstimateError::UnknownModel(e) => e.fmt(f),
}
}
}
impl Error for EstimateError {}
impl From<UnknownModel> for EstimateError {
fn from(e: UnknownModel) -> Self {
EstimateError::UnknownModel(e)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum CheckError {
UnknownModel(UnknownModel),
CapExceeded(CapExceeded),
}
impl fmt::Display for CheckError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CheckError::UnknownModel(e) => e.fmt(f),
CheckError::CapExceeded(e) => e.fmt(f),
}
}
}
impl Error for CheckError {}
impl From<UnknownModel> for CheckError {
fn from(e: UnknownModel) -> Self {
CheckError::UnknownModel(e)
}
}
impl From<CapExceeded> for CheckError {
fn from(e: CapExceeded) -> Self {
CheckError::CapExceeded(e)
}
}
impl From<EstimateError> for CheckError {
fn from(e: EstimateError) -> Self {
match e {
EstimateError::UnknownModel(u) => CheckError::UnknownModel(u),
}
}
}
#[derive(Debug, Clone)]
pub struct CostCap {
cap_usd: f64,
prices: HashMap<String, ModelPrice>,
}
impl CostCap {
pub fn new(max_usd: f64) -> Self {
assert!(max_usd >= 0.0, "max_usd must be >= 0");
Self {
cap_usd: max_usd,
prices: builtin_prices(),
}
}
pub fn with_prices(prices: HashMap<String, ModelPrice>, max_usd: f64) -> Self {
assert!(max_usd >= 0.0, "max_usd must be >= 0");
Self {
cap_usd: max_usd,
prices,
}
}
pub fn cap_usd(&self) -> f64 {
self.cap_usd
}
pub fn add_model<S: Into<String>>(&mut self, model: S, price: ModelPrice) {
self.prices.insert(model.into(), price);
}
pub fn known_models(&self) -> Vec<String> {
let mut ids: Vec<String> = self.prices.keys().cloned().collect();
ids.sort();
ids
}
pub fn estimate(
&self,
model: &str,
input_tokens: u64,
max_output_tokens: u64,
) -> Result<EstimatedCost, EstimateError> {
self.estimate_with_cached(model, input_tokens, max_output_tokens, 0)
}
pub fn estimate_with_cached(
&self,
model: &str,
input_tokens: u64,
max_output_tokens: u64,
cached_input_tokens: u64,
) -> Result<EstimatedCost, EstimateError> {
let price = self.prices.get(model).ok_or_else(|| UnknownModel {
model: model.to_string(),
})?;
let input_usd = (input_tokens as f64 / 1_000_000.0) * price.input_per_million_usd;
let output_usd = (max_output_tokens as f64 / 1_000_000.0) * price.output_per_million_usd;
let cached_input_usd = match price.cached_input_per_million_usd {
Some(rate) if cached_input_tokens > 0 => {
(cached_input_tokens as f64 / 1_000_000.0) * rate
}
_ => 0.0,
};
Ok(EstimatedCost {
total_usd: input_usd + output_usd + cached_input_usd,
input_usd,
output_usd,
cached_input_usd,
})
}
pub fn check(
&self,
model: &str,
input_tokens: u64,
max_output_tokens: u64,
) -> Result<EstimatedCost, CheckError> {
self.check_with_cached(model, input_tokens, max_output_tokens, 0)
}
pub fn check_with_cached(
&self,
model: &str,
input_tokens: u64,
max_output_tokens: u64,
cached_input_tokens: u64,
) -> Result<EstimatedCost, CheckError> {
let est = self.estimate_with_cached(
model,
input_tokens,
max_output_tokens,
cached_input_tokens,
)?;
if est.total_usd > self.cap_usd {
return Err(CheckError::CapExceeded(CapExceeded {
projected_usd: est.total_usd,
cap_usd: self.cap_usd,
model: model.to_string(),
}));
}
Ok(est)
}
pub fn run<T>(
&self,
model: &str,
input_tokens: u64,
max_output_tokens: u64,
f: impl FnOnce() -> T,
) -> Result<T, CheckError> {
self.check(model, input_tokens, max_output_tokens)?;
Ok(f())
}
}