1use serde::{Deserialize, Serialize};
2
3use crate::model::ModelPricing;
4
5#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
7pub struct Usage {
8 pub input_tokens: u32,
10 pub output_tokens: u32,
12 pub cache_read_tokens: u32,
14 pub cache_write_tokens: u32,
16}
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
20pub struct Cost {
21 pub input: f64,
23 pub output: f64,
25 pub cache_read: f64,
27 pub cache_write: f64,
29 pub total: f64,
31}
32
33impl Usage {
34 pub fn total_tokens(&self) -> u32 {
36 self.input_tokens + self.output_tokens
37 }
38
39 pub fn cost(&self, pricing: &ModelPricing) -> Cost {
41 let input = self.input_tokens as f64 * pricing.input_per_mtok / 1_000_000.0;
42 let output = self.output_tokens as f64 * pricing.output_per_mtok / 1_000_000.0;
43 let cache_read = self.cache_read_tokens as f64 * pricing.cache_read_per_mtok / 1_000_000.0;
44 let cache_write =
45 self.cache_write_tokens as f64 * pricing.cache_write_per_mtok / 1_000_000.0;
46 let total = input + output + cache_read + cache_write;
47 Cost {
48 input,
49 output,
50 cache_read,
51 cache_write,
52 total,
53 }
54 }
55
56 pub fn add(&mut self, other: &Usage) {
58 self.input_tokens += other.input_tokens;
59 self.output_tokens += other.output_tokens;
60 self.cache_read_tokens += other.cache_read_tokens;
61 self.cache_write_tokens += other.cache_write_tokens;
62 }
63}
64
65impl Cost {
66 pub fn add(&mut self, other: &Cost) {
68 self.input += other.input;
69 self.output += other.output;
70 self.cache_read += other.cache_read;
71 self.cache_write += other.cache_write;
72 self.total += other.total;
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn total_tokens_sums_input_and_output() {
82 let usage = Usage {
83 input_tokens: 100,
84 output_tokens: 50,
85 cache_read_tokens: 200,
86 cache_write_tokens: 10,
87 };
88 assert_eq!(usage.total_tokens(), 150);
89 }
90
91 #[test]
92 fn cost_calculation_matches_expected() {
93 let usage = Usage {
94 input_tokens: 1_000_000,
95 output_tokens: 500_000,
96 cache_read_tokens: 200_000,
97 cache_write_tokens: 100_000,
98 };
99 let pricing = ModelPricing {
100 input_per_mtok: 3.0,
101 output_per_mtok: 15.0,
102 cache_read_per_mtok: 0.3,
103 cache_write_per_mtok: 3.75,
104 };
105 let cost = usage.cost(&pricing);
106
107 assert!((cost.input - 3.0).abs() < f64::EPSILON);
109 assert!((cost.output - 7.5).abs() < f64::EPSILON);
111 assert!((cost.cache_read - 0.06).abs() < f64::EPSILON);
113 assert!((cost.cache_write - 0.375).abs() < f64::EPSILON);
115 assert!((cost.total - 10.935).abs() < 1e-10);
117 }
118
119 #[test]
120 fn cost_zero_for_zero_usage() {
121 let usage = Usage::default();
122 let pricing = ModelPricing {
123 input_per_mtok: 3.0,
124 output_per_mtok: 15.0,
125 cache_read_per_mtok: 0.3,
126 cache_write_per_mtok: 3.75,
127 };
128 let cost = usage.cost(&pricing);
129 assert!((cost.total).abs() < f64::EPSILON);
130 }
131
132 #[test]
133 fn add_accumulates_all_fields() {
134 let mut a = Usage {
135 input_tokens: 100,
136 output_tokens: 50,
137 cache_read_tokens: 10,
138 cache_write_tokens: 5,
139 };
140 let b = Usage {
141 input_tokens: 200,
142 output_tokens: 100,
143 cache_read_tokens: 20,
144 cache_write_tokens: 10,
145 };
146 a.add(&b);
147 assert_eq!(
148 a,
149 Usage {
150 input_tokens: 300,
151 output_tokens: 150,
152 cache_read_tokens: 30,
153 cache_write_tokens: 15,
154 }
155 );
156 }
157
158 #[test]
159 fn cost_add_accumulates_all_fields() {
160 let mut a = Cost {
161 input: 1.0,
162 output: 2.0,
163 cache_read: 0.5,
164 cache_write: 0.25,
165 total: 3.75,
166 };
167 let b = Cost {
168 input: 0.5,
169 output: 1.5,
170 cache_read: 0.25,
171 cache_write: 0.75,
172 total: 3.0,
173 };
174 a.add(&b);
175 assert_eq!(
176 a,
177 Cost {
178 input: 1.5,
179 output: 3.5,
180 cache_read: 0.75,
181 cache_write: 1.0,
182 total: 6.75,
183 }
184 );
185 }
186}