use lru::LruCache;
use serde_json::Value;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use log::{debug, info};
#[derive(Debug, Clone)]
struct CacheEntry {
data: Value,
created_at: u64,
last_accessed: u64,
access_count: u64,
size_bytes: usize,
}
impl CacheEntry {
fn new(data: Value) -> Self {
let now = current_time_millis();
let size_bytes = estimate_json_size(&data);
Self {
data,
created_at: now,
last_accessed: now,
access_count: 1,
size_bytes,
}
}
fn access(&mut self) -> Value {
self.last_accessed = current_time_millis();
self.access_count += 1;
self.data.clone()
}
fn is_expired(&self, ttl_ms: u64) -> bool {
if ttl_ms == 0 {
return false; }
current_time_millis() - self.created_at > ttl_ms
}
fn age_ms(&self) -> u64 {
current_time_millis() - self.created_at
}
#[allow(dead_code)]
fn idle_time_ms(&self) -> u64 {
current_time_millis() - self.last_accessed
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct CacheKey {
url: String,
params_hash: u64,
}
impl CacheKey {
fn new(url: String, params: Option<&HashMap<String, String>>) -> Self {
let params_hash = if let Some(params) = params {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
let mut sorted_params: Vec<_> = params.iter().collect();
sorted_params.sort_by_key(|(k, _)| *k);
for (key, value) in sorted_params {
key.hash(&mut hasher);
value.hash(&mut hasher);
}
hasher.finish()
} else {
0
};
Self { url, params_hash }
}
}
impl Hash for CacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.url.hash(state);
self.params_hash.hash(state);
}
}
#[derive(Debug, Clone)]
pub struct ResponseCacheConfig {
pub max_entries: usize,
pub default_ttl_ms: u64,
pub quote_ttl_ms: u64,
pub search_ttl_ms: u64,
pub history_ttl_ms: u64,
pub max_memory_bytes: usize,
pub enable_size_eviction: bool,
pub cleanup_interval_ms: u64,
pub cache_errors: bool,
pub error_ttl_ms: u64,
}
impl Default for ResponseCacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
default_ttl_ms: 300_000, quote_ttl_ms: 60_000, search_ttl_ms: 3600_000, history_ttl_ms: 1800_000, max_memory_bytes: 50 * 1024 * 1024, enable_size_eviction: true,
cleanup_interval_ms: 300_000, cache_errors: false,
error_ttl_ms: 30_000, }
}
}
impl ResponseCacheConfig {
pub fn aggressive() -> Self {
Self {
max_entries: 5000,
default_ttl_ms: 900_000, quote_ttl_ms: 300_000, search_ttl_ms: 7200_000, history_ttl_ms: 3600_000, max_memory_bytes: 200 * 1024 * 1024, enable_size_eviction: true,
cleanup_interval_ms: 600_000, cache_errors: true,
error_ttl_ms: 120_000, }
}
pub fn conservative() -> Self {
Self {
max_entries: 200,
default_ttl_ms: 60_000, quote_ttl_ms: 30_000, search_ttl_ms: 900_000, history_ttl_ms: 600_000, max_memory_bytes: 10 * 1024 * 1024, enable_size_eviction: true,
cleanup_interval_ms: 120_000, cache_errors: false,
error_ttl_ms: 10_000, }
}
pub fn development() -> Self {
Self {
max_entries: 50,
default_ttl_ms: 10_000, quote_ttl_ms: 5_000, search_ttl_ms: 30_000, history_ttl_ms: 20_000, max_memory_bytes: 5 * 1024 * 1024, enable_size_eviction: true,
cleanup_interval_ms: 30_000, cache_errors: false,
error_ttl_ms: 2_000, }
}
pub fn disabled() -> Self {
Self {
max_entries: 0,
default_ttl_ms: 0,
quote_ttl_ms: 0,
search_ttl_ms: 0,
history_ttl_ms: 0,
max_memory_bytes: 0,
enable_size_eviction: false,
cleanup_interval_ms: 0,
cache_errors: false,
error_ttl_ms: 0,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub total_requests: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub cache_entries: usize,
pub total_memory_bytes: usize,
pub cache_evictions: u64,
pub expired_evictions: u64,
pub size_evictions: u64,
pub average_entry_size: f64,
pub hit_rate: f64,
pub oldest_entry_age_ms: u64,
}
impl CacheStats {
pub fn update_hit_rate(&mut self) {
if self.total_requests > 0 {
self.hit_rate = self.cache_hits as f64 / self.total_requests as f64;
}
}
}
pub struct ResponseCache {
config: ResponseCacheConfig,
cache: Arc<RwLock<LruCache<CacheKey, CacheEntry>>>,
stats: Arc<RwLock<CacheStats>>,
current_memory_usage: Arc<RwLock<usize>>,
}
impl ResponseCache {
pub fn new(config: ResponseCacheConfig) -> Self {
let capacity = if config.max_entries > 0 {
NonZeroUsize::new(config.max_entries).unwrap()
} else {
NonZeroUsize::new(1).unwrap() };
Self {
config,
cache: Arc::new(RwLock::new(LruCache::new(capacity))),
stats: Arc::new(RwLock::new(CacheStats::default())),
current_memory_usage: Arc::new(RwLock::new(0)),
}
}
pub fn with_default_config() -> Self {
Self::new(ResponseCacheConfig::default())
}
pub async fn get(&self, url: &str, params: Option<&HashMap<String, String>>) -> Option<Value> {
if self.config.max_entries == 0 {
return None; }
let key = CacheKey::new(url.to_string(), params);
{
let mut stats = self.stats.write().await;
stats.total_requests += 1;
}
let mut cache = self.cache.write().await;
if let Some(entry) = cache.get_mut(&key) {
let ttl = self.determine_ttl(url);
if !entry.is_expired(ttl) {
let result = entry.access();
{
let mut stats = self.stats.write().await;
stats.cache_hits += 1;
stats.update_hit_rate();
}
debug!("Cache hit for URL: {} (age: {}ms)", url, entry.age_ms());
return Some(result);
} else {
let removed = cache.pop(&key);
if let Some(removed_entry) = removed {
let mut memory = self.current_memory_usage.write().await;
*memory = memory.saturating_sub(removed_entry.size_bytes);
let mut stats = self.stats.write().await;
stats.expired_evictions += 1;
}
debug!("Expired cache entry removed for URL: {}", url);
}
}
{
let mut stats = self.stats.write().await;
stats.cache_misses += 1;
stats.update_hit_rate();
}
None
}
pub async fn put(&self, url: &str, params: Option<&HashMap<String, String>>, response: Value) {
if self.config.max_entries == 0 {
return; }
let key = CacheKey::new(url.to_string(), params);
let entry = CacheEntry::new(response);
let entry_size = entry.size_bytes;
if self.config.enable_size_eviction {
let current_memory = *self.current_memory_usage.read().await;
if current_memory + entry_size > self.config.max_memory_bytes {
self.evict_by_size(entry_size).await;
}
}
{
let mut cache = self.cache.write().await;
if let Some(evicted) = cache.push(key, entry) {
let mut memory = self.current_memory_usage.write().await;
*memory = memory.saturating_sub(evicted.1.size_bytes);
let mut stats = self.stats.write().await;
stats.cache_evictions += 1;
}
}
{
let mut memory = self.current_memory_usage.write().await;
*memory += entry_size;
}
self.update_stats().await;
debug!("Cached response for URL: {} (size: {} bytes)", url, entry_size);
}
fn determine_ttl(&self, url: &str) -> u64 {
if url.contains("/chart/") || url.contains("interval=") {
if url.contains("interval=1m") || url.contains("interval=5m") {
self.config.quote_ttl_ms } else {
self.config.history_ttl_ms }
} else if url.contains("/search") {
self.config.search_ttl_ms } else {
self.config.default_ttl_ms }
}
async fn evict_by_size(&self, needed_bytes: usize) {
let mut cache = self.cache.write().await;
let mut memory = self.current_memory_usage.write().await;
let mut freed_bytes = 0;
let mut evicted_count = 0;
let mut keys_to_remove = Vec::new();
while freed_bytes < needed_bytes && !cache.is_empty() {
if let Some((key, _)) = cache.peek_lru() {
keys_to_remove.push(key.clone());
if let Some((_, entry)) = cache.pop_lru() {
freed_bytes += entry.size_bytes;
evicted_count += 1;
}
} else {
break;
}
}
*memory = memory.saturating_sub(freed_bytes);
drop(cache);
drop(memory);
{
let mut stats = self.stats.write().await;
stats.size_evictions += evicted_count;
stats.cache_evictions += evicted_count;
}
info!("Evicted {} entries to free {} bytes", evicted_count, freed_bytes);
}
pub async fn cleanup_expired(&self) {
let cache = self.cache.write().await;
let memory = self.current_memory_usage.write().await;
let _expired_keys: Vec<String> = Vec::new();
let _freed_bytes = 0;
let _current_time = current_time_millis();
drop(cache);
drop(memory);
self.update_stats().await;
}
async fn update_stats(&self) {
let cache = self.cache.read().await;
let memory = self.current_memory_usage.read().await;
let mut stats = self.stats.write().await;
stats.cache_entries = cache.len();
stats.total_memory_bytes = *memory;
if stats.cache_entries > 0 {
stats.average_entry_size = stats.total_memory_bytes as f64 / stats.cache_entries as f64;
}
}
pub async fn stats(&self) -> CacheStats {
self.update_stats().await;
self.stats.read().await.clone()
}
pub async fn clear(&self) {
let mut cache = self.cache.write().await;
let mut memory = self.current_memory_usage.write().await;
cache.clear();
*memory = 0;
drop(cache);
drop(memory);
self.update_stats().await;
info!("Cache cleared");
}
pub async fn memory_usage(&self) -> usize {
*self.current_memory_usage.read().await
}
pub fn is_enabled(&self) -> bool {
self.config.max_entries > 0
}
}
fn estimate_json_size(value: &Value) -> usize {
match value {
Value::Null => 4,
Value::Bool(_) => 5,
Value::Number(_) => 8,
Value::String(s) => s.len() + 24, Value::Array(arr) => {
24 + arr.iter().map(estimate_json_size).sum::<usize>()
}
Value::Object(obj) => {
24 + obj.iter()
.map(|(k, v)| k.len() + estimate_json_size(v) + 16)
.sum::<usize>()
}
}
}
fn current_time_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn test_cache_put_and_get() {
let cache = ResponseCache::with_default_config();
let data = json!({"symbol": "AAPL", "price": 150.0});
assert!(cache.get("https://test.com", None).await.is_none());
cache.put("https://test.com", None, data.clone()).await;
let cached = cache.get("https://test.com", None).await;
assert!(cached.is_some());
assert_eq!(cached.unwrap(), data);
}
#[tokio::test]
async fn test_cache_expiration() {
let config = ResponseCacheConfig {
quote_ttl_ms: 50, ..Default::default()
};
let cache = ResponseCache::new(config);
let data = json!({"test": true});
cache.put("https://chart.test.com/interval=1m", None, data.clone()).await;
assert!(cache.get("https://chart.test.com/interval=1m", None).await.is_some());
sleep(Duration::from_millis(60)).await;
assert!(cache.get("https://chart.test.com/interval=1m", None).await.is_none());
}
#[tokio::test]
async fn test_cache_with_parameters() {
let cache = ResponseCache::with_default_config();
let data1 = json!({"result": "data1"});
let data2 = json!({"result": "data2"});
let mut params1 = HashMap::new();
params1.insert("symbol".to_string(), "AAPL".to_string());
let mut params2 = HashMap::new();
params2.insert("symbol".to_string(), "MSFT".to_string());
cache.put("https://test.com", Some(¶ms1), data1.clone()).await;
cache.put("https://test.com", Some(¶ms2), data2.clone()).await;
let cached1 = cache.get("https://test.com", Some(¶ms1)).await;
let cached2 = cache.get("https://test.com", Some(¶ms2)).await;
assert_eq!(cached1.unwrap(), data1);
assert_eq!(cached2.unwrap(), data2);
}
#[tokio::test]
async fn test_ttl_determination() {
let cache = ResponseCache::with_default_config();
assert_eq!(cache.determine_ttl("https://chart.yahoo.com/interval=1m"), cache.config.quote_ttl_ms);
assert_eq!(cache.determine_ttl("https://search.yahoo.com"), cache.config.search_ttl_ms);
assert_eq!(cache.determine_ttl("https://chart.yahoo.com/interval=1d"), cache.config.history_ttl_ms);
assert_eq!(cache.determine_ttl("https://other.yahoo.com"), cache.config.default_ttl_ms);
}
#[tokio::test]
async fn test_cache_stats() {
let cache = ResponseCache::with_default_config();
let data = json!({"test": true});
let stats = cache.stats().await;
assert_eq!(stats.total_requests, 0);
assert_eq!(stats.cache_hits, 0);
assert_eq!(stats.cache_misses, 0);
cache.get("https://test.com", None).await;
let stats = cache.stats().await;
assert_eq!(stats.total_requests, 1);
assert_eq!(stats.cache_misses, 1);
cache.put("https://test.com", None, data).await;
cache.get("https://test.com", None).await;
let stats = cache.stats().await;
assert_eq!(stats.total_requests, 2);
assert_eq!(stats.cache_hits, 1);
assert!(stats.hit_rate > 0.0);
}
#[test]
fn test_json_size_estimation() {
assert_eq!(estimate_json_size(&json!(null)), 4);
assert_eq!(estimate_json_size(&json!(true)), 5);
assert!(estimate_json_size(&json!("hello")) > 5);
assert!(estimate_json_size(&json!({"key": "value"})) > 10);
}
}