1use std::collections::HashMap;
4use std::error::Error;
5use std::fmt;
6
7use crate::prices::{builtin_prices, ModelPrice};
8
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Copy, PartialEq)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub struct EstimatedCost {
17 pub total_usd: f64,
18 pub input_usd: f64,
19 pub output_usd: f64,
20 pub cached_input_usd: f64,
21}
22
23#[derive(Debug, Clone, PartialEq)]
25pub struct CapExceeded {
26 pub projected_usd: f64,
27 pub cap_usd: f64,
28 pub model: String,
29}
30
31impl fmt::Display for CapExceeded {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 write!(
34 f,
35 "estimated cost ${:.6} for model {:?} exceeds cap ${:.6}",
36 self.projected_usd, self.model, self.cap_usd
37 )
38 }
39}
40
41impl Error for CapExceeded {}
42
43#[derive(Debug, Clone, PartialEq)]
45pub struct UnknownModel {
46 pub model: String,
47}
48
49impl fmt::Display for UnknownModel {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(
52 f,
53 "unknown model {:?}: pass a custom price table or call add_model",
54 self.model
55 )
56 }
57}
58
59impl Error for UnknownModel {}
60
61#[derive(Debug, Clone, PartialEq)]
65pub enum EstimateError {
66 UnknownModel(UnknownModel),
67}
68
69impl fmt::Display for EstimateError {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 match self {
72 EstimateError::UnknownModel(e) => e.fmt(f),
73 }
74 }
75}
76
77impl Error for EstimateError {}
78
79impl From<UnknownModel> for EstimateError {
80 fn from(e: UnknownModel) -> Self {
81 EstimateError::UnknownModel(e)
82 }
83}
84
85#[derive(Debug, Clone, PartialEq)]
87pub enum CheckError {
88 UnknownModel(UnknownModel),
89 CapExceeded(CapExceeded),
90}
91
92impl fmt::Display for CheckError {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 match self {
95 CheckError::UnknownModel(e) => e.fmt(f),
96 CheckError::CapExceeded(e) => e.fmt(f),
97 }
98 }
99}
100
101impl Error for CheckError {}
102
103impl From<UnknownModel> for CheckError {
104 fn from(e: UnknownModel) -> Self {
105 CheckError::UnknownModel(e)
106 }
107}
108
109impl From<CapExceeded> for CheckError {
110 fn from(e: CapExceeded) -> Self {
111 CheckError::CapExceeded(e)
112 }
113}
114
115impl From<EstimateError> for CheckError {
116 fn from(e: EstimateError) -> Self {
117 match e {
118 EstimateError::UnknownModel(u) => CheckError::UnknownModel(u),
119 }
120 }
121}
122
123#[derive(Debug, Clone)]
137pub struct CostCap {
138 cap_usd: f64,
139 prices: HashMap<String, ModelPrice>,
140}
141
142impl CostCap {
143 pub fn new(max_usd: f64) -> Self {
145 assert!(max_usd >= 0.0, "max_usd must be >= 0");
146 Self {
147 cap_usd: max_usd,
148 prices: builtin_prices(),
149 }
150 }
151
152 pub fn with_prices(prices: HashMap<String, ModelPrice>, max_usd: f64) -> Self {
155 assert!(max_usd >= 0.0, "max_usd must be >= 0");
156 Self {
157 cap_usd: max_usd,
158 prices,
159 }
160 }
161
162 pub fn cap_usd(&self) -> f64 {
163 self.cap_usd
164 }
165
166 pub fn add_model<S: Into<String>>(&mut self, model: S, price: ModelPrice) {
168 self.prices.insert(model.into(), price);
169 }
170
171 pub fn known_models(&self) -> Vec<String> {
173 let mut ids: Vec<String> = self.prices.keys().cloned().collect();
174 ids.sort();
175 ids
176 }
177
178 pub fn estimate(
180 &self,
181 model: &str,
182 input_tokens: u64,
183 max_output_tokens: u64,
184 ) -> Result<EstimatedCost, EstimateError> {
185 self.estimate_with_cached(model, input_tokens, max_output_tokens, 0)
186 }
187
188 pub fn estimate_with_cached(
192 &self,
193 model: &str,
194 input_tokens: u64,
195 max_output_tokens: u64,
196 cached_input_tokens: u64,
197 ) -> Result<EstimatedCost, EstimateError> {
198 let price = self.prices.get(model).ok_or_else(|| UnknownModel {
199 model: model.to_string(),
200 })?;
201 let input_usd = (input_tokens as f64 / 1_000_000.0) * price.input_per_million_usd;
202 let output_usd = (max_output_tokens as f64 / 1_000_000.0) * price.output_per_million_usd;
203 let cached_input_usd = match price.cached_input_per_million_usd {
204 Some(rate) if cached_input_tokens > 0 => {
205 (cached_input_tokens as f64 / 1_000_000.0) * rate
206 }
207 _ => 0.0,
208 };
209 Ok(EstimatedCost {
210 total_usd: input_usd + output_usd + cached_input_usd,
211 input_usd,
212 output_usd,
213 cached_input_usd,
214 })
215 }
216
217 pub fn check(
220 &self,
221 model: &str,
222 input_tokens: u64,
223 max_output_tokens: u64,
224 ) -> Result<EstimatedCost, CheckError> {
225 self.check_with_cached(model, input_tokens, max_output_tokens, 0)
226 }
227
228 pub fn check_with_cached(
231 &self,
232 model: &str,
233 input_tokens: u64,
234 max_output_tokens: u64,
235 cached_input_tokens: u64,
236 ) -> Result<EstimatedCost, CheckError> {
237 let est = self.estimate_with_cached(
238 model,
239 input_tokens,
240 max_output_tokens,
241 cached_input_tokens,
242 )?;
243 if est.total_usd > self.cap_usd {
244 return Err(CheckError::CapExceeded(CapExceeded {
245 projected_usd: est.total_usd,
246 cap_usd: self.cap_usd,
247 model: model.to_string(),
248 }));
249 }
250 Ok(est)
251 }
252
253 pub fn run<T>(
256 &self,
257 model: &str,
258 input_tokens: u64,
259 max_output_tokens: u64,
260 f: impl FnOnce() -> T,
261 ) -> Result<T, CheckError> {
262 self.check(model, input_tokens, max_output_tokens)?;
263 Ok(f())
264 }
265}