Skip to main content

llm_cost_cap/
cap.rs

1//! Core [`CostCap`] implementation.
2
3use std::collections::HashMap;
4use std::error::Error;
5use std::fmt;
6
7use crate::prices::{builtin_prices, ModelPrice};
8
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12/// A breakdown of the per-call estimate for one model. All four fields
13/// are USD; `total_usd == input_usd + output_usd + cached_input_usd`.
14#[derive(Debug, Clone, Copy, PartialEq)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub struct EstimatedCost {
17    pub total_usd: f64,
18    pub input_usd: f64,
19    pub output_usd: f64,
20    pub cached_input_usd: f64,
21}
22
23/// Returned when the estimated cost of a single call exceeds the cap.
24#[derive(Debug, Clone, PartialEq)]
25pub struct CapExceeded {
26    pub projected_usd: f64,
27    pub cap_usd: f64,
28    pub model: String,
29}
30
31impl fmt::Display for CapExceeded {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        write!(
34            f,
35            "estimated cost ${:.6} for model {:?} exceeds cap ${:.6}",
36            self.projected_usd, self.model, self.cap_usd
37        )
38    }
39}
40
41impl Error for CapExceeded {}
42
43/// Returned when a model id is not in the price table.
44#[derive(Debug, Clone, PartialEq)]
45pub struct UnknownModel {
46    pub model: String,
47}
48
49impl fmt::Display for UnknownModel {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        write!(
52            f,
53            "unknown model {:?}: pass a custom price table or call add_model",
54            self.model
55        )
56    }
57}
58
59impl Error for UnknownModel {}
60
61/// Errors returned by [`CostCap::estimate`] when the model is unknown
62/// or token counts are negative. (Negative counts are not representable
63/// with `u64`, but the model lookup can still fail.)
64#[derive(Debug, Clone, PartialEq)]
65pub enum EstimateError {
66    UnknownModel(UnknownModel),
67}
68
69impl fmt::Display for EstimateError {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        match self {
72            EstimateError::UnknownModel(e) => e.fmt(f),
73        }
74    }
75}
76
77impl Error for EstimateError {}
78
79impl From<UnknownModel> for EstimateError {
80    fn from(e: UnknownModel) -> Self {
81        EstimateError::UnknownModel(e)
82    }
83}
84
85/// Errors returned by [`CostCap::check`].
86#[derive(Debug, Clone, PartialEq)]
87pub enum CheckError {
88    UnknownModel(UnknownModel),
89    CapExceeded(CapExceeded),
90}
91
92impl fmt::Display for CheckError {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        match self {
95            CheckError::UnknownModel(e) => e.fmt(f),
96            CheckError::CapExceeded(e) => e.fmt(f),
97        }
98    }
99}
100
101impl Error for CheckError {}
102
103impl From<UnknownModel> for CheckError {
104    fn from(e: UnknownModel) -> Self {
105        CheckError::UnknownModel(e)
106    }
107}
108
109impl From<CapExceeded> for CheckError {
110    fn from(e: CapExceeded) -> Self {
111        CheckError::CapExceeded(e)
112    }
113}
114
115impl From<EstimateError> for CheckError {
116    fn from(e: EstimateError) -> Self {
117        match e {
118            EstimateError::UnknownModel(u) => CheckError::UnknownModel(u),
119        }
120    }
121}
122
123/// Pre-flight cost gate for a single LLM call.
124///
125/// Build once with the per-call USD cap, then call [`CostCap::check`]
126/// before sending each request. If the estimated cost for the requested
127/// model and token counts exceeds the cap, an error is returned and the
128/// caller can short circuit before paying for the call.
129///
130/// Cost model:
131///   * `input_tokens` are billed at the model's input rate.
132///   * `max_output_tokens` are billed at the model's output rate (worst
133///     case ceiling, because the cap protects against the worst case).
134///   * `cached_input_tokens` are billed at the model's cached rate if
135///     one is published, otherwise treated as zero.
136#[derive(Debug, Clone)]
137pub struct CostCap {
138    cap_usd: f64,
139    prices: HashMap<String, ModelPrice>,
140}
141
142impl CostCap {
143    /// Build a cap with the built-in price table.
144    pub fn new(max_usd: f64) -> Self {
145        assert!(max_usd >= 0.0, "max_usd must be >= 0");
146        Self {
147            cap_usd: max_usd,
148            prices: builtin_prices(),
149        }
150    }
151
152    /// Build a cap with a caller-supplied price table. The built-in
153    /// table is not used; only entries in `prices` are recognized.
154    pub fn with_prices(prices: HashMap<String, ModelPrice>, max_usd: f64) -> Self {
155        assert!(max_usd >= 0.0, "max_usd must be >= 0");
156        Self {
157            cap_usd: max_usd,
158            prices,
159        }
160    }
161
162    pub fn cap_usd(&self) -> f64 {
163        self.cap_usd
164    }
165
166    /// Register or replace one model in the price table.
167    pub fn add_model<S: Into<String>>(&mut self, model: S, price: ModelPrice) {
168        self.prices.insert(model.into(), price);
169    }
170
171    /// Sorted list of model ids registered with this cap.
172    pub fn known_models(&self) -> Vec<String> {
173        let mut ids: Vec<String> = self.prices.keys().cloned().collect();
174        ids.sort();
175        ids
176    }
177
178    /// Return the per-call cost breakdown. Does not raise on overage.
179    pub fn estimate(
180        &self,
181        model: &str,
182        input_tokens: u64,
183        max_output_tokens: u64,
184    ) -> Result<EstimatedCost, EstimateError> {
185        self.estimate_with_cached(model, input_tokens, max_output_tokens, 0)
186    }
187
188    /// Like [`CostCap::estimate`] but includes a separate cached input
189    /// token count. Cached tokens are billed at the model's cached rate
190    /// if one is published; otherwise they are zero in the estimate.
191    pub fn estimate_with_cached(
192        &self,
193        model: &str,
194        input_tokens: u64,
195        max_output_tokens: u64,
196        cached_input_tokens: u64,
197    ) -> Result<EstimatedCost, EstimateError> {
198        let price = self.prices.get(model).ok_or_else(|| UnknownModel {
199            model: model.to_string(),
200        })?;
201        let input_usd = (input_tokens as f64 / 1_000_000.0) * price.input_per_million_usd;
202        let output_usd = (max_output_tokens as f64 / 1_000_000.0) * price.output_per_million_usd;
203        let cached_input_usd = match price.cached_input_per_million_usd {
204            Some(rate) if cached_input_tokens > 0 => {
205                (cached_input_tokens as f64 / 1_000_000.0) * rate
206            }
207            _ => 0.0,
208        };
209        Ok(EstimatedCost {
210            total_usd: input_usd + output_usd + cached_input_usd,
211            input_usd,
212            output_usd,
213            cached_input_usd,
214        })
215    }
216
217    /// Estimate the call. Returns `Ok(cost)` if under cap, `Err` with
218    /// either an [`UnknownModel`] or a [`CapExceeded`] otherwise.
219    pub fn check(
220        &self,
221        model: &str,
222        input_tokens: u64,
223        max_output_tokens: u64,
224    ) -> Result<EstimatedCost, CheckError> {
225        self.check_with_cached(model, input_tokens, max_output_tokens, 0)
226    }
227
228    /// Like [`CostCap::check`] but with a separate cached input token
229    /// count.
230    pub fn check_with_cached(
231        &self,
232        model: &str,
233        input_tokens: u64,
234        max_output_tokens: u64,
235        cached_input_tokens: u64,
236    ) -> Result<EstimatedCost, CheckError> {
237        let est = self.estimate_with_cached(
238            model,
239            input_tokens,
240            max_output_tokens,
241            cached_input_tokens,
242        )?;
243        if est.total_usd > self.cap_usd {
244            return Err(CheckError::CapExceeded(CapExceeded {
245                projected_usd: est.total_usd,
246                cap_usd: self.cap_usd,
247                model: model.to_string(),
248            }));
249        }
250        Ok(est)
251    }
252
253    /// Gate then invoke. Returns whatever `f` returns. `f` is only
254    /// called if [`CostCap::check`] passes.
255    pub fn run<T>(
256        &self,
257        model: &str,
258        input_tokens: u64,
259        max_output_tokens: u64,
260        f: impl FnOnce() -> T,
261    ) -> Result<T, CheckError> {
262        self.check(model, input_tokens, max_output_tokens)?;
263        Ok(f())
264    }
265}