use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::agentlog::Record;
use crate::diff::axes::{Axis, AxisStat, Flag};
use crate::diff::bootstrap::{median, paired_ci};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct ModelPricing {
pub input: f64,
pub output: f64,
#[serde(default)]
pub cached_input: f64,
#[serde(default)]
pub cached_write_5m: f64,
#[serde(default)]
pub cached_write_1h: f64,
#[serde(default)]
pub reasoning: f64,
#[serde(default)]
pub batch_discount: f64,
}
impl ModelPricing {
pub fn simple(input: f64, output: f64) -> Self {
Self {
input,
output,
cached_input: 0.0,
cached_write_5m: 0.0,
cached_write_1h: 0.0,
reasoning: 0.0,
batch_discount: 0.0,
}
}
}
pub type Pricing = HashMap<String, ModelPricing>;
fn lookup_with_snapshot_fallback<'a>(
pricing: &'a Pricing,
model: &str,
) -> Option<&'a ModelPricing> {
if let Some(p) = pricing.get(model) {
return Some(p);
}
if let Some(base) = strip_snapshot_tail(model) {
return pricing.get(base);
}
None
}
fn strip_snapshot_tail(model: &str) -> Option<&str> {
let bytes = model.as_bytes();
if bytes.len() > 11 && bytes[bytes.len() - 11] == b'-' {
let tail = &bytes[bytes.len() - 10..];
if tail.len() == 10
&& tail[..4].iter().all(u8::is_ascii_digit)
&& tail[4] == b'-'
&& tail[5..7].iter().all(u8::is_ascii_digit)
&& tail[7] == b'-'
&& tail[8..10].iter().all(u8::is_ascii_digit)
{
return Some(&model[..model.len() - 11]);
}
}
if bytes.len() > 9 && bytes[bytes.len() - 9] == b'-' {
let tail = &bytes[bytes.len() - 8..];
if tail.len() == 8 && tail.iter().all(u8::is_ascii_digit) {
return Some(&model[..model.len() - 9]);
}
}
None
}
pub(crate) fn cost_of(r: &Record, pricing: &Pricing) -> Option<f64> {
let model = r.payload.get("model")?.as_str()?;
let usage = r.payload.get("usage")?;
let input = usage.get("input_tokens")?.as_f64()?;
let output = usage.get("output_tokens")?.as_f64()?;
let cached_input = usage
.get("cached_input_tokens")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let cached_write_5m = usage
.get("cached_write_5m_tokens")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let cached_write_1h = usage
.get("cached_write_1h_tokens")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let thinking = usage
.get("thinking_tokens")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
if !(input.is_finite()
&& output.is_finite()
&& cached_input.is_finite()
&& cached_write_5m.is_finite()
&& cached_write_1h.is_finite()
&& thinking.is_finite())
{
return Some(0.0);
}
let Some(p) = lookup_with_snapshot_fallback(pricing, model) else {
return Some(0.0);
};
let cached_rate = if p.cached_input > 0.0 {
p.cached_input
} else {
p.input
};
let reasoning_rate = if p.reasoning > 0.0 {
p.reasoning
} else {
p.output
};
let write_5m_rate = if p.cached_write_5m > 0.0 {
p.cached_write_5m
} else {
p.input
};
let write_1h_rate = if p.cached_write_1h > 0.0 {
p.cached_write_1h
} else {
p.input
};
let mut cost = input * p.input
+ cached_input * cached_rate
+ cached_write_5m * write_5m_rate
+ cached_write_1h * write_1h_rate
+ output * p.output
+ thinking * reasoning_rate;
let batch = r
.payload
.get("batch")
.and_then(|v| v.as_bool())
.unwrap_or(false);
if batch && p.batch_discount > 0.0 {
cost *= p.batch_discount;
}
Some(cost)
}
fn pair_is_priced(br: &Record, cr: &Record, pricing: &Pricing) -> bool {
fn model_in_table(r: &Record, pricing: &Pricing) -> bool {
r.payload
.get("model")
.and_then(|m| m.as_str())
.is_some_and(|m| lookup_with_snapshot_fallback(pricing, m).is_some())
}
model_in_table(br, pricing) && model_in_table(cr, pricing)
}
pub fn compute(pairs: &[(&Record, &Record)], pricing: &Pricing, seed: Option<u64>) -> AxisStat {
let mut b = Vec::with_capacity(pairs.len());
let mut c = Vec::with_capacity(pairs.len());
let mut priced_pairs = 0usize;
for (br, cr) in pairs {
if let (Some(bv), Some(cv)) = (cost_of(br, pricing), cost_of(cr, pricing)) {
b.push(bv);
c.push(cv);
if pair_is_priced(br, cr, pricing) {
priced_pairs += 1;
}
}
}
if b.is_empty() {
let mut stat = AxisStat::empty(Axis::Cost);
if !pairs.is_empty() {
stat.flags.push(Flag::NoPricing);
}
return stat;
}
let bm = median(&b);
let cm = median(&c);
let delta = cm - bm;
let ci = paired_ci(&b, &c, |bs, cs| median(cs) - median(bs), 0, seed);
let mut stat = AxisStat::new_value(Axis::Cost, bm, cm, delta, ci.low, ci.high, b.len());
if priced_pairs * 2 < pairs.len() {
stat.flags.push(Flag::NoPricing);
}
stat
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agentlog::Kind;
use crate::diff::axes::Severity;
use serde_json::json;
fn response(model: &str, input: u64, output: u64) -> Record {
Record::new(
Kind::ChatResponse,
json!({
"model": model,
"content": [],
"stop_reason": "end_turn",
"latency_ms": 0,
"usage": {"input_tokens": input, "output_tokens": output, "thinking_tokens": 0},
}),
"2026-04-21T10:00:00Z",
None,
)
}
fn response_with_usage(model: &str, usage: serde_json::Value) -> Record {
Record::new(
Kind::ChatResponse,
json!({
"model": model,
"content": [],
"stop_reason": "end_turn",
"latency_ms": 0,
"usage": usage,
}),
"2026-04-21T10:00:00Z",
None,
)
}
#[test]
fn pricing_lookup_drives_cost() {
let mut pricing = Pricing::new();
pricing.insert("opus".to_string(), ModelPricing::simple(0.000015, 0.000075));
pricing.insert(
"haiku".to_string(),
ModelPricing::simple(0.0000008, 0.000004),
);
let baseline: Vec<Record> = (0..10).map(|_| response("opus", 1000, 500)).collect();
let candidate: Vec<Record> = (0..10).map(|_| response("haiku", 1000, 500)).collect();
let pairs: Vec<(&Record, &Record)> = baseline.iter().zip(candidate.iter()).collect();
let stat = compute(&pairs, &pricing, Some(1));
assert!(stat.delta < 0.0);
assert_eq!(stat.severity, Severity::Severe);
}
#[test]
fn unknown_model_costs_zero() {
let pricing = Pricing::new();
let r = response("mystery", 1000, 500);
let pairs = [(&r, &r)];
let stat = compute(&pairs, &pricing, Some(1));
assert_eq!(stat.baseline_median, 0.0);
}
#[test]
fn no_pricing_flag_when_table_is_empty_but_pairs_exist() {
let pricing = Pricing::new();
let r = response("mystery", 1000, 500);
let pairs = [(&r, &r), (&r, &r), (&r, &r)];
let stat = compute(&pairs, &pricing, Some(1));
assert!(stat.flags.contains(&Flag::NoPricing));
assert_eq!(stat.delta, 0.0);
}
#[test]
fn no_pricing_flag_when_most_models_unpriced() {
let mut pricing = Pricing::new();
pricing.insert("opus".to_string(), ModelPricing::simple(0.000015, 0.000075));
let priced = response("opus", 1000, 500);
let unpriced1 = response("sonnet-unlisted", 1000, 500);
let unpriced2 = response("gpt-x-unlisted", 1000, 500);
let pairs = [
(&priced, &priced),
(&unpriced1, &unpriced1),
(&unpriced2, &unpriced2),
];
let stat = compute(&pairs, &pricing, Some(1));
assert!(stat.flags.contains(&Flag::NoPricing));
}
#[test]
fn no_pricing_flag_absent_when_all_pairs_priced() {
let mut pricing = Pricing::new();
pricing.insert("opus".to_string(), ModelPricing::simple(0.000015, 0.000075));
let r = response("opus", 1000, 500);
let pairs = [(&r, &r), (&r, &r)];
let stat = compute(&pairs, &pricing, Some(1));
assert!(!stat.flags.contains(&Flag::NoPricing));
}
#[test]
fn no_pricing_flag_absent_when_pairs_empty() {
let pricing = Pricing::new();
let pairs: Vec<(&Record, &Record)> = Vec::new();
let stat = compute(&pairs, &pricing, Some(1));
assert!(!stat.flags.contains(&Flag::NoPricing));
assert_eq!(stat.n, 0);
}
#[test]
fn cached_input_tokens_billed_at_cheaper_rate() {
let mut pricing = Pricing::new();
pricing.insert(
"opus".to_string(),
ModelPricing {
input: 0.000015,
output: 0.000075,
cached_input: 0.0000015, cached_write_5m: 0.0,
cached_write_1h: 0.0,
reasoning: 0.0,
batch_discount: 0.0,
},
);
let r = response_with_usage(
"opus",
json!({
"input_tokens": 1000,
"output_tokens": 500,
"thinking_tokens": 0,
"cached_input_tokens": 1000,
}),
);
let pairs = [(&r, &r)];
let stat = compute(&pairs, &pricing, Some(1));
assert!((stat.baseline_median - 0.054).abs() < 1e-9);
}
#[test]
fn reasoning_tokens_billed_at_reasoning_rate() {
let mut pricing = Pricing::new();
pricing.insert(
"gpt-5".to_string(),
ModelPricing {
input: 0.000010,
output: 0.000040,
cached_input: 0.0,
cached_write_5m: 0.0,
cached_write_1h: 0.0,
reasoning: 0.000060, batch_discount: 0.0,
},
);
let r = response_with_usage(
"gpt-5",
json!({
"input_tokens": 100,
"output_tokens": 100,
"thinking_tokens": 500,
}),
);
let pairs = [(&r, &r)];
let stat = compute(&pairs, &pricing, Some(1));
assert!((stat.baseline_median - 0.035).abs() < 1e-6);
}
#[test]
fn anthropic_cache_write_tiers_are_billed_separately() {
let mut pricing = Pricing::new();
pricing.insert(
"opus".to_string(),
ModelPricing {
input: 0.000015,
output: 0.000075,
cached_input: 0.0000015,
cached_write_5m: 0.00001875,
cached_write_1h: 0.00003,
reasoning: 0.0,
batch_discount: 0.0,
},
);
let r = Record::new(
Kind::ChatResponse,
json!({
"model": "opus",
"content": [],
"stop_reason": "end_turn",
"latency_ms": 0,
"usage": {
"input_tokens": 1000,
"output_tokens": 200,
"thinking_tokens": 0,
"cached_input_tokens": 500,
"cached_write_5m_tokens": 200,
"cached_write_1h_tokens": 100,
},
}),
"2026-04-21T10:00:00Z",
None,
);
let pairs = [(&r, &r)];
let stat = compute(&pairs, &pricing, Some(1));
assert!(
(stat.baseline_median - 0.0375).abs() < 1e-6,
"got {}",
stat.baseline_median
);
}
#[test]
fn nan_usage_values_produce_zero_cost_not_phantom_inf() {
let mut pricing = Pricing::new();
pricing.insert("m".to_string(), ModelPricing::simple(0.001, 0.002));
let r = Record::new(
Kind::ChatResponse,
json!({
"model": "m",
"content": [],
"stop_reason": "end_turn",
"latency_ms": 0,
"usage": {
"input_tokens": 100.0,
"output_tokens": 100.0,
"thinking_tokens": 0,
"cached_input_tokens": f64::NAN,
},
}),
"2026-04-21T10:00:00Z",
None,
);
let pairs = [(&r, &r)];
let stat = compute(&pairs, &pricing, Some(1));
assert!(stat.baseline_median.is_finite());
assert_eq!(stat.severity, Severity::None);
}
#[test]
fn batch_flag_applies_discount() {
let mut pricing = Pricing::new();
pricing.insert(
"opus".to_string(),
ModelPricing {
input: 0.000015,
output: 0.000075,
cached_input: 0.0,
cached_write_5m: 0.0,
cached_write_1h: 0.0,
reasoning: 0.0,
batch_discount: 0.5, },
);
let batched = Record::new(
Kind::ChatResponse,
json!({
"model": "opus",
"content": [],
"stop_reason": "end_turn",
"latency_ms": 0,
"batch": true,
"usage": {"input_tokens": 1000, "output_tokens": 500, "thinking_tokens": 0},
}),
"2026-04-21T10:00:00Z",
None,
);
let non_batched = response("opus", 1000, 500);
let pairs_batched = [(&batched, &batched)];
let pairs_normal = [(&non_batched, &non_batched)];
let stat_b = compute(&pairs_batched, &pricing, Some(1));
let stat_n = compute(&pairs_normal, &pricing, Some(1));
assert!((stat_b.baseline_median - stat_n.baseline_median * 0.5).abs() < 1e-9);
}
#[test]
fn snapshot_tail_strips_iso_dates() {
assert_eq!(strip_snapshot_tail("gpt-5-2025-08-07"), Some("gpt-5"));
assert_eq!(
strip_snapshot_tail("gpt-4o-mini-2024-07-18"),
Some("gpt-4o-mini"),
);
assert_eq!(
strip_snapshot_tail("claude-opus-4-7-20250219"),
Some("claude-opus-4-7"),
);
assert_eq!(strip_snapshot_tail("gpt-5"), None);
assert_eq!(strip_snapshot_tail("gpt-4o-mini"), None);
assert_eq!(strip_snapshot_tail("o1"), None);
}
#[test]
fn cost_resolves_dated_snapshot_to_bare_alias() {
let mut pricing = Pricing::new();
pricing.insert(
"gpt-5".to_string(),
ModelPricing {
input: 0.000010,
output: 0.000040,
cached_input: 0.0,
cached_write_5m: 0.0,
cached_write_1h: 0.0,
reasoning: 0.0,
batch_discount: 0.0,
},
);
let r = response("gpt-5-2025-08-07", 100, 50);
let cost = cost_of(&r, &pricing).unwrap();
assert!((cost - 0.003).abs() < 1e-9, "got {}", cost);
let pairs = [(&r, &r)];
let stat = compute(&pairs, &pricing, Some(42));
assert!(
!stat.flags.contains(&Flag::NoPricing),
"pair_is_priced should accept dated snapshots"
);
}
}