use serde::{Deserialize, Serialize};
use crate::model::ModelPricing;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
pub cache_read_tokens: u32,
pub cache_write_tokens: u32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct Cost {
pub input: f64,
pub output: f64,
pub cache_read: f64,
pub cache_write: f64,
pub total: f64,
}
impl Usage {
pub fn total_tokens(&self) -> u32 {
self.input_tokens + self.output_tokens
}
pub fn cost(&self, pricing: &ModelPricing) -> Cost {
let input = self.input_tokens as f64 * pricing.input_per_mtok / 1_000_000.0;
let output = self.output_tokens as f64 * pricing.output_per_mtok / 1_000_000.0;
let cache_read = self.cache_read_tokens as f64 * pricing.cache_read_per_mtok / 1_000_000.0;
let cache_write =
self.cache_write_tokens as f64 * pricing.cache_write_per_mtok / 1_000_000.0;
let total = input + output + cache_read + cache_write;
Cost {
input,
output,
cache_read,
cache_write,
total,
}
}
pub fn add(&mut self, other: &Usage) {
self.input_tokens += other.input_tokens;
self.output_tokens += other.output_tokens;
self.cache_read_tokens += other.cache_read_tokens;
self.cache_write_tokens += other.cache_write_tokens;
}
}
impl Cost {
pub fn add(&mut self, other: &Cost) {
self.input += other.input;
self.output += other.output;
self.cache_read += other.cache_read;
self.cache_write += other.cache_write;
self.total += other.total;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn total_tokens_sums_input_and_output() {
let usage = Usage {
input_tokens: 100,
output_tokens: 50,
cache_read_tokens: 200,
cache_write_tokens: 10,
};
assert_eq!(usage.total_tokens(), 150);
}
#[test]
fn cost_calculation_matches_expected() {
let usage = Usage {
input_tokens: 1_000_000,
output_tokens: 500_000,
cache_read_tokens: 200_000,
cache_write_tokens: 100_000,
};
let pricing = ModelPricing {
input_per_mtok: 3.0,
output_per_mtok: 15.0,
cache_read_per_mtok: 0.3,
cache_write_per_mtok: 3.75,
};
let cost = usage.cost(&pricing);
assert!((cost.input - 3.0).abs() < f64::EPSILON);
assert!((cost.output - 7.5).abs() < f64::EPSILON);
assert!((cost.cache_read - 0.06).abs() < f64::EPSILON);
assert!((cost.cache_write - 0.375).abs() < f64::EPSILON);
assert!((cost.total - 10.935).abs() < 1e-10);
}
#[test]
fn cost_zero_for_zero_usage() {
let usage = Usage::default();
let pricing = ModelPricing {
input_per_mtok: 3.0,
output_per_mtok: 15.0,
cache_read_per_mtok: 0.3,
cache_write_per_mtok: 3.75,
};
let cost = usage.cost(&pricing);
assert!((cost.total).abs() < f64::EPSILON);
}
#[test]
fn add_accumulates_all_fields() {
let mut a = Usage {
input_tokens: 100,
output_tokens: 50,
cache_read_tokens: 10,
cache_write_tokens: 5,
};
let b = Usage {
input_tokens: 200,
output_tokens: 100,
cache_read_tokens: 20,
cache_write_tokens: 10,
};
a.add(&b);
assert_eq!(
a,
Usage {
input_tokens: 300,
output_tokens: 150,
cache_read_tokens: 30,
cache_write_tokens: 15,
}
);
}
#[test]
fn cost_add_accumulates_all_fields() {
let mut a = Cost {
input: 1.0,
output: 2.0,
cache_read: 0.5,
cache_write: 0.25,
total: 3.75,
};
let b = Cost {
input: 0.5,
output: 1.5,
cache_read: 0.25,
cache_write: 0.75,
total: 3.0,
};
a.add(&b);
assert_eq!(
a,
Cost {
input: 1.5,
output: 3.5,
cache_read: 0.75,
cache_write: 1.0,
total: 6.75,
}
);
}
}