use super::*;
fn test_pricing() -> PricingInfo {
PricingInfo {
input_price_per_million: 3.0, output_price_per_million: 15.0, }
}
#[test]
fn test_basic_cost_tracking() {
let mut tracker = CostTracker::new();
let usage = TokenUsage {
prompt_tokens: 1000,
completion_tokens: 500,
..Default::default()
};
let cost = tracker.record_usage(&usage, Some(&test_pricing()));
let expected = 0.003 + 0.0075;
assert!((cost - expected).abs() < 1e-9);
assert_eq!(tracker.total_input_tokens, 1000);
assert_eq!(tracker.total_output_tokens, 500);
assert_eq!(tracker.call_count, 1);
}
#[test]
fn test_tiered_pricing_over_200k() {
let mut tracker = CostTracker::new();
let usage = TokenUsage {
prompt_tokens: 250_000,
completion_tokens: 100,
..Default::default()
};
let cost = tracker.record_usage(&usage, Some(&test_pricing()));
let expected = 0.60 + 0.225 + 0.0015;
assert!((cost - expected).abs() < 1e-9);
}
#[test]
fn test_cache_read_tokens() {
let mut tracker = CostTracker::new();
let usage = TokenUsage {
prompt_tokens: 1000,
completion_tokens: 100,
cache_read_input_tokens: 5000,
..Default::default()
};
let cost = tracker.record_usage(&usage, Some(&test_pricing()));
let expected = 0.003 + 0.0015 + 0.0015;
assert!((cost - expected).abs() < 1e-9);
}
#[test]
fn test_no_pricing_tracks_tokens_only() {
let mut tracker = CostTracker::new();
let usage = TokenUsage {
prompt_tokens: 1000,
completion_tokens: 500,
..Default::default()
};
let cost = tracker.record_usage(&usage, None);
assert_eq!(cost, 0.0);
assert_eq!(tracker.total_input_tokens, 1000);
assert_eq!(tracker.total_output_tokens, 500);
assert_eq!(tracker.total_cost_usd, 0.0);
}
#[test]
fn test_cumulative_tracking() {
let mut tracker = CostTracker::new();
let pricing = test_pricing();
let usage1 = TokenUsage {
prompt_tokens: 1000,
completion_tokens: 200,
..Default::default()
};
let usage2 = TokenUsage {
prompt_tokens: 2000,
completion_tokens: 300,
..Default::default()
};
tracker.record_usage(&usage1, Some(&pricing));
tracker.record_usage(&usage2, Some(&pricing));
assert_eq!(tracker.total_input_tokens, 3000);
assert_eq!(tracker.total_output_tokens, 500);
assert_eq!(tracker.call_count, 2);
}
#[test]
fn test_format_cost_small() {
let mut tracker = CostTracker::new();
tracker.total_cost_usd = 0.005;
assert_eq!(tracker.format_cost(), "$0.0050");
}
#[test]
fn test_format_cost_large() {
let mut tracker = CostTracker::new();
tracker.total_cost_usd = 1.234;
assert_eq!(tracker.format_cost(), "$1.23");
}
#[test]
fn test_to_metadata_and_restore() {
let mut tracker = CostTracker::new();
tracker.total_input_tokens = 5000;
tracker.total_output_tokens = 2000;
tracker.total_cost_usd = 0.123456;
tracker.call_count = 3;
let metadata = tracker.to_metadata();
let mut restored = CostTracker::new();
let meta_json = serde_json::json!({
"cost_tracking": metadata,
});
restored.restore_from_metadata(&meta_json);
assert_eq!(restored.total_input_tokens, 5000);
assert_eq!(restored.total_output_tokens, 2000);
assert!((restored.total_cost_usd - 0.123456).abs() < 1e-9);
assert_eq!(restored.call_count, 3);
}
#[test]
fn test_restore_missing_cost_tracking() {
let mut tracker = CostTracker::new();
tracker.total_input_tokens = 100;
tracker.restore_from_metadata(&serde_json::json!({}));
assert_eq!(tracker.total_input_tokens, 100);
}
#[test]
fn test_token_usage_from_json() {
let json = serde_json::json!({
"prompt_tokens": 1500,
"completion_tokens": 300,
"cache_read_input_tokens": 800,
});
let usage = TokenUsage::from_json(&json);
assert_eq!(usage.prompt_tokens, 1500);
assert_eq!(usage.completion_tokens, 300);
assert_eq!(usage.cache_read_input_tokens, 800);
assert_eq!(usage.cache_creation_input_tokens, 0);
}
#[test]
fn test_round_f64() {
assert_eq!(round_f64(1.23456789, 6), 1.234568);
assert_eq!(round_f64(0.0, 2), 0.0);
}
fn anthropic_sonnet_pricing() -> PricingInfo {
PricingInfo {
input_price_per_million: 3.0,
output_price_per_million: 15.0,
}
}
fn openai_gpt4o_pricing() -> PricingInfo {
PricingInfo {
input_price_per_million: 2.50,
output_price_per_million: 10.0,
}
}
#[test]
fn test_anthropic_cache_discount_accuracy() {
let mut tracker = CostTracker::new();
let usage = TokenUsage {
prompt_tokens: 10_000,
completion_tokens: 1_000,
cache_read_input_tokens: 50_000,
cache_creation_input_tokens: 0,
};
let cost = tracker.record_usage(&usage, Some(&anthropic_sonnet_pricing()));
let expected = 0.03 + 0.015 + 0.015;
assert!(
(cost - expected).abs() < 1e-9,
"Anthropic cache cost mismatch: got {cost}, expected {expected}"
);
}
#[test]
fn test_over_200k_tier_multiplier() {
let mut tracker = CostTracker::new();
let pricing = anthropic_sonnet_pricing();
let at_threshold = TokenUsage {
prompt_tokens: 200_000,
completion_tokens: 0,
..Default::default()
};
let cost_at = tracker.record_usage(&at_threshold, Some(&pricing));
let expected_at = 200_000.0 / 1_000_000.0 * 3.0; assert!(
(cost_at - expected_at).abs() < 1e-9,
"At 200K: got {cost_at}, expected {expected_at}"
);
let mut tracker2 = CostTracker::new();
let over_threshold = TokenUsage {
prompt_tokens: 200_001,
completion_tokens: 0,
..Default::default()
};
let cost_over = tracker2.record_usage(&over_threshold, Some(&pricing));
let expected_over = 0.60 + (1.0 / 1_000_000.0 * 3.0 * 1.5);
assert!(
(cost_over - expected_over).abs() < 1e-9,
"At 200_001: got {cost_over}, expected {expected_over}"
);
}
#[test]
fn test_openai_pricing_accuracy() {
let mut tracker = CostTracker::new();
let pricing = openai_gpt4o_pricing();
let usage = TokenUsage {
prompt_tokens: 5_000,
completion_tokens: 2_000,
cache_read_input_tokens: 0,
cache_creation_input_tokens: 0,
};
let cost = tracker.record_usage(&usage, Some(&pricing));
let expected = 0.0125 + 0.02;
assert!(
(cost - expected).abs() < 1e-9,
"OpenAI cost mismatch: got {cost}, expected {expected}"
);
}
#[test]
fn test_cost_sum_across_multiple_calls() {
let mut tracker = CostTracker::new();
let pricing = anthropic_sonnet_pricing();
let calls = vec![
TokenUsage {
prompt_tokens: 1_000,
completion_tokens: 500,
cache_read_input_tokens: 0,
cache_creation_input_tokens: 0,
},
TokenUsage {
prompt_tokens: 10_000,
completion_tokens: 2_000,
cache_read_input_tokens: 30_000,
cache_creation_input_tokens: 0,
},
TokenUsage {
prompt_tokens: 50_000,
completion_tokens: 5_000,
cache_read_input_tokens: 100_000,
cache_creation_input_tokens: 0,
},
TokenUsage {
prompt_tokens: 250_000, completion_tokens: 3_000,
cache_read_input_tokens: 0,
cache_creation_input_tokens: 0,
},
TokenUsage {
prompt_tokens: 500,
completion_tokens: 100,
cache_read_input_tokens: 0,
cache_creation_input_tokens: 0,
},
];
let mut sum = 0.0;
for usage in &calls {
let incremental = tracker.record_usage(usage, Some(&pricing));
sum += incremental;
}
assert!(
(tracker.total_cost_usd - sum).abs() < 1e-9,
"Cumulative cost {:.6} != sum of incremental costs {:.6}",
tracker.total_cost_usd,
sum
);
let expected_input: u64 = calls.iter().map(|u| u.prompt_tokens).sum();
let expected_output: u64 = calls.iter().map(|u| u.completion_tokens).sum();
assert_eq!(tracker.total_input_tokens, expected_input);
assert_eq!(tracker.total_output_tokens, expected_output);
assert_eq!(tracker.call_count, 5);
assert!(
tracker.total_cost_usd > 0.5,
"Total cost should be > $0.50 for this volume"
);
assert!(
tracker.total_cost_usd < 5.0,
"Total cost should be < $5.00 for this volume"
);
}
#[test]
fn test_zero_price_model() {
let mut tracker = CostTracker::new();
let pricing = PricingInfo {
input_price_per_million: 0.0,
output_price_per_million: 0.0,
};
let usage = TokenUsage {
prompt_tokens: 100_000,
completion_tokens: 50_000,
..Default::default()
};
let cost = tracker.record_usage(&usage, Some(&pricing));
assert_eq!(cost, 0.0);
assert_eq!(tracker.total_cost_usd, 0.0);
assert_eq!(tracker.total_input_tokens, 100_000);
assert_eq!(tracker.total_output_tokens, 50_000);
}
#[test]
fn test_no_budget_not_over() {
let tracker = CostTracker::new();
assert!(!tracker.is_over_budget());
assert_eq!(tracker.remaining_budget(), None);
}
#[test]
fn test_set_budget() {
let mut tracker = CostTracker::new();
tracker.set_budget(1.0);
assert_eq!(tracker.budget_usd, Some(1.0));
assert!(!tracker.is_over_budget());
assert!((tracker.remaining_budget().unwrap() - 1.0).abs() < 1e-9);
}
#[test]
fn test_budget_not_exceeded() {
let mut tracker = CostTracker::new();
tracker.set_budget(1.0);
tracker.total_cost_usd = 0.5;
assert!(!tracker.is_over_budget());
assert!((tracker.remaining_budget().unwrap() - 0.5).abs() < 1e-9);
}
#[test]
fn test_budget_exactly_met() {
let mut tracker = CostTracker::new();
tracker.set_budget(1.0);
tracker.total_cost_usd = 1.0;
assert!(tracker.is_over_budget());
assert!((tracker.remaining_budget().unwrap()).abs() < 1e-9);
}
#[test]
fn test_budget_exceeded() {
let mut tracker = CostTracker::new();
tracker.set_budget(0.50);
tracker.total_cost_usd = 0.75;
assert!(tracker.is_over_budget());
assert_eq!(tracker.remaining_budget().unwrap(), 0.0);
}
#[test]
fn test_budget_exceeded_after_usage() {
let mut tracker = CostTracker::new();
tracker.set_budget(0.05);
let pricing = test_pricing();
let usage = TokenUsage {
prompt_tokens: 10_000,
completion_tokens: 5_000,
..Default::default()
};
tracker.record_usage(&usage, Some(&pricing));
assert!(tracker.is_over_budget());
assert_eq!(tracker.remaining_budget().unwrap(), 0.0);
}
#[test]
fn test_budget_serialization() {
let mut tracker = CostTracker::new();
tracker.set_budget(2.50);
tracker.total_cost_usd = 0.75;
let json = serde_json::to_string(&tracker).unwrap();
let restored: CostTracker = serde_json::from_str(&json).unwrap();
assert_eq!(restored.budget_usd, Some(2.50));
assert!(!restored.is_over_budget());
assert!((restored.remaining_budget().unwrap() - 1.75).abs() < 1e-9);
}