use crate::tool::Tool;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CostMetrics {
pub total_requests: u64,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub total_cost_usd: f64,
pub requests_by_model: HashMap<String, u64>,
pub cost_by_model: HashMap<String, f64>,
}
pub struct CostTracker {
metrics: Arc<std::sync::Mutex<CostMetrics>>,
model_pricing: HashMap<String, ModelPricing>,
}
#[derive(Debug, Clone)]
struct ModelPricing {
input_price_per_1k: f64, output_price_per_1k: f64, }
impl CostTracker {
pub fn new() -> Self {
let mut model_pricing = HashMap::new();
model_pricing.insert(
"gpt-4o".to_string(),
ModelPricing {
input_price_per_1k: 0.0025,
output_price_per_1k: 0.01,
},
);
model_pricing.insert(
"gpt-4o-mini".to_string(),
ModelPricing {
input_price_per_1k: 0.00015,
output_price_per_1k: 0.0006,
},
);
model_pricing.insert(
"gpt-4-turbo".to_string(),
ModelPricing {
input_price_per_1k: 0.01,
output_price_per_1k: 0.03,
},
);
model_pricing.insert(
"claude-3-opus".to_string(),
ModelPricing {
input_price_per_1k: 0.015,
output_price_per_1k: 0.075,
},
);
model_pricing.insert(
"claude-3-sonnet".to_string(),
ModelPricing {
input_price_per_1k: 0.003,
output_price_per_1k: 0.015,
},
);
model_pricing.insert(
"gemini-pro".to_string(),
ModelPricing {
input_price_per_1k: 0.0005,
output_price_per_1k: 0.0015,
},
);
Self {
metrics: Arc::new(std::sync::Mutex::new(CostMetrics::default())),
model_pricing,
}
}
pub fn track_request(&self, model: &str, input_tokens: u64, output_tokens: u64) {
let mut metrics = self.metrics.lock().unwrap();
metrics.total_requests += 1;
metrics.total_input_tokens += input_tokens;
metrics.total_output_tokens += output_tokens;
*metrics
.requests_by_model
.entry(model.to_string())
.or_insert(0) += 1;
let cost = if let Some(pricing) = self.model_pricing.get(model) {
let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_price_per_1k;
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_price_per_1k;
input_cost + output_cost
} else {
let input_cost = (input_tokens as f64 / 1000.0) * 0.001;
let output_cost = (output_tokens as f64 / 1000.0) * 0.002;
input_cost + output_cost
};
metrics.total_cost_usd += cost;
*metrics
.cost_by_model
.entry(model.to_string())
.or_insert(0.0) += cost;
}
pub fn get_metrics(&self) -> CostMetrics {
self.metrics.lock().unwrap().clone()
}
pub fn reset(&self) {
let mut metrics = self.metrics.lock().unwrap();
*metrics = CostMetrics::default();
}
}
impl Default for CostTracker {
fn default() -> Self {
Self::new()
}
}
pub struct CostTool {
tracker: Arc<CostTracker>,
}
impl CostTool {
pub fn new(tracker: Arc<CostTracker>) -> Self {
Self { tracker }
}
}
#[async_trait]
impl Tool for CostTool {
fn name(&self) -> &str {
"cost"
}
fn description(&self) -> &str {
"Get cost metrics and usage statistics for LLM API calls"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["get", "reset"],
"description": "Action to perform",
"default": "get"
}
}
})
}
fn requires_network(&self) -> bool {
false
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("get");
match action {
"reset" => {
self.tracker.reset();
Ok(json!({
"success": true,
"message": "Cost metrics reset"
}))
}
_ => {
let metrics = self.tracker.get_metrics();
Ok(json!({
"success": true,
"metrics": metrics
}))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cost_tracker() {
let tracker = CostTracker::new();
tracker.track_request("gpt-4o-mini", 1000, 500);
tracker.track_request("gpt-4o-mini", 2000, 1000);
let metrics = tracker.get_metrics();
assert_eq!(metrics.total_requests, 2);
assert_eq!(metrics.total_input_tokens, 3000);
assert_eq!(metrics.total_output_tokens, 1500);
assert!(metrics.total_cost_usd > 0.0);
assert_eq!(metrics.requests_by_model.get("gpt-4o-mini"), Some(&2));
}
#[test]
fn test_cost_tracker_reset() {
let tracker = CostTracker::new();
tracker.track_request("gpt-4o-mini", 1000, 500);
tracker.reset();
let metrics = tracker.get_metrics();
assert_eq!(metrics.total_requests, 0);
assert_eq!(metrics.total_cost_usd, 0.0);
}
#[tokio::test]
async fn test_cost_tool_get() {
let tracker = Arc::new(CostTracker::new());
tracker.track_request("gpt-4o-mini", 1000, 500);
let tool = CostTool::new(tracker);
let result = tool.execute(json!({"action": "get"})).await.unwrap();
assert_eq!(result["success"], true);
assert!(result["metrics"]["total_requests"].as_u64().unwrap() > 0);
}
}