agentcore 0.1.0

A minimal Rust crate that gives any application agentic capabilities.
Documentation
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

use serde::{Deserialize, Serialize};

use super::types::TokenUsage;

#[derive(Debug, Clone)]
pub struct ModelCosts {
    pub input_per_million: f64,
    pub output_per_million: f64,
    pub cache_read_per_million: f64,
    pub cache_write_per_million: f64,
}

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelUsage {
    pub input_tokens: u64,
    pub output_tokens: u64,
    pub cache_read_tokens: u64,
    pub cache_write_tokens: u64,
    pub cost_usd: f64,
    pub request_count: u64,
}

#[derive(Clone)]
pub struct CostTracker {
    inner: Arc<Mutex<CostTrackerInner>>,
}

impl CostTracker {
    pub fn new() -> Self {
        Self {
            inner: Arc::new(Mutex::new(CostTrackerInner {
                pricing: default_pricing(),
                usage: HashMap::new(),
                total_tool_calls: 0,
            })),
        }
    }

    pub fn model_pricing(&self, model: &str, costs: ModelCosts) {
        self.inner.lock().unwrap().pricing.insert(model.to_string(), costs);
    }

    pub fn record_usage(&self, model: &str, usage: &TokenUsage) {
        let mut inner = self.inner.lock().unwrap();

        let cost = inner.pricing.get(model).map_or(0.0, |p| {
            (usage.input_tokens as f64 * p.input_per_million
                + usage.output_tokens as f64 * p.output_per_million
                + usage.cache_read_input_tokens as f64 * p.cache_read_per_million
                + usage.cache_creation_input_tokens as f64 * p.cache_write_per_million)
                / 1_000_000.0
        });

        let entry = inner.usage.entry(model.to_string()).or_default();
        entry.input_tokens += usage.input_tokens;
        entry.output_tokens += usage.output_tokens;
        entry.cache_read_tokens += usage.cache_read_input_tokens;
        entry.cache_write_tokens += usage.cache_creation_input_tokens;
        entry.cost_usd += cost;
        entry.request_count += 1;
    }

    pub fn record_tool_calls(&self, count: u64) {
        self.inner.lock().unwrap().total_tool_calls += count;
    }

    pub fn total_cost_usd(&self) -> f64 {
        self.inner.lock().unwrap().usage.values().map(|u| u.cost_usd).sum()
    }

    pub fn total_requests(&self) -> u64 {
        self.inner.lock().unwrap().usage.values().map(|u| u.request_count).sum()
    }

    pub fn total_tool_calls(&self) -> u64 {
        self.inner.lock().unwrap().total_tool_calls
    }

    pub fn model_usage(&self) -> HashMap<String, ModelUsage> {
        self.inner.lock().unwrap().usage.clone()
    }

    pub fn summary(&self) -> String {
        let inner = self.inner.lock().unwrap();
        let total_cost: f64 = inner.usage.values().map(|u| u.cost_usd).sum();
        let mut result = format!("Total cost:            ${total_cost:.4}\n");

        let mut models: Vec<_> = inner.usage.iter().collect();
        models.sort_by(|(a, _), (b, _)| a.cmp(b));

        for (model, usage) in models {
            result.push_str(&format!(
                "{model}:  {} input, {} output, {} cache read (${:.4})\n",
                format_tokens(usage.input_tokens),
                format_tokens(usage.output_tokens),
                format_tokens(usage.cache_read_tokens),
                usage.cost_usd,
            ));
        }
        result
    }
}

#[derive(Debug)]
struct CostTrackerInner {
    pricing: HashMap<String, ModelCosts>,
    usage: HashMap<String, ModelUsage>,
    total_tool_calls: u64,
}

fn default_pricing() -> HashMap<String, ModelCosts> {
    HashMap::from([
        ("claude-haiku-4-5-20251001".into(), ModelCosts {
            input_per_million: 0.80,
            output_per_million: 4.0,
            cache_read_per_million: 0.08,
            cache_write_per_million: 1.0,
        }),
        ("claude-sonnet-4-20250514".into(), ModelCosts {
            input_per_million: 3.0,
            output_per_million: 15.0,
            cache_read_per_million: 0.30,
            cache_write_per_million: 3.75,
        }),
        ("claude-opus-4-20250514".into(), ModelCosts {
            input_per_million: 15.0,
            output_per_million: 75.0,
            cache_read_per_million: 1.50,
            cache_write_per_million: 18.75,
        }),
        ("mistral-large-latest".into(), ModelCosts {
            input_per_million: 2.0,
            output_per_million: 6.0,
            cache_read_per_million: 0.0,
            cache_write_per_million: 0.0,
        }),
        ("mistral-small-latest".into(), ModelCosts {
            input_per_million: 0.10,
            output_per_million: 0.30,
            cache_read_per_million: 0.0,
            cache_write_per_million: 0.0,
        }),
        ("codestral-latest".into(), ModelCosts {
            input_per_million: 0.30,
            output_per_million: 0.90,
            cache_read_per_million: 0.0,
            cache_write_per_million: 0.0,
        }),
        ("mistral-medium-2508".into(), ModelCosts {
            input_per_million: 2.75,
            output_per_million: 8.10,
            cache_read_per_million: 0.0,
            cache_write_per_million: 0.0,
        }),
    ])
}

fn format_tokens(count: u64) -> String {
    if count >= 1_000_000 {
        format!("{:.1}M", count as f64 / 1_000_000.0)
    } else if count >= 1_000 {
        format!("{:.1}k", count as f64 / 1_000.0)
    } else {
        count.to_string()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn empty_tracker_zero_cost() {
        let tracker = CostTracker::new();
        assert_eq!(tracker.total_cost_usd(), 0.0);
        assert_eq!(tracker.total_requests(), 0);
        assert_eq!(tracker.total_tool_calls(), 0);
    }

    #[test]
    fn record_usage_accumulates() {
        let tracker = CostTracker::new();
        let usage = TokenUsage {
            input_tokens: 1000,
            output_tokens: 500,
            cache_read_input_tokens: 0,
            cache_creation_input_tokens: 0,
        };
        tracker.record_usage("claude-sonnet-4-20250514", &usage);
        tracker.record_usage("claude-sonnet-4-20250514", &usage);

        let model_usage = tracker.model_usage();
        let sonnet = &model_usage["claude-sonnet-4-20250514"];
        assert_eq!(sonnet.input_tokens, 2000);
        assert_eq!(sonnet.output_tokens, 1000);
        assert_eq!(sonnet.request_count, 2);
    }

    #[test]
    fn multiple_models_tracked_separately() {
        let tracker = CostTracker::new();
        tracker.record_usage(
            "claude-sonnet-4-20250514",
            &TokenUsage {
                input_tokens: 100,
                output_tokens: 50,
                ..Default::default()
            },
        );
        tracker.record_usage(
            "claude-opus-4-20250514",
            &TokenUsage {
                input_tokens: 200,
                output_tokens: 100,
                ..Default::default()
            },
        );

        let usage = tracker.model_usage();
        assert_eq!(usage.len(), 2);
        assert_eq!(usage["claude-sonnet-4-20250514"].input_tokens, 100);
        assert_eq!(usage["claude-opus-4-20250514"].input_tokens, 200);
    }

    #[test]
    fn custom_pricing_applied() {
        let tracker = CostTracker::new();
        tracker.model_pricing(
            "custom-model",
            ModelCosts {
                input_per_million: 1.0,
                output_per_million: 0.0,
                cache_read_per_million: 0.0,
                cache_write_per_million: 0.0,
            },
        );
        tracker.record_usage(
            "custom-model",
            &TokenUsage {
                input_tokens: 1_000_000,
                output_tokens: 0,
                ..Default::default()
            },
        );
        let cost = tracker.total_cost_usd();
        assert!((cost - 1.0).abs() < 0.0001, "Expected $1.00, got ${cost}");
    }

    #[test]
    fn tool_calls_tracked() {
        let tracker = CostTracker::new();
        tracker.record_tool_calls(5);
        tracker.record_tool_calls(3);
        assert_eq!(tracker.total_tool_calls(), 8);
    }

    #[test]
    fn summary_contains_model_name() {
        let tracker = CostTracker::new();
        tracker.record_usage(
            "claude-sonnet-4-20250514",
            &TokenUsage {
                input_tokens: 100,
                output_tokens: 50,
                ..Default::default()
            },
        );
        let summary = tracker.summary();
        assert!(summary.contains("claude-sonnet-4-20250514"));
    }

    #[tokio::test]
    async fn concurrent_recording_thread_safe() {
        let tracker = CostTracker::new();
        let mut handles = Vec::new();

        for _ in 0..10 {
            let t = tracker.clone();
            handles.push(tokio::spawn(async move {
                for _ in 0..100 {
                    t.record_usage(
                        "claude-sonnet-4-20250514",
                        &TokenUsage {
                            input_tokens: 1,
                            output_tokens: 1,
                            ..Default::default()
                        },
                    );
                }
            }));
        }

        for handle in handles {
            handle.await.unwrap();
        }

        assert_eq!(tracker.total_requests(), 1000);
        let usage = tracker.model_usage();
        assert_eq!(usage["claude-sonnet-4-20250514"].input_tokens, 1000);
    }
}