ai-agent-sdk 0.5.0

Idiomatic agent sdk inspired by the claude code source leak
Documentation
//! Model cost calculation.
//!
//! Provides cost estimation for different AI models similar to claude code.

use serde::{Deserialize, Serialize};

/// Model cost configuration (per million tokens)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCosts {
    /// Input tokens cost per million
    pub input_tokens: f64,
    /// Output tokens cost per million
    pub output_tokens: f64,
    /// Prompt cache write tokens cost per million
    pub prompt_cache_write_tokens: f64,
    /// Prompt cache read tokens cost per million
    pub prompt_cache_read_tokens: f64,
    /// Web search requests cost per search
    pub web_search_requests: f64,
}

impl ModelCosts {
    /// Calculate cost for input tokens
    pub fn input_cost(&self, tokens: u32) -> f64 {
        (tokens as f64 / 1_000_000.0) * self.input_tokens
    }

    /// Calculate cost for output tokens
    pub fn output_cost(&self, tokens: u32) -> f64 {
        (tokens as f64 / 1_000_000.0) * self.output_tokens
    }

    /// Calculate cost for cache write tokens
    pub fn cache_write_cost(&self, tokens: u32) -> f64 {
        (tokens as f64 / 1_000_000.0) * self.prompt_cache_write_tokens
    }

    /// Calculate cost for cache read tokens
    pub fn cache_read_cost(&self, tokens: u32) -> f64 {
        (tokens as f64 / 1_000_000.0) * self.prompt_cache_read_tokens
    }

    /// Calculate total cost for a usage record
    pub fn total_cost(&self, usage: &TokenUsage) -> f64 {
        self.input_cost(usage.input_tokens)
            + self.output_cost(usage.output_tokens)
            + self.cache_write_cost(usage.prompt_cache_write_tokens)
            + self.cache_read_cost(usage.prompt_cache_read_tokens)
    }
}

/// Token usage from API response
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
    pub input_tokens: u32,
    pub output_tokens: u32,
    #[serde(rename = "promptCacheWriteTokens")]
    pub prompt_cache_write_tokens: u32,
    #[serde(rename = "promptCacheReadTokens")]
    pub prompt_cache_read_tokens: u32,
}

impl TokenUsage {
    /// Total tokens used
    pub fn total(&self) -> u32 {
        self.input_tokens
            + self.output_tokens
            + self.prompt_cache_write_tokens
            + self.prompt_cache_read_tokens
    }
}

/// Common cost tiers

/// Standard pricing: $3 input / $15 output per M tokens
pub const COST_TIER_3_15: ModelCosts = ModelCosts {
    input_tokens: 3.0,
    output_tokens: 15.0,
    prompt_cache_write_tokens: 3.75,
    prompt_cache_read_tokens: 0.3,
    web_search_requests: 0.01,
};

/// Opus pricing: $15 input / $75 output per M tokens
pub const COST_TIER_15_75: ModelCosts = ModelCosts {
    input_tokens: 15.0,
    output_tokens: 75.0,
    prompt_cache_write_tokens: 18.75,
    prompt_cache_read_tokens: 1.5,
    web_search_requests: 0.01,
};

/// Mid-tier pricing: $5 input / $25 output per M tokens
pub const COST_TIER_5_25: ModelCosts = ModelCosts {
    input_tokens: 5.0,
    output_tokens: 25.0,
    prompt_cache_write_tokens: 6.25,
    prompt_cache_read_tokens: 0.5,
    web_search_requests: 0.01,
};

/// Fast mode pricing: $30 input / $150 output per M tokens
pub const COST_TIER_30_150: ModelCosts = ModelCosts {
    input_tokens: 30.0,
    output_tokens: 150.0,
    prompt_cache_write_tokens: 37.5,
    prompt_cache_read_tokens: 3.0,
    web_search_requests: 0.01,
};

/// Haiku 3.5 pricing: $0.80 input / $4 output per M tokens
pub const COST_HAIKU_35: ModelCosts = ModelCosts {
    input_tokens: 0.8,
    output_tokens: 4.0,
    prompt_cache_write_tokens: 1.0,
    prompt_cache_read_tokens: 0.08,
    web_search_requests: 0.01,
};

/// Haiku 4.5 pricing: $1 input / $5 output per M tokens
pub const COST_HAIKU_45: ModelCosts = ModelCosts {
    input_tokens: 1.0,
    output_tokens: 5.0,
    prompt_cache_write_tokens: 1.25,
    prompt_cache_read_tokens: 0.1,
    web_search_requests: 0.01,
};

/// Default cost for unknown models
pub const COST_DEFAULT: ModelCosts = COST_TIER_5_25;

/// Model cost registry
pub struct ModelCostRegistry {
    costs: std::collections::HashMap<String, ModelCosts>,
}

impl ModelCostRegistry {
    pub fn new() -> Self {
        let mut costs = std::collections::HashMap::new();

        // Anthropic models
        costs.insert("claude-opus-4-6".to_string(), COST_TIER_5_25);
        costs.insert("claude-opus-4-5".to_string(), COST_TIER_5_25);
        costs.insert("claude-opus-4-1".to_string(), COST_TIER_15_75);
        costs.insert("claude-opus-4".to_string(), COST_TIER_15_75);
        costs.insert("claude-sonnet-4-6".to_string(), COST_TIER_3_15);
        costs.insert("claude-sonnet-4-5".to_string(), COST_TIER_3_15);
        costs.insert("claude-sonnet-4".to_string(), COST_TIER_3_15);
        costs.insert("claude-sonnet-3-5".to_string(), COST_TIER_3_15);
        costs.insert("claude-haiku-4-5".to_string(), COST_HAIKU_45);
        costs.insert("claude-haiku-3-5".to_string(), COST_HAIKU_35);

        // MiniMax models
        costs.insert("MiniMaxAI/MiniMax-M2.5".to_string(), COST_TIER_3_15);
        costs.insert("MiniMaxAI/MiniMax-M2".to_string(), COST_TIER_3_15);

        // OpenAI models (for compatibility)
        costs.insert("gpt-4o".to_string(), COST_TIER_5_25);
        costs.insert("gpt-4o-mini".to_string(), COST_HAIKU_35);
        costs.insert("gpt-4-turbo".to_string(), COST_TIER_10_30);
        costs.insert("gpt-4".to_string(), COST_TIER_30_60);

        Self { costs }
    }

    /// Get cost for a model
    pub fn get(&self, model: &str) -> &ModelCosts {
        // Try exact match first
        if let Some(cost) = self.costs.get(model) {
            return cost;
        }

        // Try prefix match
        for (key, cost) in &self.costs {
            if model.starts_with(key) || key.starts_with(model) {
                return cost;
            }
        }

        &COST_DEFAULT
    }

    /// Register a custom model cost
    pub fn register(&mut self, model: &str, costs: ModelCosts) {
        self.costs.insert(model.to_string(), costs);
    }
}

impl Default for ModelCostRegistry {
    fn default() -> Self {
        Self::new()
    }
}

/// Pricing tier for GPT-4: $30 input / $60 output per M tokens
pub const COST_TIER_30_60: ModelCosts = ModelCosts {
    input_tokens: 30.0,
    output_tokens: 60.0,
    prompt_cache_write_tokens: 30.0,
    prompt_cache_read_tokens: 10.0,
    web_search_requests: 0.01,
};

/// Pricing tier for GPT-4 Turbo: $10 input / $30 output per M tokens
pub const COST_TIER_10_30: ModelCosts = ModelCosts {
    input_tokens: 10.0,
    output_tokens: 30.0,
    prompt_cache_write_tokens: 10.0,
    prompt_cache_read_tokens: 3.0,
    web_search_requests: 0.01,
};

/// Calculate cost from model name and usage
pub fn calculate_cost(model: &str, usage: &TokenUsage) -> f64 {
    let registry = ModelCostRegistry::new();
    let costs = registry.get(model);
    costs.total_cost(usage)
}

/// Format cost as dollars
pub fn format_cost(cost: f64) -> String {
    if cost < 0.01 {
        format!("${:.4}", cost)
    } else if cost < 1.0 {
        format!("${:.2}", cost)
    } else {
        format!("${:.4}", cost)
    }
}

/// Cost summary for display
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostSummary {
    pub input_cost: f64,
    pub output_cost: f64,
    pub cache_write_cost: f64,
    pub cache_read_cost: f64,
    pub total_cost: f64,
}

impl CostSummary {
    pub fn from_usage(model: &str, usage: &TokenUsage) -> Self {
        let registry = ModelCostRegistry::new();
        let costs = registry.get(model);

        Self {
            input_cost: costs.input_cost(usage.input_tokens),
            output_cost: costs.output_cost(usage.output_tokens),
            cache_write_cost: costs.cache_write_cost(usage.prompt_cache_write_tokens),
            cache_read_cost: costs.cache_read_cost(usage.prompt_cache_read_tokens),
            total_cost: costs.total_cost(usage),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_model_costs_input() {
        let costs = COST_TIER_3_15;
        assert_eq!(costs.input_cost(1_000_000), 3.0);
        assert_eq!(costs.input_cost(500_000), 1.5);
    }

    #[test]
    fn test_model_costs_output() {
        let costs = COST_TIER_3_15;
        assert_eq!(costs.output_cost(1_000_000), 15.0);
    }

    #[test]
    fn test_token_usage_total() {
        let usage = TokenUsage {
            input_tokens: 100,
            output_tokens: 50,
            prompt_cache_write_tokens: 25,
            prompt_cache_read_tokens: 75,
        };
        assert_eq!(usage.total(), 250);
    }

    #[test]
    fn test_model_cost_registry() {
        let registry = ModelCostRegistry::new();

        let costs = registry.get("claude-sonnet-4-6");
        assert_eq!(costs.input_tokens, 3.0);

        let costs = registry.get("claude-haiku-4-5");
        assert_eq!(costs.input_tokens, 1.0);
    }

    #[test]
    fn test_model_cost_registry_unknown() {
        let registry = ModelCostRegistry::new();
        let costs = registry.get("unknown-model");
        assert_eq!(costs.input_tokens, COST_DEFAULT.input_tokens);
    }

    #[test]
    fn test_calculate_cost() {
        let usage = TokenUsage {
            input_tokens: 1_000_000,
            output_tokens: 500_000,
            prompt_cache_write_tokens: 0,
            prompt_cache_read_tokens: 0,
        };

        let cost = calculate_cost("claude-sonnet-4-6", &usage);
        // $3 * 1 + $15 * 0.5 = $3 + $7.50 = $10.50
        assert!((cost - 10.5).abs() < 0.01);
    }

    #[test]
    fn test_format_cost() {
        assert_eq!(format_cost(0.001), "$0.0010");
        assert_eq!(format_cost(0.5), "$0.50");
        assert_eq!(format_cost(1.5), "$1.5000");
    }

    #[test]
    fn test_cost_summary() {
        let usage = TokenUsage {
            input_tokens: 1_000_000,
            output_tokens: 500_000,
            prompt_cache_write_tokens: 100_000,
            prompt_cache_read_tokens: 200_000,
        };

        let summary = CostSummary::from_usage("claude-sonnet-4-6", &usage);

        // Input: 1M * $3/M = $3
        assert!((summary.input_cost - 3.0).abs() < 0.01);
        // Output: 500K * $15/M = $7.50
        assert!((summary.output_cost - 7.5).abs() < 0.01);
        // Cache write: 100K * $3.75/M = $0.375
        assert!((summary.cache_write_cost - 0.375).abs() < 0.01);
        // Cache read: 200K * $0.3/M = $0.06
        assert!((summary.cache_read_cost - 0.06).abs() < 0.01);
    }
}