ai-dispatch 8.99.5

Multi-AI CLI team orchestrator
// Cost estimation for AI agent tasks.
// Maps model names to per-token pricing, computes task cost from token counts.
// Deps: cmd::config, store::Store, types::AgentKind

mod pricing_builtin;

use crate::cmd::config;
use crate::store::Store;
use crate::types::AgentKind;
use std::collections::HashMap;
use std::sync::OnceLock;

/// Price per 1M tokens (input, output) in USD
#[derive(Clone, Copy)]
pub(crate) struct ModelPricing {
    pub(crate) input_per_m: f64,
    pub(crate) output_per_m: f64,
}

static PRICING_OVERRIDES: OnceLock<HashMap<(AgentKind, String), ModelPricing>> = OnceLock::new();

/// Most recent completed model name for Gemini from the task DB (`None` = checked, no hits).
/// Unset means warm has not run yet (`gemini_fallback_pricing` uses static fallback pricing).
static GEMINI_DEFAULT_MODEL_CACHE: OnceLock<Option<String>> = OnceLock::new();

/// Populate [`GEMINI_DEFAULT_MODEL_CACHE`] once per process from [`Store::latest_default_model`].
pub fn warm_gemini_default_from_store(store: &Store) {
    let _ = GEMINI_DEFAULT_MODEL_CACHE.get_or_init(|| match store.latest_default_model(AgentKind::Gemini) {
        Ok(m) => m,
        Err(_) => None,
    });
}

/// Estimate cost in USD from total token count and model name.
/// Uses blended rate (assumes ~70% input, ~30% output) when breakdown unavailable.
pub fn estimate_cost(tokens: i64, model: Option<&str>, agent: AgentKind) -> Option<f64> {
    let pricing = resolve_pricing(model, agent)?;
    let blended_per_m = pricing.input_per_m * 0.7 + pricing.output_per_m * 0.3;
    Some(tokens as f64 * blended_per_m / 1_000_000.0)
}

/// Format cost for display: "$0.0012" or "free"
pub fn format_cost(cost_usd: Option<f64>) -> String {
    match cost_usd {
        Some(c) if c < 0.0001 => "free".to_string(),
        Some(c) if c < 0.01 => format!("${:.4}", c),
        Some(c) => format!("${:.2}", c),
        None => "".to_string(),
    }
}

pub fn format_cost_label(cost_usd: Option<f64>, agent: AgentKind) -> String {
    match agent {
        AgentKind::Cursor => match cost_usd {
            Some(c) if c > 0.0 => format_cost(cost_usd),
            _ => "subscription".to_string(),
        },
        AgentKind::Copilot => match cost_usd {
            Some(c) if c > 0.0 => format_cost(cost_usd),
            _ => "subscription".to_string(),
        },
        AgentKind::Kilo if cost_usd == Some(0.0) => "included".to_string(),
        AgentKind::Kilo => format_cost(cost_usd),
        _ => format_cost(cost_usd),
    }
}

fn resolve_pricing(model: Option<&str>, agent: AgentKind) -> Option<ModelPricing> {
    if let Some(m) = model {
        return model_pricing(m, agent);
    }
    match agent {
        AgentKind::Gemini => gemini_fallback_pricing(agent),
        AgentKind::Qwen => model_pricing("coder-model", agent),
        AgentKind::Codex => model_pricing("gpt-4.1", agent),
        AgentKind::Copilot => Some(ModelPricing {
            input_per_m: 0.0,
            output_per_m: 0.0,
        }),
        AgentKind::OpenCode => None,
        AgentKind::Cursor => Some(ModelPricing {
            input_per_m: 0.0,
            output_per_m: 0.0,
        }),
        AgentKind::Kilo => Some(ModelPricing {
            input_per_m: 0.0,
            output_per_m: 0.0,
        }),
        AgentKind::Claude => None,
        AgentKind::Codebuff => None,
        AgentKind::Droid => None,
        AgentKind::Oz => None,
        AgentKind::Custom => None,
    }
}

fn gemini_fallback_pricing(agent: AgentKind) -> Option<ModelPricing> {
    let model = GEMINI_DEFAULT_MODEL_CACHE
        .get()
        .and_then(|stored| stored.as_deref())
        .filter(|m| !m.is_empty());
    if let Some(m) = model {
        return model_pricing(m, agent);
    }
    model_pricing("gemini-3-flash-preview", agent)
}

fn pricing_overrides() -> &'static HashMap<(AgentKind, String), ModelPricing> {
    PRICING_OVERRIDES.get_or_init(|| {
        config::load_pricing_overrides()
            .unwrap_or_default()
            .into_iter()
            .filter_map(|model| {
                let agent = AgentKind::parse_str(&model.agent)?;
                Some((
                    (agent, model.model.to_lowercase()),
                    ModelPricing {
                        input_per_m: model.input_per_m,
                        output_per_m: model.output_per_m,
                    },
                ))
            })
            .collect()
    })
}

fn override_pricing(model: &str, agent: AgentKind) -> Option<ModelPricing> {
    let candidates = [
        model.to_lowercase(),
        model.rsplit('/').next().unwrap_or(model).to_lowercase(),
    ];
    for candidate in candidates {
        if let Some(pricing) = pricing_overrides().get(&(agent, candidate)) {
            return Some(*pricing);
        }
    }
    None
}

fn model_pricing(model: &str, agent: AgentKind) -> Option<ModelPricing> {
    if let Some(pricing) = override_pricing(model, agent) {
        return Some(pricing);
    }
    pricing_builtin::for_model_lower(&model.to_lowercase())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn kilo_and_free_models_zero_cost() {
        assert_eq!(
            estimate_cost(
                100_000,
                Some("opencode/mimo-v2-flash-free"),
                AgentKind::OpenCode
            ),
            Some(0.0)
        );
        assert_eq!(estimate_cost(100_000, None, AgentKind::Kilo), Some(0.0));
        assert_eq!(
            estimate_cost(100_000, Some("kilo/kilo/auto-free"), AgentKind::Kilo),
            Some(0.0)
        );
    }

    #[test]
    fn gpt41_cost_estimate() {
        let cost = estimate_cost(1_000_000, Some("gpt-4.1"), AgentKind::Codex).unwrap();
        assert!((cost - 3.8).abs() < 0.01);
    }

    #[test]
    fn composer2_cost_estimate() {
        let cost = estimate_cost(1_000_000, Some("composer-2"), AgentKind::Cursor).unwrap();
        assert!((cost - 1.10).abs() < 0.01);
    }

    #[test]
    fn unknown_model_returns_none() {
        let cost = estimate_cost(1000, Some("unknown-model"), AgentKind::OpenCode);
        assert!(cost.is_none());
    }

    #[test]
    fn format_cost_variants() {
        assert_eq!(format_cost(Some(0.0)), "free");
        assert_eq!(format_cost(Some(0.0038)), "$0.0038");
        assert_eq!(format_cost(Some(1.23)), "$1.23");
        assert_eq!(format_cost(None), "");
    }

    #[test]
    fn format_cost_label_special_cases() {
        assert_eq!(format_cost_label(Some(1.0), AgentKind::Cursor), "$1.00");
        assert_eq!(format_cost_label(None, AgentKind::Cursor), "subscription");
        assert_eq!(format_cost_label(None, AgentKind::Copilot), "subscription");
        assert_eq!(format_cost_label(Some(0.0), AgentKind::Kilo), "included");
    }

    #[test]
    fn format_cost_label_codebuff() {
        assert_eq!(format_cost_label(Some(1.5), AgentKind::Codebuff), "$1.50");
    }

    #[test]
    fn gemini_estimate_fallback_without_explicit_model_matches_gemini_three_flash_blend() {
        let blended =
            estimate_cost(1_000_000, None, AgentKind::Gemini).expect("gemini default pricing present");
        let expected = model_pricing("gemini-3-flash-preview", AgentKind::Gemini).unwrap();
        let blended_per_m = expected.input_per_m * 0.7 + expected.output_per_m * 0.3;
        assert!((blended - blended_per_m).abs() < 0.001);
    }

    #[test]
    fn gemini_3_preview_model_pricing() {
        let p = model_pricing("gemini-3.1-pro-preview", AgentKind::Gemini).unwrap();
        assert_eq!(p.input_per_m, 1.25);
        assert_eq!(p.output_per_m, 10.0);
        let p = model_pricing("gemini-3-flash-preview", AgentKind::Gemini).unwrap();
        assert_eq!(p.input_per_m, 0.30);
        assert_eq!(p.output_per_m, 2.50);
        let p = model_pricing("gemini-3-flash-lite-preview", AgentKind::Gemini).unwrap();
        assert_eq!(p.input_per_m, 0.10);
        assert_eq!(p.output_per_m, 0.40);
    }

    #[test]
    fn new_model_pricing_entries() {
        let pricing = model_pricing("claude-sonnet-4", AgentKind::Custom).unwrap();
        assert_eq!(pricing.input_per_m, 3.0);
        assert_eq!(pricing.output_per_m, 15.0);
        let pricing = model_pricing("gpt-5", AgentKind::Codex).unwrap();
        assert_eq!(pricing.input_per_m, 1.25);
        assert_eq!(pricing.output_per_m, 10.0);
        let pricing = model_pricing("gpt-5.4", AgentKind::Codex).unwrap();
        assert_eq!(pricing.input_per_m, 2.5);
        assert_eq!(pricing.output_per_m, 15.0);
        let pricing = model_pricing("gpt-5-mini", AgentKind::Codex).unwrap();
        assert_eq!(pricing.input_per_m, 0.25);
        assert_eq!(pricing.output_per_m, 2.0);
        let pricing = model_pricing("o3-mini", AgentKind::Custom).unwrap();
        assert_eq!(pricing.input_per_m, 1.10);
        assert_eq!(pricing.output_per_m, 4.40);
    }
}