#![cfg(all(feature = "async", feature = "pricing"))]
use std::collections::{HashMap, VecDeque};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Mutex;
use crate::pricing::PricingTable;
use crate::types::{ModelId, Usage};
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub struct CostPreview {
pub model: ModelId,
pub input_tokens: u32,
pub max_output_tokens: u32,
pub input_cost_usd: f64,
pub max_output_cost_usd: f64,
pub max_total_usd: f64,
}
impl CostPreview {
#[must_use]
pub fn cost_for(&self, output_tokens: u32, pricing: &PricingTable) -> f64 {
pricing.cost(
&self.model,
&Usage {
input_tokens: self.input_tokens,
output_tokens,
..Usage::default()
},
)
}
}
#[derive(Debug)]
pub struct CountTokensCache {
inner: Mutex<CacheInner>,
capacity: usize,
}
#[derive(Debug)]
struct CacheInner {
map: HashMap<u64, u32>,
order: VecDeque<u64>,
}
impl CountTokensCache {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
inner: Mutex::new(CacheInner {
map: HashMap::with_capacity(capacity),
order: VecDeque::with_capacity(capacity),
}),
capacity,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.lock().map.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.lock().map.is_empty()
}
pub fn clear(&self) {
let mut inner = self.lock();
inner.map.clear();
inner.order.clear();
}
#[must_use]
pub fn get(&self, key: u64) -> Option<u32> {
self.lock().map.get(&key).copied()
}
#[allow(clippy::map_entry)]
pub fn put(&self, key: u64, value: u32) {
let mut inner = self.lock();
if inner.map.contains_key(&key) {
inner.map.insert(key, value);
return;
}
if inner.order.len() >= self.capacity
&& let Some(oldest) = inner.order.pop_front()
{
inner.map.remove(&oldest);
}
inner.map.insert(key, value);
inner.order.push_back(key);
}
fn lock(&self) -> std::sync::MutexGuard<'_, CacheInner> {
self.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
#[must_use]
pub fn hash_request<T: serde::Serialize>(value: &T) -> u64 {
let bytes = serde_json::to_vec(value).unwrap_or_default();
let mut hasher = DefaultHasher::new();
bytes.hash(&mut hasher);
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cache_put_and_get_round_trips() {
let cache = CountTokensCache::new(4);
cache.put(1, 100);
cache.put(2, 200);
assert_eq!(cache.get(1), Some(100));
assert_eq!(cache.get(2), Some(200));
assert_eq!(cache.get(3), None);
}
#[test]
fn cache_evicts_oldest_when_full() {
let cache = CountTokensCache::new(2);
cache.put(1, 10);
cache.put(2, 20);
cache.put(3, 30); assert_eq!(cache.get(1), None);
assert_eq!(cache.get(2), Some(20));
assert_eq!(cache.get(3), Some(30));
}
#[test]
fn cache_replace_does_not_change_eviction_order() {
let cache = CountTokensCache::new(2);
cache.put(1, 10);
cache.put(2, 20);
cache.put(1, 11); cache.put(3, 30); assert_eq!(cache.get(1), None);
assert_eq!(cache.get(2), Some(20));
assert_eq!(cache.get(3), Some(30));
}
#[test]
fn cache_clear_drops_all_entries() {
let cache = CountTokensCache::new(2);
cache.put(1, 10);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn hash_request_is_stable_across_calls() {
let v = serde_json::json!({"a": 1, "b": [1, 2, 3]});
assert_eq!(hash_request(&v), hash_request(&v));
}
#[test]
fn hash_request_distinguishes_different_payloads() {
let a = serde_json::json!({"a": 1});
let b = serde_json::json!({"a": 2});
assert_ne!(hash_request(&a), hash_request(&b));
}
}