Skip to main content

obol_core/pricing/
mod.rs

1//! Price tables and the per-message cost kernel.
2
3pub mod refresh;
4pub mod store;
5pub use store::{current_path, embedded, pricing_dir, PriceStore};
6
7use crate::model::MessageUsage;
8use serde::{Deserialize, Serialize};
9
10/// Per-million-USD rates for one model. Tier-`above` fields apply when a single
11/// request's input exceeds `tier_boundary`.
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub struct ModelPrice {
14    pub input: f64,
15    pub output: f64,
16    #[serde(default)]
17    pub cache_read: f64,
18    #[serde(default)]
19    pub cache_write: f64, // 5-minute cache write
20    #[serde(default)]
21    pub cache_write_1h: Option<f64>,
22    #[serde(default)]
23    pub tier_boundary: Option<u64>,
24    #[serde(default)]
25    pub input_above: Option<f64>,
26    #[serde(default)]
27    pub output_above: Option<f64>,
28    #[serde(default)]
29    pub cache_read_above: Option<f64>,
30    #[serde(default)]
31    pub cache_write_above: Option<f64>,
32}
33
34/// USD cost of one message's tokens under `price`.
35pub fn cost_for(price: &ModelPrice, u: &MessageUsage) -> f64 {
36    let above = price
37        .tier_boundary
38        .is_some_and(|b| u.request_input_tokens > b);
39    let pick = |base: f64, over: Option<f64>| if above { over.unwrap_or(base) } else { base };
40
41    let r_in = pick(price.input, price.input_above);
42    let r_out = pick(price.output, price.output_above);
43    let r_cr = pick(price.cache_read, price.cache_read_above);
44    let r_cw5 = pick(price.cache_write, price.cache_write_above);
45    let r_cw1 = price.cache_write_1h.unwrap_or(price.cache_write);
46
47    (u.input_uncached as f64 * r_in
48        + u.output as f64 * r_out
49        + u.cache_read as f64 * r_cr
50        + u.cache_write_5m as f64 * r_cw5
51        + u.cache_write_1h as f64 * r_cw1)
52        / 1_000_000.0
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58    use crate::model::Provider;
59
60    fn usage(input: u64, cr: u64, cw5: u64, out: u64, req_in: u64) -> MessageUsage {
61        MessageUsage {
62            model: "claude-opus-4-8".into(),
63            provider: Provider::Anthropic,
64            namespace: "litellm".into(),
65            input_uncached: input,
66            cache_read: cr,
67            cache_write_5m: cw5,
68            cache_write_1h: 0,
69            output: out,
70            request_input_tokens: req_in,
71            service_tier: Some("standard".into()),
72        }
73    }
74
75    // claude-opus-4-8: input 5, output 25, cache_read 0.5, cache_write 6.25 per-million.
76    fn opus() -> ModelPrice {
77        ModelPrice {
78            input: 5.0,
79            output: 25.0,
80            cache_read: 0.5,
81            cache_write: 6.25,
82            cache_write_1h: Some(10.0),
83            tier_boundary: None,
84            input_above: None,
85            output_above: None,
86            cache_read_above: None,
87            cache_write_above: None,
88        }
89    }
90
91    #[test]
92    fn flat_model_costs_each_bucket() {
93        // 1M uncached in, 1M cache_read, 1M cache_write_5m, 1M out
94        let u = usage(1_000_000, 1_000_000, 1_000_000, 1_000_000, 3_000_000);
95        let c = cost_for(&opus(), &u);
96        assert!((c - (5.0 + 0.5 + 6.25 + 25.0)).abs() < 1e-9, "got {c}");
97    }
98
99    #[test]
100    fn tiered_model_switches_rates_above_boundary() {
101        // sonnet-4-5: base input 3, above-200k input 6, output 15/22.5
102        let price = ModelPrice {
103            input: 3.0,
104            output: 15.0,
105            cache_read: 0.3,
106            cache_write: 3.75,
107            cache_write_1h: None,
108            tier_boundary: Some(200_000),
109            input_above: Some(6.0),
110            output_above: Some(22.5),
111            cache_read_above: Some(0.6),
112            cache_write_above: Some(7.5),
113        };
114        // request input 300k > 200k -> above rates
115        let u = usage(1_000_000, 0, 0, 1_000_000, 300_000);
116        let c = cost_for(&price, &u);
117        assert!((c - (6.0 + 22.5)).abs() < 1e-9, "got {c}");
118        // request input 100k < 200k -> base rates
119        let u2 = usage(1_000_000, 0, 0, 1_000_000, 100_000);
120        let c2 = cost_for(&price, &u2);
121        assert!((c2 - (3.0 + 15.0)).abs() < 1e-9, "got {c2}");
122    }
123}