Skip to main content

imp_llm/
usage.rs

1use serde::{Deserialize, Serialize};
2
3use crate::model::ModelPricing;
4
5/// Token usage from a single LLM request.
6#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
7pub struct Usage {
8    /// Tokens consumed by the input prompt.
9    pub input_tokens: u32,
10    /// Tokens generated in the output.
11    pub output_tokens: u32,
12    /// Tokens served from the prompt cache.
13    pub cache_read_tokens: u32,
14    /// Tokens written into the prompt cache.
15    pub cache_write_tokens: u32,
16}
17
18/// Dollar cost breakdown for a request.
19#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
20pub struct Cost {
21    /// Cost of input tokens.
22    pub input: f64,
23    /// Cost of output tokens.
24    pub output: f64,
25    /// Cost of cache-read tokens.
26    pub cache_read: f64,
27    /// Cost of cache-write tokens.
28    pub cache_write: f64,
29    /// Sum of all cost components.
30    pub total: f64,
31}
32
33impl Usage {
34    /// Total tokens across input and output (excludes cache).
35    pub fn total_tokens(&self) -> u32 {
36        self.input_tokens + self.output_tokens
37    }
38
39    /// Calculate dollar cost given a model's pricing.
40    pub fn cost(&self, pricing: &ModelPricing) -> Cost {
41        let input = self.input_tokens as f64 * pricing.input_per_mtok / 1_000_000.0;
42        let output = self.output_tokens as f64 * pricing.output_per_mtok / 1_000_000.0;
43        let cache_read = self.cache_read_tokens as f64 * pricing.cache_read_per_mtok / 1_000_000.0;
44        let cache_write =
45            self.cache_write_tokens as f64 * pricing.cache_write_per_mtok / 1_000_000.0;
46        let total = input + output + cache_read + cache_write;
47        Cost {
48            input,
49            output,
50            cache_read,
51            cache_write,
52            total,
53        }
54    }
55
56    /// Accumulate another usage into this one.
57    pub fn add(&mut self, other: &Usage) {
58        self.input_tokens += other.input_tokens;
59        self.output_tokens += other.output_tokens;
60        self.cache_read_tokens += other.cache_read_tokens;
61        self.cache_write_tokens += other.cache_write_tokens;
62    }
63}
64
65impl Cost {
66    /// Accumulate another cost breakdown into this one.
67    pub fn add(&mut self, other: &Cost) {
68        self.input += other.input;
69        self.output += other.output;
70        self.cache_read += other.cache_read;
71        self.cache_write += other.cache_write;
72        self.total += other.total;
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn total_tokens_sums_input_and_output() {
82        let usage = Usage {
83            input_tokens: 100,
84            output_tokens: 50,
85            cache_read_tokens: 200,
86            cache_write_tokens: 10,
87        };
88        assert_eq!(usage.total_tokens(), 150);
89    }
90
91    #[test]
92    fn cost_calculation_matches_expected() {
93        let usage = Usage {
94            input_tokens: 1_000_000,
95            output_tokens: 500_000,
96            cache_read_tokens: 200_000,
97            cache_write_tokens: 100_000,
98        };
99        let pricing = ModelPricing {
100            input_per_mtok: 3.0,
101            output_per_mtok: 15.0,
102            cache_read_per_mtok: 0.3,
103            cache_write_per_mtok: 3.75,
104        };
105        let cost = usage.cost(&pricing);
106
107        // 1M input * $3/Mtok = $3.00
108        assert!((cost.input - 3.0).abs() < f64::EPSILON);
109        // 500k output * $15/Mtok = $7.50
110        assert!((cost.output - 7.5).abs() < f64::EPSILON);
111        // 200k cache_read * $0.30/Mtok = $0.06
112        assert!((cost.cache_read - 0.06).abs() < f64::EPSILON);
113        // 100k cache_write * $3.75/Mtok = $0.375
114        assert!((cost.cache_write - 0.375).abs() < f64::EPSILON);
115        // total = 3.0 + 7.5 + 0.06 + 0.375 = 10.935
116        assert!((cost.total - 10.935).abs() < 1e-10);
117    }
118
119    #[test]
120    fn cost_zero_for_zero_usage() {
121        let usage = Usage::default();
122        let pricing = ModelPricing {
123            input_per_mtok: 3.0,
124            output_per_mtok: 15.0,
125            cache_read_per_mtok: 0.3,
126            cache_write_per_mtok: 3.75,
127        };
128        let cost = usage.cost(&pricing);
129        assert!((cost.total).abs() < f64::EPSILON);
130    }
131
132    #[test]
133    fn add_accumulates_all_fields() {
134        let mut a = Usage {
135            input_tokens: 100,
136            output_tokens: 50,
137            cache_read_tokens: 10,
138            cache_write_tokens: 5,
139        };
140        let b = Usage {
141            input_tokens: 200,
142            output_tokens: 100,
143            cache_read_tokens: 20,
144            cache_write_tokens: 10,
145        };
146        a.add(&b);
147        assert_eq!(
148            a,
149            Usage {
150                input_tokens: 300,
151                output_tokens: 150,
152                cache_read_tokens: 30,
153                cache_write_tokens: 15,
154            }
155        );
156    }
157
158    #[test]
159    fn cost_add_accumulates_all_fields() {
160        let mut a = Cost {
161            input: 1.0,
162            output: 2.0,
163            cache_read: 0.5,
164            cache_write: 0.25,
165            total: 3.75,
166        };
167        let b = Cost {
168            input: 0.5,
169            output: 1.5,
170            cache_read: 0.25,
171            cache_write: 0.75,
172            total: 3.0,
173        };
174        a.add(&b);
175        assert_eq!(
176            a,
177            Cost {
178                input: 1.5,
179                output: 3.5,
180                cache_read: 0.75,
181                cache_write: 1.0,
182                total: 6.75,
183            }
184        );
185    }
186}