const HAIKU_INPUT: f64 = 0.25;
const HAIKU_OUTPUT: f64 = 1.25;
const SONNET_INPUT: f64 = 3.0;
const SONNET_OUTPUT: f64 = 15.0;
const OPUS_INPUT: f64 = 15.0;
const OPUS_OUTPUT: f64 = 75.0;
pub struct CostEstimator;
impl CostEstimator {
pub fn estimate_tokens(text: &str) -> usize {
(text.chars().count() as f64 / 4.0).ceil() as usize
}
pub fn estimate_cost(input_tokens: usize, output_tokens: usize, model: &str) -> f64 {
let (input_price, output_price) = Self::get_pricing(model);
let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price;
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price;
input_cost + output_cost
}
fn get_pricing(model: &str) -> (f64, f64) {
let model_lower = model.to_lowercase();
if model_lower.contains("haiku") {
(HAIKU_INPUT, HAIKU_OUTPUT)
} else if model_lower.contains("sonnet") {
(SONNET_INPUT, SONNET_OUTPUT)
} else if model_lower.contains("opus") {
(OPUS_INPUT, OPUS_OUTPUT)
} else if model_lower.contains("gpt-5-nano") || model_lower.contains("opencode") {
(0.1, 0.5)
} else {
(SONNET_INPUT, SONNET_OUTPUT)
}
}
pub fn format_cost(cost: f64) -> String {
if cost < 0.01 {
format!("${:.4}", cost)
} else if cost < 1.0 {
format!("${:.3}", cost)
} else {
format!("${:.2}", cost)
}
}
pub fn estimate_text_cost(input_text: &str, output_text: &str, model: &str) -> f64 {
let input_tokens = Self::estimate_tokens(input_text);
let output_tokens = Self::estimate_tokens(output_text);
Self::estimate_cost(input_tokens, output_tokens, model)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_estimation() {
let text = "Hello, world!";
let tokens = CostEstimator::estimate_tokens(text);
assert!(tokens >= 3 && tokens <= 4);
}
#[test]
fn test_cost_estimation_haiku() {
let cost = CostEstimator::estimate_cost(1_000_000, 1_000_000, "haiku");
assert_eq!(cost, HAIKU_INPUT + HAIKU_OUTPUT);
}
#[test]
fn test_cost_estimation_sonnet() {
let cost = CostEstimator::estimate_cost(1_000_000, 1_000_000, "sonnet");
assert_eq!(cost, SONNET_INPUT + SONNET_OUTPUT);
}
#[test]
fn test_cost_estimation_opus() {
let cost = CostEstimator::estimate_cost(1_000_000, 1_000_000, "opus");
assert_eq!(cost, OPUS_INPUT + OPUS_OUTPUT);
}
#[test]
fn test_format_cost() {
assert_eq!(CostEstimator::format_cost(0.001), "$0.0010");
assert_eq!(CostEstimator::format_cost(0.1), "$0.100");
assert_eq!(CostEstimator::format_cost(1.5), "$1.50");
}
}