use crate::ui::UI;
use colored::Colorize;
const CACHE_READ_RATE: f64 = 0.10;
const CACHE_CREATION_RATE: f64 = 1.25;
fn is_openai_model(model: &str) -> bool {
model.starts_with("gpt")
}
impl UI {
pub fn display_session_summary(
model: &str,
total_input_tokens: u32,
total_output_tokens: u32,
total_cache_read_tokens: u32,
total_cache_creation_tokens: u32,
peak_single_turn_input_tokens: u32,
) -> bool {
if total_input_tokens == 0 && total_output_tokens == 0 {
return false;
}
println!();
println!("{}", "─".repeat(50).bright_cyan());
println!("{}", "Session Summary".bright_cyan().bold());
println!("{}", "─".repeat(50).bright_cyan());
let estimated_cost = Self::calculate_cost(
model,
total_input_tokens,
total_output_tokens,
total_cache_read_tokens,
total_cache_creation_tokens,
peak_single_turn_input_tokens,
);
let total_input_seen =
Self::total_input_seen_by_model(model, total_input_tokens, total_cache_read_tokens)
+ total_cache_creation_tokens;
let cache_hit_pct = if total_input_seen > 0 {
(total_cache_read_tokens as f64 / total_input_seen as f64) * 100.0
} else {
0.0
};
println!(
"{:<20} {}",
"Input tokens:".bright_white(),
Self::format_number(total_input_seen).bright_green()
);
if total_cache_read_tokens > 0 || total_cache_creation_tokens > 0 {
println!(
"{:<20} {} {}",
" cache read:".bright_white(),
Self::format_number(total_cache_read_tokens).bright_green(),
format!("({:.0}% hit)", cache_hit_pct).dimmed()
);
if total_cache_creation_tokens > 0 {
println!(
"{:<20} {}",
" cache write:".bright_white(),
Self::format_number(total_cache_creation_tokens).bright_green()
);
}
}
println!(
"{:<20} {}",
"Output tokens:".bright_white(),
Self::format_number(total_output_tokens).bright_green()
);
println!(
"{:<20} {}",
"Total tokens:".bright_white(),
Self::format_number(total_input_seen + total_output_tokens).bright_green()
);
println!();
println!(
"{:<20} {}",
"Estimated cost:".bright_white().bold(),
format!("${:.4}", estimated_cost).bright_yellow().bold()
);
println!("{}", "─".repeat(50).bright_cyan());
println!();
true
}
fn total_input_seen_by_model(
model: &str,
total_input_tokens: u32,
cache_read_tokens: u32,
) -> u32 {
if is_openai_model(model) {
total_input_tokens
} else {
total_input_tokens + cache_read_tokens
}
}
fn calculate_cost(
model: &str,
input_tokens: u32,
output_tokens: u32,
cache_read_tokens: u32,
cache_creation_tokens: u32,
peak_single_turn_input_tokens: u32,
) -> f64 {
let info = crate::api::model_info::lookup(model);
let (input_price, output_price) = match info.premium_tier {
Some(tier) if peak_single_turn_input_tokens > tier.input_threshold => {
(tier.price_input_per_m, tier.price_output_per_m)
}
_ => (info.price_input_per_m, info.price_output_per_m),
};
let uncached = if is_openai_model(model) {
input_tokens.saturating_sub(cache_read_tokens)
} else {
input_tokens
};
let uncached_cost = (uncached as f64 / 1_000_000.0) * input_price;
let cached_cost = (cache_read_tokens as f64 / 1_000_000.0) * input_price * CACHE_READ_RATE;
let creation_cost =
(cache_creation_tokens as f64 / 1_000_000.0) * input_price * CACHE_CREATION_RATE;
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price;
uncached_cost + cached_cost + creation_cost + output_cost
}
pub fn format_turn_finished(elapsed: std::time::Duration) -> String {
let total_secs = elapsed.as_secs();
if total_secs < 1 {
"Finished in <1s".to_string()
} else if total_secs < 60 {
format!("Finished in {}s", total_secs)
} else if total_secs < 3600 {
format!("Finished in {}m {}s", total_secs / 60, total_secs % 60)
} else {
format!(
"Finished in {}h {}m",
total_secs / 3600,
(total_secs % 3600) / 60
)
}
}
fn format_number(n: u32) -> String {
let s = n.to_string();
let mut result = String::new();
for (i, c) in s.chars().rev().enumerate() {
if i > 0 && i % 3 == 0 {
result.push(',');
}
result.push(c);
}
result.chars().rev().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f64, b: f64) {
assert!(
(a - b).abs() < 1e-9,
"expected ≈{}, got {} (delta {})",
b,
a,
(a - b).abs()
);
}
#[test]
fn openai_cost_uses_full_rate_when_no_cache() {
let cost = UI::calculate_cost("gpt-5.5", 100_000, 5_000, 0, 0, 100_000);
approx(cost, 100_000.0 / 1e6 * 5.0 + 5_000.0 / 1e6 * 30.0);
}
#[test]
fn openai_cost_discounts_cache_reads_at_10pct() {
let cost = UI::calculate_cost("gpt-5.5", 100_000, 5_000, 75_000, 0, 100_000);
approx(cost, 0.1625 + 0.15);
}
#[test]
fn openai_cost_3x_lower_than_pre_fix_at_75pct_hit_input_only() {
let pre_fix_input = 100_000.0 / 1e6 * 5.0;
let post_fix_input = UI::calculate_cost("gpt-5.5", 100_000, 0, 75_000, 0, 100_000);
let ratio = pre_fix_input / post_fix_input;
assert!(
(2.9..=3.2).contains(&ratio),
"expected pre/post ratio ≈3x at 75% hit, got {:.2}x",
ratio
);
}
#[test]
fn anthropic_cost_input_tokens_already_excludes_cache() {
let cost = UI::calculate_cost("claude-opus-4-7", 25_000, 5_000, 75_000, 0, 100_000);
approx(cost, 0.1625 + 0.125);
}
#[test]
fn anthropic_cost_charges_creation_at_125pct() {
let cost = UI::calculate_cost("claude-opus-4-7", 0, 0, 0, 50_000, 0);
approx(cost, 50_000.0 / 1e6 * 5.0 * 1.25);
}
#[test]
fn cache_hit_does_not_underflow_when_read_exceeds_input() {
let cost = UI::calculate_cost("gpt-5.5", 50_000, 0, 100_000, 0, 100_000);
approx(cost, 100_000.0 / 1e6 * 5.0 * 0.10);
}
#[test]
fn cliff_crossing_doubles_input_rate_for_gpt_5_5() {
let standard = UI::calculate_cost("gpt-5.5", 100_000, 0, 0, 0, 200_000);
approx(standard, 100_000.0 / 1e6 * 5.0);
let premium = UI::calculate_cost("gpt-5.5", 100_000, 0, 0, 0, 300_000);
approx(premium, 100_000.0 / 1e6 * 10.0);
assert!((premium / standard - 2.0).abs() < 0.01);
}
#[test]
fn turn_finished_format_picks_unit_by_magnitude() {
use std::time::Duration;
assert_eq!(
UI::format_turn_finished(Duration::from_millis(400)),
"Finished in <1s"
);
assert_eq!(
UI::format_turn_finished(Duration::from_secs(7)),
"Finished in 7s"
);
assert_eq!(
UI::format_turn_finished(Duration::from_secs(94)),
"Finished in 1m 34s"
);
assert_eq!(
UI::format_turn_finished(Duration::from_secs(60)),
"Finished in 1m 0s"
);
assert_eq!(
UI::format_turn_finished(Duration::from_secs(3725)),
"Finished in 1h 2m"
);
}
#[test]
fn unknown_model_falls_back_without_panic() {
let cost = UI::calculate_cost("some-future-model", 1_000, 1_000, 0, 0, 1_000);
approx(cost, 1_000.0 / 1e6 * 3.0 + 1_000.0 / 1e6 * 15.0);
}
}