enact-core 0.0.2

Core agent runtime for Enact - Graph-Native AI agents
Documentation
//! Cost tracking tool for monitoring API usage

use crate::tool::Tool;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;

/// Cost tracking for LLM API calls
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CostMetrics {
    pub total_requests: u64,
    pub total_input_tokens: u64,
    pub total_output_tokens: u64,
    pub total_cost_usd: f64,
    pub requests_by_model: HashMap<String, u64>,
    pub cost_by_model: HashMap<String, f64>,
}

/// Cost tracker with thread-safe counters
pub struct CostTracker {
    metrics: Arc<std::sync::Mutex<CostMetrics>>,
    model_pricing: HashMap<String, ModelPricing>,
}

#[derive(Debug, Clone)]
struct ModelPricing {
    input_price_per_1k: f64,  // USD per 1K input tokens
    output_price_per_1k: f64, // USD per 1K output tokens
}

impl CostTracker {
    pub fn new() -> Self {
        let mut model_pricing = HashMap::new();

        // OpenAI pricing (as of 2024)
        model_pricing.insert(
            "gpt-4o".to_string(),
            ModelPricing {
                input_price_per_1k: 0.0025,
                output_price_per_1k: 0.01,
            },
        );
        model_pricing.insert(
            "gpt-4o-mini".to_string(),
            ModelPricing {
                input_price_per_1k: 0.00015,
                output_price_per_1k: 0.0006,
            },
        );
        model_pricing.insert(
            "gpt-4-turbo".to_string(),
            ModelPricing {
                input_price_per_1k: 0.01,
                output_price_per_1k: 0.03,
            },
        );

        // Anthropic pricing
        model_pricing.insert(
            "claude-3-opus".to_string(),
            ModelPricing {
                input_price_per_1k: 0.015,
                output_price_per_1k: 0.075,
            },
        );
        model_pricing.insert(
            "claude-3-sonnet".to_string(),
            ModelPricing {
                input_price_per_1k: 0.003,
                output_price_per_1k: 0.015,
            },
        );

        // Gemini pricing
        model_pricing.insert(
            "gemini-pro".to_string(),
            ModelPricing {
                input_price_per_1k: 0.0005,
                output_price_per_1k: 0.0015,
            },
        );

        Self {
            metrics: Arc::new(std::sync::Mutex::new(CostMetrics::default())),
            model_pricing,
        }
    }

    pub fn track_request(&self, model: &str, input_tokens: u64, output_tokens: u64) {
        let mut metrics = self.metrics.lock().unwrap();

        metrics.total_requests += 1;
        metrics.total_input_tokens += input_tokens;
        metrics.total_output_tokens += output_tokens;

        // Track by model
        *metrics
            .requests_by_model
            .entry(model.to_string())
            .or_insert(0) += 1;

        // Calculate cost
        let cost = if let Some(pricing) = self.model_pricing.get(model) {
            let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_price_per_1k;
            let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_price_per_1k;
            input_cost + output_cost
        } else {
            // Default pricing if model unknown
            let input_cost = (input_tokens as f64 / 1000.0) * 0.001;
            let output_cost = (output_tokens as f64 / 1000.0) * 0.002;
            input_cost + output_cost
        };

        metrics.total_cost_usd += cost;
        *metrics
            .cost_by_model
            .entry(model.to_string())
            .or_insert(0.0) += cost;
    }

    pub fn get_metrics(&self) -> CostMetrics {
        self.metrics.lock().unwrap().clone()
    }

    pub fn reset(&self) {
        let mut metrics = self.metrics.lock().unwrap();
        *metrics = CostMetrics::default();
    }
}

impl Default for CostTracker {
    fn default() -> Self {
        Self::new()
    }
}

/// Cost tool for querying cost metrics
pub struct CostTool {
    tracker: Arc<CostTracker>,
}

impl CostTool {
    pub fn new(tracker: Arc<CostTracker>) -> Self {
        Self { tracker }
    }
}

#[async_trait]
impl Tool for CostTool {
    fn name(&self) -> &str {
        "cost"
    }

    fn description(&self) -> &str {
        "Get cost metrics and usage statistics for LLM API calls"
    }

    fn parameters_schema(&self) -> serde_json::Value {
        json!({
            "type": "object",
            "properties": {
                "action": {
                    "type": "string",
                    "enum": ["get", "reset"],
                    "description": "Action to perform",
                    "default": "get"
                }
            }
        })
    }

    fn requires_network(&self) -> bool {
        false
    }

    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
        let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get");

        match action {
            "reset" => {
                self.tracker.reset();
                Ok(json!({
                    "success": true,
                    "message": "Cost metrics reset"
                }))
            }
            _ => {
                let metrics = self.tracker.get_metrics();
                Ok(json!({
                    "success": true,
                    "metrics": metrics
                }))
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_cost_tracker() {
        let tracker = CostTracker::new();

        tracker.track_request("gpt-4o-mini", 1000, 500);
        tracker.track_request("gpt-4o-mini", 2000, 1000);

        let metrics = tracker.get_metrics();
        assert_eq!(metrics.total_requests, 2);
        assert_eq!(metrics.total_input_tokens, 3000);
        assert_eq!(metrics.total_output_tokens, 1500);
        assert!(metrics.total_cost_usd > 0.0);

        // Check model breakdown
        assert_eq!(metrics.requests_by_model.get("gpt-4o-mini"), Some(&2));
    }

    #[test]
    fn test_cost_tracker_reset() {
        let tracker = CostTracker::new();
        tracker.track_request("gpt-4o-mini", 1000, 500);

        tracker.reset();
        let metrics = tracker.get_metrics();
        assert_eq!(metrics.total_requests, 0);
        assert_eq!(metrics.total_cost_usd, 0.0);
    }

    #[tokio::test]
    async fn test_cost_tool_get() {
        let tracker = Arc::new(CostTracker::new());
        tracker.track_request("gpt-4o-mini", 1000, 500);

        let tool = CostTool::new(tracker);
        let result = tool.execute(json!({"action": "get"})).await.unwrap();

        assert_eq!(result["success"], true);
        assert!(result["metrics"]["total_requests"].as_u64().unwrap() > 0);
    }
}