Skip to main content

llmkit_core/
usage.rs

1//! Token usage and per-model cost estimation.
2
3use serde::{Deserialize, Serialize};
4
5/// Normalised token counts for one request.
6#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
7pub struct TokenUsage {
8    /// Input (prompt) tokens.
9    pub prompt: u32,
10    /// Output (completion) tokens.
11    pub completion: u32,
12}
13
14impl TokenUsage {
15    /// Construct from prompt and completion counts.
16    pub const fn new(prompt: u32, completion: u32) -> Self {
17        Self { prompt, completion }
18    }
19
20    /// Total tokens.
21    pub const fn total(&self) -> u32 {
22        self.prompt + self.completion
23    }
24}
25
26impl std::ops::Add for TokenUsage {
27    type Output = TokenUsage;
28    fn add(self, rhs: TokenUsage) -> TokenUsage {
29        TokenUsage { prompt: self.prompt + rhs.prompt, completion: self.completion + rhs.completion }
30    }
31}
32
33impl std::ops::AddAssign for TokenUsage {
34    fn add_assign(&mut self, rhs: TokenUsage) {
35        self.prompt += rhs.prompt;
36        self.completion += rhs.completion;
37    }
38}
39
40/// USD pricing per 1M tokens.
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub struct ModelPricing {
43    /// USD per 1M input tokens.
44    pub input_per_mtok: f64,
45    /// USD per 1M output tokens.
46    pub output_per_mtok: f64,
47}
48
49impl ModelPricing {
50    /// Cost for the given usage under this pricing.
51    pub fn cost_for(&self, usage: TokenUsage) -> CostEstimate {
52        CostEstimate {
53            input_usd: (usage.prompt as f64 / 1_000_000.0) * self.input_per_mtok,
54            output_usd: (usage.completion as f64 / 1_000_000.0) * self.output_per_mtok,
55        }
56    }
57}
58
59/// Computed cost breakdown in USD.
60#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
61pub struct CostEstimate {
62    /// Cost of input tokens.
63    pub input_usd: f64,
64    /// Cost of output tokens.
65    pub output_usd: f64,
66}
67
68impl CostEstimate {
69    /// Total cost (input + output).
70    pub fn total_usd(&self) -> f64 {
71        self.input_usd + self.output_usd
72    }
73}
74
75impl std::ops::Add for CostEstimate {
76    type Output = CostEstimate;
77    fn add(self, rhs: CostEstimate) -> CostEstimate {
78        CostEstimate {
79            input_usd: self.input_usd + rhs.input_usd,
80            output_usd: self.output_usd + rhs.output_usd,
81        }
82    }
83}
84
85/// Per-model pricing lookup. Unknown models (e.g. local Ollama) return `None`.
86pub mod pricing {
87    use super::ModelPricing;
88
89    const fn p(input: f64, output: f64) -> ModelPricing {
90        ModelPricing { input_per_mtok: input, output_per_mtok: output }
91    }
92
93    /// Look up pricing for a model slug by prefix. `None` if unknown.
94    pub fn pricing_for(model: &str) -> Option<ModelPricing> {
95        let m = model.trim().to_ascii_lowercase();
96        let m = m.strip_prefix("anthropic.").unwrap_or(&m);
97
98        let pricing = match () {
99            _ if m.starts_with("gpt-4o-mini") => p(0.15, 0.60),
100            _ if m.starts_with("gpt-4o") => p(2.50, 10.0),
101            _ if m.starts_with("o1-mini") => p(3.0, 12.0),
102            _ if m.starts_with("o1") => p(15.0, 60.0),
103            _ if m.starts_with("text-embedding-3-small") => p(0.02, 0.0),
104            _ if m.starts_with("text-embedding-3-large") => p(0.13, 0.0),
105            _ if m.starts_with("claude-opus-4") => p(5.0, 25.0),
106            _ if m.starts_with("claude-sonnet-4") || m.starts_with("claude-3-5-sonnet") => p(3.0, 15.0),
107            _ if m.starts_with("claude-haiku-4") || m.starts_with("claude-3-5-haiku") => p(1.0, 5.0),
108            _ if m.starts_with("claude-3-opus") => p(15.0, 75.0),
109            _ => return None,
110        };
111        Some(pricing)
112    }
113
114    /// Look up pricing, falling back to `default` when unknown.
115    pub fn pricing_for_or(model: &str, default: ModelPricing) -> ModelPricing {
116        pricing_for(model).unwrap_or(default)
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn usage_arithmetic() {
126        let mut c = TokenUsage::new(10, 5);
127        c += TokenUsage::new(3, 7);
128        assert_eq!(c, TokenUsage::new(13, 12));
129        assert_eq!(c.total(), 25);
130    }
131
132    #[test]
133    fn cost_computation() {
134        let cost = ModelPricing { input_per_mtok: 5.0, output_per_mtok: 25.0 }
135            .cost_for(TokenUsage::new(1_000_000, 1_000_000));
136        assert!((cost.total_usd() - 30.0).abs() < 1e-9);
137    }
138
139    #[test]
140    fn pricing_lookup() {
141        assert!(pricing::pricing_for("gpt-4o-mini").is_some());
142        assert!(pricing::pricing_for("claude-opus-4-8").is_some());
143        assert!(pricing::pricing_for("llama3.1").is_none());
144    }
145}