1use crate::error::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::hash::Hash;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11#[derive(Debug, Clone)]
13pub struct CacheConfig {
14 pub default_ttl: Duration,
16 pub max_items: usize,
18 pub compress: bool,
20 pub enable_metrics: bool,
22}
23
24impl Default for CacheConfig {
25 fn default() -> Self {
26 Self {
27 default_ttl: Duration::from_secs(300), max_items: 1000,
29 compress: true,
30 enable_metrics: true,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37struct CacheEntry<T> {
38 data: T,
39 created_at: Instant,
40 ttl: Duration,
41 access_count: u64,
42 last_access: Instant,
43}
44
45impl<T> CacheEntry<T> {
46 fn new(data: T, ttl: Duration) -> Self {
47 let now = Instant::now();
48 Self {
49 data,
50 created_at: now,
51 ttl,
52 access_count: 0,
53 last_access: now,
54 }
55 }
56
57 fn is_expired(&self) -> bool {
58 self.created_at.elapsed() > self.ttl
59 }
60
61 fn access(&mut self) -> &T {
62 self.access_count += 1;
63 self.last_access = Instant::now();
64 &self.data
65 }
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub struct CacheKey {
71 endpoint: String,
72 params: String,
73}
74
75impl CacheKey {
76 pub fn new(endpoint: &str, params: &str) -> Self {
77 Self {
78 endpoint: endpoint.to_string(),
79 params: params.to_string(),
80 }
81 }
82}
83
84#[derive(Debug, Clone, Default)]
86pub struct CacheMetrics {
87 pub hits: u64,
88 pub misses: u64,
89 pub evictions: u64,
90 pub expired_items: u64,
91 pub total_items: usize,
92}
93
94impl CacheMetrics {
95 pub fn hit_rate(&self) -> f64 {
96 if self.hits + self.misses == 0 {
97 0.0
98 } else {
99 self.hits as f64 / (self.hits + self.misses) as f64
100 }
101 }
102}
103
104pub struct IntelligentCache<T>
106where
107 T: Clone + Send + Sync,
108{
109 cache: Arc<RwLock<HashMap<CacheKey, CacheEntry<T>>>>,
110 config: CacheConfig,
111 metrics: Arc<RwLock<CacheMetrics>>,
112}
113
114impl<T> IntelligentCache<T>
115where
116 T: Clone + Send + Sync,
117{
118 pub fn new(config: CacheConfig) -> Self {
119 Self {
120 cache: Arc::new(RwLock::new(HashMap::new())),
121 config,
122 metrics: Arc::new(RwLock::new(CacheMetrics::default())),
123 }
124 }
125
126 pub async fn get(&self, key: &CacheKey) -> Option<T> {
128 let mut cache = self.cache.write().await;
129
130 if let Some(entry) = cache.get_mut(key) {
131 if entry.is_expired() {
132 cache.remove(key);
134 if self.config.enable_metrics {
135 let mut metrics = self.metrics.write().await;
136 metrics.expired_items += 1;
137 metrics.total_items = cache.len();
138 }
139 return None;
140 }
141
142 let data = entry.access().clone();
144 if self.config.enable_metrics {
145 let mut metrics = self.metrics.write().await;
146 metrics.hits += 1;
147 }
148 Some(data)
149 } else {
150 if self.config.enable_metrics {
152 let mut metrics = self.metrics.write().await;
153 metrics.misses += 1;
154 }
155 None
156 }
157 }
158
159 pub async fn set(&self, key: CacheKey, value: T, ttl: Option<Duration>) {
161 let mut cache = self.cache.write().await;
162
163 if cache.len() >= self.config.max_items {
165 self.evict_lru(&mut cache).await;
166 }
167
168 let ttl = ttl.unwrap_or(self.config.default_ttl);
169 let entry = CacheEntry::new(value, ttl);
170 cache.insert(key, entry);
171
172 if self.config.enable_metrics {
173 let mut metrics = self.metrics.write().await;
174 metrics.total_items = cache.len();
175 }
176 }
177
178 async fn evict_lru(&self, cache: &mut HashMap<CacheKey, CacheEntry<T>>) {
180 if cache.is_empty() {
181 return;
182 }
183
184 let lru_key = cache
186 .iter()
187 .min_by_key(|(_, entry)| entry.last_access)
188 .map(|(key, _)| key.clone());
189
190 if let Some(key) = lru_key {
191 cache.remove(&key);
192 if self.config.enable_metrics {
193 let mut metrics = self.metrics.write().await;
194 metrics.evictions += 1;
195 }
196 }
197 }
198
199 pub async fn cleanup_expired(&self) {
201 let mut cache = self.cache.write().await;
202 let initial_size = cache.len();
203
204 cache.retain(|_, entry| !entry.is_expired());
205
206 if self.config.enable_metrics {
207 let mut metrics = self.metrics.write().await;
208 metrics.expired_items += (initial_size - cache.len()) as u64;
209 metrics.total_items = cache.len();
210 }
211 }
212
213 pub async fn get_metrics(&self) -> CacheMetrics {
215 self.metrics.read().await.clone()
216 }
217
218 pub async fn clear(&self) {
220 let mut cache = self.cache.write().await;
221 cache.clear();
222
223 if self.config.enable_metrics {
224 let mut metrics = self.metrics.write().await;
225 metrics.total_items = 0;
226 }
227 }
228
229 pub async fn size(&self) -> usize {
231 self.cache.read().await.len()
232 }
233}
234
235pub struct SmartCacheStrategy;
237
238impl SmartCacheStrategy {
239 pub fn get_ttl_for_endpoint(endpoint: &str) -> Duration {
241 match endpoint {
242 path if path.contains("quote") || path.contains("price") => Duration::from_secs(30),
244
245 path if path.contains("market-hours") => Duration::from_secs(3600 * 12),
247
248 path if path.contains("profile") || path.contains("company") => {
250 Duration::from_secs(3600 * 24)
251 }
252
253 path if path.contains("income-statement") || path.contains("balance-sheet") => {
255 Duration::from_secs(3600 * 6)
256 }
257
258 path if path.contains("historical") => Duration::from_secs(3600 * 2),
260
261 path if path.contains("news") || path.contains("calendar") => Duration::from_secs(900),
263
264 path if path.contains("symbols") || path.contains("exchanges") => {
266 Duration::from_secs(3600 * 24 * 7)
267 }
268
269 _ => Duration::from_secs(300),
271 }
272 }
273
274 pub fn should_cache_endpoint(endpoint: &str) -> bool {
276 if endpoint.contains("bulk") {
278 return false;
279 }
280
281 if endpoint.contains("stream") || endpoint.contains("websocket") {
283 return false;
284 }
285
286 true
288 }
289
290 pub fn generate_cache_key(endpoint: &str, query_params: Option<&str>) -> CacheKey {
292 let params = query_params.unwrap_or("");
293 CacheKey::new(endpoint, params)
294 }
295}
296
297pub struct CachedApiClient<T>
299where
300 T: Clone + Send + Sync + for<'de> Deserialize<'de> + Serialize,
301{
302 cache: IntelligentCache<T>,
303}
304
305impl<T> CachedApiClient<T>
306where
307 T: Clone + Send + Sync + for<'de> Deserialize<'de> + Serialize,
308{
309 pub fn new(config: CacheConfig) -> Self {
310 Self {
311 cache: IntelligentCache::new(config),
312 }
313 }
314
315 pub async fn get_cached<F, Fut>(&self, key: CacheKey, fetch_fn: F) -> Result<T>
317 where
318 F: FnOnce() -> Fut,
319 Fut: std::future::Future<Output = Result<T>>,
320 {
321 if let Some(cached_data) = self.cache.get(&key).await {
323 return Ok(cached_data);
324 }
325
326 let data = fetch_fn().await?;
328
329 if SmartCacheStrategy::should_cache_endpoint(&key.endpoint) {
331 let ttl = SmartCacheStrategy::get_ttl_for_endpoint(&key.endpoint);
332 self.cache.set(key, data.clone(), Some(ttl)).await;
333 }
334
335 Ok(data)
336 }
337
338 pub async fn metrics(&self) -> CacheMetrics {
340 self.cache.get_metrics().await
341 }
342
343 pub async fn cleanup(&self) {
345 self.cache.cleanup_expired().await;
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[tokio::test]
354 async fn test_cache_basic_operations() {
355 let config = CacheConfig::default();
356 let cache: IntelligentCache<String> = IntelligentCache::new(config);
357
358 let key = CacheKey::new("test", "params");
359 let value = "test_value".to_string();
360
361 assert!(cache.get(&key).await.is_none());
363
364 cache.set(key.clone(), value.clone(), None).await;
366 assert_eq!(cache.get(&key).await.unwrap(), value);
367 }
368
369 #[tokio::test]
370 async fn test_cache_expiration() {
371 let config = CacheConfig::default();
372 let cache: IntelligentCache<String> = IntelligentCache::new(config);
373
374 let key = CacheKey::new("test", "params");
375 let value = "test_value".to_string();
376
377 cache
379 .set(key.clone(), value, Some(Duration::from_millis(100)))
380 .await;
381
382 assert!(cache.get(&key).await.is_some());
384
385 tokio::time::sleep(Duration::from_millis(150)).await;
387
388 assert!(cache.get(&key).await.is_none());
390 }
391
392 #[test]
393 fn test_smart_cache_strategy() {
394 assert_eq!(
396 SmartCacheStrategy::get_ttl_for_endpoint("/quote"),
397 Duration::from_secs(30)
398 );
399 assert_eq!(
400 SmartCacheStrategy::get_ttl_for_endpoint("/profile"),
401 Duration::from_secs(3600 * 24)
402 );
403
404 assert!(SmartCacheStrategy::should_cache_endpoint("/quote"));
406 assert!(!SmartCacheStrategy::should_cache_endpoint("/bulk-data"));
407 }
408}