use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Provider {
Anthropic,
OpenAI,
Bedrock,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Pricing {
pub input_per_mtok: f64,
pub cache_read_per_mtok: f64,
pub cache_write_5m_per_mtok: f64,
pub cache_write_1h_per_mtok: f64,
pub output_per_mtok: f64,
}
pub const DEFAULT_ANTHROPIC_PRICING: Pricing = Pricing {
input_per_mtok: 3.00,
cache_read_per_mtok: 0.30,
cache_write_5m_per_mtok: 3.75,
cache_write_1h_per_mtok: 6.00,
output_per_mtok: 15.00,
};
pub const DEFAULT_OPENAI_PRICING: Pricing = Pricing {
input_per_mtok: 2.50,
cache_read_per_mtok: 1.25,
cache_write_5m_per_mtok: 2.50,
cache_write_1h_per_mtok: 2.50,
output_per_mtok: 10.00,
};
pub const DEFAULT_BEDROCK_PRICING: Pricing = Pricing {
input_per_mtok: 3.00,
cache_read_per_mtok: 0.30,
cache_write_5m_per_mtok: 3.75,
cache_write_1h_per_mtok: 6.00,
output_per_mtok: 15.00,
};
impl Provider {
pub fn default_pricing(&self) -> Pricing {
match self {
Provider::Anthropic => DEFAULT_ANTHROPIC_PRICING,
Provider::OpenAI => DEFAULT_OPENAI_PRICING,
Provider::Bedrock => DEFAULT_BEDROCK_PRICING,
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u64,
pub cache_read_tokens: u64,
pub cache_creation_tokens: u64,
pub output_tokens: u64,
}
#[derive(Debug, Clone)]
pub struct CallMetrics {
pub provider: Provider,
pub prefix_id: String,
pub usage: Usage,
pub elapsed: Duration,
pub timestamp: SystemTime,
}
impl CallMetrics {
pub fn hit_ratio(&self) -> Option<f64> {
let cacheable = self.usage.cache_read_tokens + self.usage.cache_creation_tokens;
if cacheable == 0 {
None
} else {
Some(self.usage.cache_read_tokens as f64 / cacheable as f64)
}
}
pub fn cost_usd(&self, pricing: &Pricing) -> f64 {
(self.usage.input_tokens as f64 * pricing.input_per_mtok
+ self.usage.cache_read_tokens as f64 * pricing.cache_read_per_mtok
+ self.usage.cache_creation_tokens as f64 * pricing.cache_write_5m_per_mtok
+ self.usage.output_tokens as f64 * pricing.output_per_mtok)
/ 1_000_000.0
}
pub fn cost_saved_usd(&self, pricing: &Pricing) -> f64 {
let cacheable_total = self.usage.input_tokens
+ self.usage.cache_read_tokens
+ self.usage.cache_creation_tokens;
let full = cacheable_total as f64 * pricing.input_per_mtok;
let actual = self.usage.input_tokens as f64 * pricing.input_per_mtok
+ self.usage.cache_read_tokens as f64 * pricing.cache_read_per_mtok
+ self.usage.cache_creation_tokens as f64 * pricing.cache_write_5m_per_mtok;
(full - actual) / 1_000_000.0
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Aggregate {
pub calls: usize,
pub hit_ratio: Option<f64>,
pub tokens_read_from_cache: u64,
pub tokens_written_to_cache: u64,
pub cost_usd: f64,
pub cost_saved_usd: f64,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PrefixStats {
pub calls: usize,
pub hit_ratio: Option<f64>,
pub cost_saved_usd: f64,
}
type AlertHook = Arc<dyn Fn(&CallMetrics) + Send + Sync>;
#[derive(Clone)]
pub struct CacheTracker {
provider: Provider,
pricing: Pricing,
miss_alert_threshold: f64,
on_miss_alert: Option<AlertHook>,
history_size: usize,
history: Arc<Mutex<VecDeque<CallMetrics>>>,
}
impl CacheTracker {
pub fn new(provider: Provider) -> Self {
Self {
provider,
pricing: provider.default_pricing(),
miss_alert_threshold: 0.6,
on_miss_alert: None,
history_size: 10_000,
history: Arc::new(Mutex::new(VecDeque::with_capacity(1024))),
}
}
pub fn with_pricing(mut self, pricing: Pricing) -> Self {
self.pricing = pricing;
self
}
pub fn with_alert_threshold(mut self, threshold: f64) -> Self {
self.miss_alert_threshold = threshold;
self
}
pub fn with_alert_hook<F>(mut self, hook: F) -> Self
where
F: Fn(&CallMetrics) + Send + Sync + 'static,
{
self.on_miss_alert = Some(Arc::new(hook));
self
}
pub fn with_history_size(mut self, size: usize) -> Self {
self.history_size = size;
self
}
pub fn record(&self, prefix_id: String, usage: Usage, elapsed: Duration) -> CallMetrics {
let m = CallMetrics {
provider: self.provider,
prefix_id,
usage,
elapsed,
timestamp: SystemTime::now(),
};
{
let mut history = self.history.lock();
history.push_back(m.clone());
while history.len() > self.history_size {
history.pop_front();
}
}
let cacheable = usage.cache_read_tokens + usage.cache_creation_tokens;
if cacheable > 0 {
if let Some(ratio) = m.hit_ratio() {
if ratio < self.miss_alert_threshold {
if let Some(hook) = self.on_miss_alert.as_ref() {
hook(&m);
}
}
}
}
m
}
pub fn calls(&self) -> Vec<CallMetrics> {
self.history.lock().iter().cloned().collect()
}
pub fn reset(&self) {
self.history.lock().clear();
}
pub fn provider(&self) -> Provider {
self.provider
}
pub fn aggregate(&self) -> Aggregate {
let calls = self.calls();
if calls.is_empty() {
return Aggregate::default();
}
let read: u64 = calls.iter().map(|c| c.usage.cache_read_tokens).sum();
let write: u64 = calls.iter().map(|c| c.usage.cache_creation_tokens).sum();
let cacheable = read + write;
let hit_ratio = if cacheable == 0 {
None
} else {
Some(read as f64 / cacheable as f64)
};
Aggregate {
calls: calls.len(),
hit_ratio,
tokens_read_from_cache: read,
tokens_written_to_cache: write,
cost_usd: calls.iter().map(|c| c.cost_usd(&self.pricing)).sum(),
cost_saved_usd: calls.iter().map(|c| c.cost_saved_usd(&self.pricing)).sum(),
}
}
pub fn by_prefix(&self) -> HashMap<String, PrefixStats> {
let mut groups: HashMap<String, Vec<CallMetrics>> = HashMap::new();
for c in self.calls() {
groups.entry(c.prefix_id.clone()).or_default().push(c);
}
groups
.into_iter()
.map(|(id, ms)| {
let read: u64 = ms.iter().map(|c| c.usage.cache_read_tokens).sum();
let write: u64 = ms.iter().map(|c| c.usage.cache_creation_tokens).sum();
let cacheable = read + write;
let hit_ratio = if cacheable == 0 {
None
} else {
Some(read as f64 / cacheable as f64)
};
(
id,
PrefixStats {
calls: ms.len(),
hit_ratio,
cost_saved_usd: ms.iter().map(|c| c.cost_saved_usd(&self.pricing)).sum(),
},
)
})
.collect()
}
}