enact-core 0.0.2

Core agent runtime for Enact - Graph-Native AI agents
Documentation
//! Token Usage and Cost Calculation
//!
//! This module provides structures for tracking token usage and calculating
//! costs for LLM API calls. It supports both per-call and cumulative tracking.
//!
//! ## Features
//!
//! - `TokenUsage`: Per-call token counts (prompt, completion, total)
//! - `CostCalculator`: Calculates costs based on token usage and pricing
//! - `UsageAccumulator`: Tracks cumulative usage across multiple calls

use serde::{Deserialize, Serialize};

// =============================================================================
// Token Usage
// =============================================================================

/// Token usage for a single LLM call
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct TokenUsage {
    /// Number of tokens in the prompt/input
    pub prompt_tokens: u32,
    /// Number of tokens in the completion/output
    pub completion_tokens: u32,
    /// Total tokens (prompt + completion)
    pub total_tokens: u32,
}

impl TokenUsage {
    /// Create a new TokenUsage instance
    pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
        Self {
            prompt_tokens,
            completion_tokens,
            total_tokens: prompt_tokens + completion_tokens,
        }
    }

    /// Create from total tokens only (when breakdown not available)
    pub fn from_total(total_tokens: u32) -> Self {
        Self {
            prompt_tokens: 0,
            completion_tokens: 0,
            total_tokens,
        }
    }

    /// Check if usage data is available
    pub fn is_empty(&self) -> bool {
        self.total_tokens == 0
    }
}

// =============================================================================
// Cost Calculator
// =============================================================================

/// Model pricing configuration (per 1M tokens)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
    /// Cost per 1 million input tokens (USD)
    pub input_cost_per_1m: f64,
    /// Cost per 1 million output tokens (USD)
    pub output_cost_per_1m: f64,
}

impl ModelPricing {
    /// Create pricing for a model
    pub fn new(input_cost_per_1m: f64, output_cost_per_1m: f64) -> Self {
        Self {
            input_cost_per_1m,
            output_cost_per_1m,
        }
    }

    /// Default pricing (conservative estimate)
    pub fn default_pricing() -> Self {
        Self {
            input_cost_per_1m: 3.0,   // $3 per 1M input tokens
            output_cost_per_1m: 15.0, // $15 per 1M output tokens
        }
    }

    /// GPT-4 pricing
    pub fn gpt4() -> Self {
        Self {
            input_cost_per_1m: 30.0,
            output_cost_per_1m: 60.0,
        }
    }

    /// GPT-4 Turbo pricing
    pub fn gpt4_turbo() -> Self {
        Self {
            input_cost_per_1m: 10.0,
            output_cost_per_1m: 30.0,
        }
    }

    /// GPT-4o pricing
    pub fn gpt4o() -> Self {
        Self {
            input_cost_per_1m: 2.5,
            output_cost_per_1m: 10.0,
        }
    }

    /// GPT-4o mini pricing
    pub fn gpt4o_mini() -> Self {
        Self {
            input_cost_per_1m: 0.15,
            output_cost_per_1m: 0.60,
        }
    }

    /// Claude 3 Opus pricing
    pub fn claude3_opus() -> Self {
        Self {
            input_cost_per_1m: 15.0,
            output_cost_per_1m: 75.0,
        }
    }

    /// Claude 3.5 Sonnet pricing
    pub fn claude35_sonnet() -> Self {
        Self {
            input_cost_per_1m: 3.0,
            output_cost_per_1m: 15.0,
        }
    }

    /// Claude 3 Haiku pricing
    pub fn claude3_haiku() -> Self {
        Self {
            input_cost_per_1m: 0.25,
            output_cost_per_1m: 1.25,
        }
    }
}

/// Cost calculator for token usage
pub struct CostCalculator;

impl CostCalculator {
    /// Calculate cost for a given token usage and pricing
    pub fn calculate_cost(usage: &TokenUsage, pricing: &ModelPricing) -> f64 {
        let input_cost = (usage.prompt_tokens as f64 / 1_000_000.0) * pricing.input_cost_per_1m;
        let output_cost =
            (usage.completion_tokens as f64 / 1_000_000.0) * pricing.output_cost_per_1m;
        input_cost + output_cost
    }

    /// Calculate cost using default pricing
    pub fn calculate_cost_default(usage: &TokenUsage) -> f64 {
        Self::calculate_cost(usage, &ModelPricing::default_pricing())
    }

    /// Get pricing for a model by name
    pub fn pricing_for_model(model: &str) -> ModelPricing {
        let model_lower = model.to_lowercase();

        if model_lower.contains("gpt-4o-mini") || model_lower.contains("gpt-4-1-mini") {
            ModelPricing::gpt4o_mini()
        } else if model_lower.contains("gpt-4o") {
            ModelPricing::gpt4o()
        } else if model_lower.contains("gpt-4-turbo") {
            ModelPricing::gpt4_turbo()
        } else if model_lower.contains("gpt-4") {
            ModelPricing::gpt4()
        } else if model_lower.contains("claude-3-opus") || model_lower.contains("opus") {
            ModelPricing::claude3_opus()
        } else if model_lower.contains("claude-3.5-sonnet")
            || model_lower.contains("claude-3-5-sonnet")
            || model_lower.contains("sonnet")
        {
            ModelPricing::claude35_sonnet()
        } else if model_lower.contains("claude-3-haiku") || model_lower.contains("haiku") {
            ModelPricing::claude3_haiku()
        } else {
            // Default to a conservative estimate
            ModelPricing::default_pricing()
        }
    }
}

// =============================================================================
// Usage Accumulator
// =============================================================================

/// Accumulates token usage and cost across multiple LLM calls
#[derive(Debug, Clone, Default)]
pub struct UsageAccumulator {
    /// Total prompt tokens across all calls
    pub total_prompt_tokens: u64,
    /// Total completion tokens across all calls
    pub total_completion_tokens: u64,
    /// Total tokens across all calls
    pub total_tokens: u64,
    /// Total cost across all calls (USD)
    pub total_cost_usd: f64,
    /// Number of LLM calls made
    pub call_count: u64,
}

impl UsageAccumulator {
    /// Create a new accumulator
    pub fn new() -> Self {
        Self::default()
    }

    /// Add usage from a single call
    pub fn add(&mut self, usage: &TokenUsage, cost: f64) {
        self.total_prompt_tokens += usage.prompt_tokens as u64;
        self.total_completion_tokens += usage.completion_tokens as u64;
        self.total_tokens += usage.total_tokens as u64;
        self.total_cost_usd += cost;
        self.call_count += 1;
    }

    /// Add usage with automatic cost calculation
    pub fn add_with_pricing(&mut self, usage: &TokenUsage, pricing: &ModelPricing) {
        let cost = CostCalculator::calculate_cost(usage, pricing);
        self.add(usage, cost);
    }

    /// Add usage with model-based pricing
    pub fn add_for_model(&mut self, usage: &TokenUsage, model: &str) {
        let pricing = CostCalculator::pricing_for_model(model);
        self.add_with_pricing(usage, &pricing);
    }

    /// Get average tokens per call
    pub fn avg_tokens_per_call(&self) -> f64 {
        if self.call_count == 0 {
            0.0
        } else {
            self.total_tokens as f64 / self.call_count as f64
        }
    }

    /// Get average cost per call
    pub fn avg_cost_per_call(&self) -> f64 {
        if self.call_count == 0 {
            0.0
        } else {
            self.total_cost_usd / self.call_count as f64
        }
    }
}

// =============================================================================
// Tests
// =============================================================================

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_token_usage_new() {
        let usage = TokenUsage::new(100, 50);
        assert_eq!(usage.prompt_tokens, 100);
        assert_eq!(usage.completion_tokens, 50);
        assert_eq!(usage.total_tokens, 150);
    }

    #[test]
    fn test_token_usage_from_total() {
        let usage = TokenUsage::from_total(200);
        assert_eq!(usage.prompt_tokens, 0);
        assert_eq!(usage.completion_tokens, 0);
        assert_eq!(usage.total_tokens, 200);
    }

    #[test]
    fn test_token_usage_is_empty() {
        let empty = TokenUsage::default();
        let non_empty = TokenUsage::new(10, 5);

        assert!(empty.is_empty());
        assert!(!non_empty.is_empty());
    }

    #[test]
    fn test_cost_calculator() {
        let usage = TokenUsage::new(1_000_000, 500_000);
        let pricing = ModelPricing::new(3.0, 15.0);

        let cost = CostCalculator::calculate_cost(&usage, &pricing);
        // 1M input tokens at $3/1M = $3
        // 500K output tokens at $15/1M = $7.50
        // Total = $10.50
        assert!((cost - 10.50).abs() < 0.001);
    }

    #[test]
    fn test_cost_calculator_small_usage() {
        let usage = TokenUsage::new(1000, 500);
        let pricing = ModelPricing::gpt4o_mini();

        let cost = CostCalculator::calculate_cost(&usage, &pricing);
        // 1K input tokens at $0.15/1M = $0.00015
        // 500 output tokens at $0.60/1M = $0.0003
        // Total = $0.00045
        assert!(cost > 0.0 && cost < 0.001);
    }

    #[test]
    fn test_pricing_for_model() {
        let gpt4 = CostCalculator::pricing_for_model("gpt-4");
        assert_eq!(gpt4.input_cost_per_1m, 30.0);

        let gpt4o = CostCalculator::pricing_for_model("gpt-4o");
        assert_eq!(gpt4o.input_cost_per_1m, 2.5);

        let claude = CostCalculator::pricing_for_model("claude-3.5-sonnet");
        assert_eq!(claude.input_cost_per_1m, 3.0);
    }

    #[test]
    fn test_usage_accumulator() {
        let mut acc = UsageAccumulator::new();

        acc.add(&TokenUsage::new(100, 50), 0.01);
        acc.add(&TokenUsage::new(200, 100), 0.02);

        assert_eq!(acc.total_prompt_tokens, 300);
        assert_eq!(acc.total_completion_tokens, 150);
        assert_eq!(acc.total_tokens, 450);
        assert!((acc.total_cost_usd - 0.03).abs() < 0.0001);
        assert_eq!(acc.call_count, 2);
    }

    #[test]
    fn test_usage_accumulator_averages() {
        let mut acc = UsageAccumulator::new();

        acc.add(&TokenUsage::new(100, 100), 0.02);
        acc.add(&TokenUsage::new(200, 200), 0.04);

        assert_eq!(acc.avg_tokens_per_call(), 300.0);
        assert!((acc.avg_cost_per_call() - 0.03).abs() < 0.0001);
    }

    #[test]
    fn test_usage_accumulator_empty() {
        let acc = UsageAccumulator::new();
        assert_eq!(acc.avg_tokens_per_call(), 0.0);
        assert_eq!(acc.avg_cost_per_call(), 0.0);
    }
}