use crate::yahoo_error::YahooError;
use dashmap::DashMap;
use lru::LruCache;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};
use tokio::sync::Mutex;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CacheKey {
pub symbol: String,
pub interval: String,
pub range: String,
pub params: BTreeMap<String, String>,
}
impl Hash for CacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.symbol.hash(state);
self.interval.hash(state);
self.range.hash(state);
for (key, value) in &self.params {
key.hash(state);
value.hash(state);
}
}
}
impl CacheKey {
pub fn new(symbol: &str, interval: &str, range: &str) -> Self {
Self {
symbol: symbol.to_string(),
interval: interval.to_string(),
range: range.to_string(),
params: BTreeMap::new(),
}
}
pub fn with_param(mut self, key: &str, value: &str) -> Self {
self.params.insert(key.to_string(), value.to_string());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheEntry {
pub data: Vec<u8>, pub created_at: SystemTime,
pub expires_at: SystemTime,
pub access_count: u64,
pub last_access: SystemTime,
pub size_bytes: usize,
pub source_layer: CacheLayer,
pub freshness_score: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CacheLayer {
Memory,
Persistent,
Distributed,
None,
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub l1_config: L1CacheConfig,
pub l2_config: L2CacheConfig,
pub l3_config: Option<L3CacheConfig>,
pub warming_config: CacheWarmingConfig,
pub performance_config: CachePerformanceConfig,
}
#[derive(Debug, Clone)]
pub struct L1CacheConfig {
pub max_entries: usize,
pub default_ttl: Duration,
pub enable_compression: bool,
pub compression_threshold: usize,
}
#[derive(Debug, Clone)]
pub struct L2CacheConfig {
pub cache_dir: String,
pub max_size_bytes: u64,
pub default_ttl: Duration,
pub enable_encryption: bool,
}
#[derive(Debug, Clone)]
pub struct L3CacheConfig {
pub connection_string: String,
pub key_prefix: String,
pub default_ttl: Duration,
pub pool_size: usize,
}
#[derive(Debug, Clone)]
pub struct CacheWarmingConfig {
pub enabled: bool,
pub popular_symbols: Vec<String>,
pub warming_intervals: HashMap<String, Duration>,
pub max_concurrent_requests: usize,
}
#[derive(Debug, Clone)]
pub struct CachePerformanceConfig {
pub enable_analytics: bool,
pub cleanup_interval: Duration,
pub enable_prefetching: bool,
pub adaptive_ttl: bool,
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub l1_stats: LayerStats,
pub l2_stats: LayerStats,
pub l3_stats: Option<LayerStats>,
pub overall_hit_rate: f64,
pub total_size_bytes: u64,
pub warming_stats: WarmingStats,
}
#[derive(Debug, Clone)]
pub struct LayerStats {
pub hits: u64,
pub misses: u64,
pub hit_rate: f64,
pub entry_count: usize,
pub size_bytes: u64,
pub avg_access_time_us: f64,
pub evictions: u64,
}
#[derive(Debug, Clone)]
pub struct WarmingStats {
pub requests_initiated: u64,
pub successful_warms: u64,
pub failed_warms: u64,
pub avg_warming_time_ms: f64,
}
pub struct AdvancedCache {
l1_cache: Arc<Mutex<LruCache<CacheKey, CacheEntry>>>,
l2_cache: Arc<DashMap<CacheKey, CacheEntry>>,
config: CacheConfig,
stats: Arc<RwLock<CacheStats>>,
popularity_tracker: Arc<DashMap<String, AtomicU64>>,
warming_manager: Arc<CacheWarmingManager>,
}
pub struct CacheWarmingManager {
warming_queue: Arc<DashMap<String, WarmingPriority>>,
active_warmings: Arc<DashMap<String, SystemTime>>,
warming_stats: Arc<RwLock<WarmingStats>>,
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub enum WarmingPriority {
Critical = 4,
High = 3,
Normal = 2,
Low = 1,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
l1_config: L1CacheConfig {
max_entries: 10000,
default_ttl: Duration::from_secs(300), enable_compression: true,
compression_threshold: 1024, },
l2_config: L2CacheConfig {
cache_dir: "./cache".to_string(),
max_size_bytes: 1024 * 1024 * 100, default_ttl: Duration::from_secs(3600), enable_encryption: false,
},
l3_config: None, warming_config: CacheWarmingConfig {
enabled: true,
popular_symbols: vec![
"AAPL".to_string(),
"GOOGL".to_string(),
"MSFT".to_string(),
"AMZN".to_string(),
"TSLA".to_string(),
"META".to_string(),
"NVDA".to_string(),
"BRK.B".to_string(),
"V".to_string(),
"JNJ".to_string(),
"WMT".to_string(),
"PG".to_string(),
],
warming_intervals: {
let mut intervals = HashMap::new();
intervals.insert("1d".to_string(), Duration::from_secs(60)); intervals.insert("1h".to_string(), Duration::from_secs(300)); intervals.insert("5m".to_string(), Duration::from_secs(30)); intervals
},
max_concurrent_requests: 10,
},
performance_config: CachePerformanceConfig {
enable_analytics: true,
cleanup_interval: Duration::from_secs(3600), enable_prefetching: true,
adaptive_ttl: true,
},
}
}
}
impl AdvancedCache {
pub fn new(config: CacheConfig) -> Self {
let l1_size = NonZeroUsize::new(config.l1_config.max_entries)
.unwrap_or(NonZeroUsize::new(1000).unwrap());
Self {
l1_cache: Arc::new(Mutex::new(LruCache::new(l1_size))),
l2_cache: Arc::new(DashMap::new()),
warming_manager: Arc::new(CacheWarmingManager::new()),
config,
stats: Arc::new(RwLock::new(CacheStats::default())),
popularity_tracker: Arc::new(DashMap::new()),
}
}
pub async fn get(&self, key: &CacheKey) -> Option<CacheEntry> {
let start_time = SystemTime::now();
if let Some(entry) = self.get_from_l1(key).await {
self.update_hit_stats(CacheLayer::Memory, start_time).await;
self.track_popularity(&key.symbol).await;
return Some(entry);
}
if let Some(mut entry) = self.get_from_l2(key).await {
self.update_hit_stats(CacheLayer::Persistent, start_time)
.await;
entry.source_layer = CacheLayer::Memory;
self.store_in_l1(key.clone(), entry.clone()).await;
self.track_popularity(&key.symbol).await;
return Some(entry);
}
if self.config.l3_config.is_some() {
if let Some(mut entry) = self.get_from_l3(key).await {
self.update_hit_stats(CacheLayer::Distributed, start_time)
.await;
entry.source_layer = CacheLayer::Memory;
self.store_in_l2(key.clone(), entry.clone()).await;
self.store_in_l1(key.clone(), entry.clone()).await;
self.track_popularity(&key.symbol).await;
return Some(entry);
}
}
self.update_miss_stats(start_time).await;
None
}
pub async fn put(&self, key: CacheKey, data: Vec<u8>) -> Result<(), YahooError> {
let now = SystemTime::now();
let ttl = self.calculate_adaptive_ttl(&key, &data).await;
let entry = CacheEntry {
data: self.compress_data_if_needed(&data).await,
created_at: now,
expires_at: now + ttl,
access_count: 0,
last_access: now,
size_bytes: data.len(),
source_layer: CacheLayer::Memory,
freshness_score: 1.0,
};
self.store_in_l1(key.clone(), entry.clone()).await;
self.store_in_l2(key.clone(), entry.clone()).await;
if self.config.l3_config.is_some() {
self.store_in_l3(key.clone(), entry.clone()).await;
}
Ok(())
}
pub async fn warm_cache(&self, symbols: Vec<String>) -> Result<(), YahooError> {
if !self.config.warming_config.enabled {
return Ok(());
}
for symbol in symbols {
self.warming_manager
.schedule_warming(symbol, WarmingPriority::High)
.await;
}
Ok(())
}
pub async fn get_stats(&self) -> CacheStats {
self.stats.read().unwrap().clone()
}
pub async fn maintenance(&self) -> Result<(), YahooError> {
self.cleanup_expired_entries().await?;
self.optimize_cache_layout().await?;
self.update_cache_statistics().await?;
Ok(())
}
pub async fn invalidate_stale_market_data(&self) -> Result<u64, YahooError> {
let mut invalidated_count = 0;
let market_hours = self.get_market_hours().await;
for entry_ref in self.l1_cache.lock().await.iter() {
let (key, entry) = entry_ref;
if self.is_market_data_stale(key, entry, &market_hours).await {
invalidated_count += 1;
}
}
Ok(invalidated_count)
}
async fn get_from_l1(&self, key: &CacheKey) -> Option<CacheEntry> {
let mut cache = self.l1_cache.lock().await;
if let Some(entry) = cache.get_mut(key) {
if entry.expires_at > SystemTime::now() {
entry.access_count += 1;
entry.last_access = SystemTime::now();
return Some(entry.clone());
} else {
cache.pop(key); }
}
None
}
async fn get_from_l2(&self, key: &CacheKey) -> Option<CacheEntry> {
let is_expired = if let Some(entry) = self.l2_cache.get(key) {
if entry.expires_at > SystemTime::now() {
return Some(entry.clone());
} else {
true }
} else {
false };
if is_expired {
self.l2_cache.remove(key);
}
None
}
async fn get_from_l3(&self, _key: &CacheKey) -> Option<CacheEntry> {
None
}
async fn store_in_l1(&self, key: CacheKey, entry: CacheEntry) {
let mut cache = self.l1_cache.lock().await;
cache.put(key, entry);
}
async fn store_in_l2(&self, key: CacheKey, entry: CacheEntry) {
self.l2_cache.insert(key, entry);
}
async fn store_in_l3(&self, _key: CacheKey, _entry: CacheEntry) {
}
async fn calculate_adaptive_ttl(&self, key: &CacheKey, _data: &[u8]) -> Duration {
if !self.config.performance_config.adaptive_ttl {
return self.config.l1_config.default_ttl;
}
if self.config.l1_config.default_ttl < Duration::from_secs(10) {
return self.config.l1_config.default_ttl;
}
match key.interval.as_str() {
"1m" | "5m" => Duration::from_secs(60), "15m" | "30m" => Duration::from_secs(300), "1h" => Duration::from_secs(900), "1d" => Duration::from_secs(3600), _ => self.config.l1_config.default_ttl,
}
}
async fn compress_data_if_needed(&self, data: &[u8]) -> Vec<u8> {
if self.config.l1_config.enable_compression
&& data.len() > self.config.l1_config.compression_threshold
{
data.to_vec()
} else {
data.to_vec()
}
}
async fn track_popularity(&self, symbol: &str) {
self.popularity_tracker
.entry(symbol.to_string())
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
async fn update_hit_stats(&self, layer: CacheLayer, start_time: SystemTime) {
let _access_time = start_time.elapsed().unwrap_or_default();
let mut stats = self.stats.write().unwrap();
match layer {
CacheLayer::Memory => {
stats.l1_stats.hits += 1;
stats.l1_stats.hit_rate = stats.l1_stats.hits as f64
/ (stats.l1_stats.hits + stats.l1_stats.misses) as f64;
}
CacheLayer::Persistent => {
stats.l2_stats.hits += 1;
stats.l2_stats.hit_rate = stats.l2_stats.hits as f64
/ (stats.l2_stats.hits + stats.l2_stats.misses) as f64;
}
CacheLayer::Distributed => {
if let Some(ref mut l3_stats) = stats.l3_stats {
l3_stats.hits += 1;
l3_stats.hit_rate =
l3_stats.hits as f64 / (l3_stats.hits + l3_stats.misses) as f64;
}
}
CacheLayer::None => {}
}
}
async fn update_miss_stats(&self, _start_time: SystemTime) {
let mut stats = self.stats.write().unwrap();
stats.l1_stats.misses += 1;
stats.l2_stats.misses += 1;
if let Some(ref mut l3_stats) = stats.l3_stats {
l3_stats.misses += 1;
}
}
async fn cleanup_expired_entries(&self) -> Result<(), YahooError> {
let now = SystemTime::now();
let mut l1_cache = self.l1_cache.lock().await;
let expired_keys: Vec<CacheKey> = l1_cache
.iter()
.filter(|(_, entry)| entry.expires_at <= now)
.map(|(key, _)| key.clone())
.collect();
for key in expired_keys {
l1_cache.pop(&key);
}
drop(l1_cache);
self.l2_cache.retain(|_, entry| entry.expires_at > now);
Ok(())
}
async fn optimize_cache_layout(&self) -> Result<(), YahooError> {
Ok(())
}
async fn update_cache_statistics(&self) -> Result<(), YahooError> {
let mut stats = self.stats.write().unwrap();
let l1_cache = futures::executor::block_on(self.l1_cache.lock());
stats.l1_stats.entry_count = l1_cache.len();
drop(l1_cache);
stats.l2_stats.entry_count = self.l2_cache.len();
let total_hits = stats.l1_stats.hits + stats.l2_stats.hits;
let total_misses = stats.l1_stats.misses + stats.l2_stats.misses;
stats.overall_hit_rate = if total_hits + total_misses > 0 {
total_hits as f64 / (total_hits + total_misses) as f64
} else {
0.0
};
Ok(())
}
async fn get_market_hours(&self) -> MarketHours {
MarketHours::default()
}
async fn is_market_data_stale(
&self,
_key: &CacheKey,
entry: &CacheEntry,
_market_hours: &MarketHours,
) -> bool {
entry.freshness_score < 0.5 || entry.expires_at <= SystemTime::now()
}
}
impl CacheWarmingManager {
pub fn new() -> Self {
Self {
warming_queue: Arc::new(DashMap::new()),
active_warmings: Arc::new(DashMap::new()),
warming_stats: Arc::new(RwLock::new(WarmingStats::default())),
}
}
pub async fn schedule_warming(&self, symbol: String, priority: WarmingPriority) {
self.warming_queue.insert(symbol, priority);
}
}
#[derive(Debug, Clone)]
pub struct MarketHours {
pub market_open: bool,
pub next_open: SystemTime,
pub next_close: SystemTime,
}
impl Default for MarketHours {
fn default() -> Self {
let now = SystemTime::now();
Self {
market_open: true, next_open: now,
next_close: now + Duration::from_secs(8 * 3600), }
}
}
impl Default for CacheStats {
fn default() -> Self {
Self {
l1_stats: LayerStats::default(),
l2_stats: LayerStats::default(),
l3_stats: None,
overall_hit_rate: 0.0,
total_size_bytes: 0,
warming_stats: WarmingStats::default(),
}
}
}
impl Default for LayerStats {
fn default() -> Self {
Self {
hits: 0,
misses: 0,
hit_rate: 0.0,
entry_count: 0,
size_bytes: 0,
avg_access_time_us: 0.0,
evictions: 0,
}
}
}
impl Default for WarmingStats {
fn default() -> Self {
Self {
requests_initiated: 0,
successful_warms: 0,
failed_warms: 0,
avg_warming_time_ms: 0.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_advanced_cache_creation() {
let config = CacheConfig::default();
let cache = AdvancedCache::new(config);
let stats = cache.get_stats().await;
assert_eq!(stats.l1_stats.entry_count, 0);
assert_eq!(stats.overall_hit_rate, 0.0);
}
#[tokio::test]
async fn test_cache_put_and_get() {
let config = CacheConfig::default();
let cache = AdvancedCache::new(config);
let key = CacheKey {
symbol: "AAPL".to_string(),
interval: "1d".to_string(),
range: "1mo".to_string(),
params: BTreeMap::new(),
};
let test_data = b"test market data".to_vec();
cache.put(key.clone(), test_data.clone()).await.unwrap();
let retrieved = cache.get(&key).await;
assert!(retrieved.is_some());
let entry = retrieved.unwrap();
assert_eq!(entry.data, test_data);
assert_eq!(entry.source_layer, CacheLayer::Memory);
}
#[tokio::test]
async fn test_cache_expiration() {
println!("🔍 Testing AdvancedCache methods step by step...");
let mut config = CacheConfig::default();
config.performance_config.adaptive_ttl = false;
config.l1_config.default_ttl = Duration::from_millis(100);
let cache = AdvancedCache::new(config);
println!("✅ AdvancedCache created");
let key = CacheKey {
symbol: "TSLA".to_string(),
interval: "5m".to_string(),
range: "1d".to_string(),
params: BTreeMap::new(),
};
let test_data = b"expired data".to_vec();
println!("🔍 STEP 1: Testing PUT method...");
match tokio::time::timeout(Duration::from_secs(3), cache.put(key.clone(), test_data)).await
{
Ok(Ok(())) => println!("✅ PUT completed successfully"),
Ok(Err(e)) => panic!("❌ PUT failed: {:?}", e),
Err(_) => panic!("🚨 PUT DEADLOCKED after 3 seconds"),
}
println!("🔍 STEP 2: Testing GET method (fresh item)...");
match tokio::time::timeout(Duration::from_secs(3), cache.get(&key)).await {
Ok(Some(_)) => println!("✅ GET fresh item completed successfully"),
Ok(None) => panic!("❌ Fresh item not found"),
Err(_) => panic!("🚨 GET FRESH ITEM DEADLOCKED after 3 seconds"),
}
println!("🔍 STEP 3: Waiting for expiration...");
tokio::time::sleep(Duration::from_millis(150)).await;
println!("🔍 STEP 4: Testing GET method (expired item)...");
match tokio::time::timeout(Duration::from_secs(3), cache.get(&key)).await {
Ok(None) => println!("✅ GET expired item completed - item properly removed"),
Ok(Some(_)) => panic!("❌ Expired item still exists"),
Err(_) => panic!("🚨 GET EXPIRED ITEM DEADLOCKED after 3 seconds - FOUND THE BUG!"),
}
println!("🎉 All steps completed without deadlock");
}
#[tokio::test]
async fn test_cache_warming() {
let config = CacheConfig::default();
let cache = AdvancedCache::new(config);
let symbols = vec!["AAPL".to_string(), "GOOGL".to_string(), "MSFT".to_string()];
let result = cache.warm_cache(symbols).await;
assert!(result.is_ok());
}
#[test]
fn test_cache_key_hashing() {
let key1 = CacheKey {
symbol: "AAPL".to_string(),
interval: "1d".to_string(),
range: "1mo".to_string(),
params: BTreeMap::new(),
};
let key2 = CacheKey {
symbol: "AAPL".to_string(),
interval: "1d".to_string(),
range: "1mo".to_string(),
params: BTreeMap::new(),
};
let key3 = CacheKey {
symbol: "GOOGL".to_string(),
interval: "1d".to_string(),
range: "1mo".to_string(),
params: BTreeMap::new(),
};
assert_eq!(key1, key2);
assert_ne!(key1, key3);
let mut hasher1 = std::collections::hash_map::DefaultHasher::new();
let mut hasher2 = std::collections::hash_map::DefaultHasher::new();
key1.hash(&mut hasher1);
key2.hash(&mut hasher2);
assert_eq!(hasher1.finish(), hasher2.finish());
}
}