use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Default)]
pub struct CacheMetrics {
cache_hits: AtomicU64,
cache_misses: AtomicU64,
}
impl CacheMetrics {
pub fn new() -> Self {
Self {
cache_hits: AtomicU64::new(0),
cache_misses: AtomicU64::new(0),
}
}
pub fn record_hit(&self) {
self.cache_hits.fetch_add(1, Ordering::Relaxed);
}
pub fn record_miss(&self) {
self.cache_misses.fetch_add(1, Ordering::Relaxed);
}
pub fn hits(&self) -> u64 {
self.cache_hits.load(Ordering::Relaxed)
}
pub fn misses(&self) -> u64 {
self.cache_misses.load(Ordering::Relaxed)
}
pub fn hit_rate(&self) -> f64 {
let hits = self.hits();
let misses = self.misses();
let total = hits + misses;
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
pub fn estimated_savings_usd(&self) -> f64 {
const COST_PER_EMBEDDING: f64 = 0.00002;
self.hits() as f64 * COST_PER_EMBEDDING
}
pub fn embeddings_generated(&self) -> u64 {
self.misses()
}
pub fn estimated_cost_usd(&self) -> f64 {
const COST_PER_EMBEDDING: f64 = 0.00002;
self.embeddings_generated() as f64 * COST_PER_EMBEDDING
}
pub fn report(&self) -> String {
let hits = self.hits();
let misses = self.misses();
let total = hits + misses;
let hit_rate = self.hit_rate() * 100.0;
let miss_rate = if total > 0 {
(misses as f64 / total as f64) * 100.0
} else {
0.0
};
let cost = self.estimated_cost_usd();
format!(
"Cache metrics:\n \
- Chunks processed: {total}\n \
- Cache hits: {hits} ({hit_rate:.1}%)\n \
- Cache misses: {misses} ({miss_rate:.1}%)\n \
- Embeddings generated: {misses}\n \
- Estimated cost: ${cost:.4}"
)
}
pub fn reset(&self) {
self.cache_hits.store(0, Ordering::Relaxed);
self.cache_misses.store(0, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_metrics() {
let metrics = CacheMetrics::new();
assert_eq!(metrics.hits(), 0);
assert_eq!(metrics.misses(), 0);
assert_eq!(metrics.hit_rate(), 0.0);
}
#[test]
fn test_record_hit() {
let metrics = CacheMetrics::new();
metrics.record_hit();
assert_eq!(metrics.hits(), 1);
assert_eq!(metrics.misses(), 0);
metrics.record_hit();
assert_eq!(metrics.hits(), 2);
}
#[test]
fn test_record_miss() {
let metrics = CacheMetrics::new();
metrics.record_miss();
assert_eq!(metrics.hits(), 0);
assert_eq!(metrics.misses(), 1);
metrics.record_miss();
assert_eq!(metrics.misses(), 2);
}
#[test]
fn test_hit_rate() {
let metrics = CacheMetrics::new();
assert_eq!(metrics.hit_rate(), 0.0);
for _ in 0..8 {
metrics.record_hit();
}
for _ in 0..2 {
metrics.record_miss();
}
assert_eq!(metrics.hits(), 8);
assert_eq!(metrics.misses(), 2);
assert_eq!(metrics.hit_rate(), 0.8);
}
#[test]
fn test_estimated_cost() {
let metrics = CacheMetrics::new();
for _ in 0..2000 {
metrics.record_miss();
}
let cost = metrics.estimated_cost_usd();
assert!((cost - 0.04).abs() < 0.0001);
}
#[test]
fn test_estimated_savings() {
let metrics = CacheMetrics::new();
for _ in 0..8000 {
metrics.record_hit();
}
let savings = metrics.estimated_savings_usd();
assert!((savings - 0.16).abs() < 0.0001);
}
#[test]
fn test_report_format() {
let metrics = CacheMetrics::new();
for _ in 0..8000 {
metrics.record_hit();
}
for _ in 0..2000 {
metrics.record_miss();
}
let report = metrics.report();
assert!(report.contains("10000")); assert!(report.contains("8000")); assert!(report.contains("80.0%")); assert!(report.contains("2000")); assert!(report.contains("20.0%")); assert!(report.contains("$0.0400")); }
#[test]
fn test_reset() {
let metrics = CacheMetrics::new();
metrics.record_hit();
metrics.record_miss();
assert_eq!(metrics.hits(), 1);
assert_eq!(metrics.misses(), 1);
metrics.reset();
assert_eq!(metrics.hits(), 0);
assert_eq!(metrics.misses(), 0);
assert_eq!(metrics.hit_rate(), 0.0);
}
#[test]
fn test_thread_safety() {
use std::sync::Arc;
use std::thread;
let metrics = Arc::new(CacheMetrics::new());
let mut handles = vec![];
for _ in 0..10 {
let metrics_clone = Arc::clone(&metrics);
let handle = thread::spawn(move || {
for _ in 0..100 {
metrics_clone.record_hit();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(metrics.hits(), 1000);
}
}