1#![ allow( clippy::missing_inline_in_public_items, clippy::unused_async ) ]
8
9mod private
11{
12 use std::
13 {
14 collections ::HashMap,
15 sync ::{ Arc, Mutex },
16 time ::Instant,
17 };
18 use core::
19 {
20 hash ::{ Hash, Hasher },
21 time ::Duration,
22 };
23 use tokio::sync::RwLock;
24 use core::sync::atomic::{ AtomicU32, AtomicU64, Ordering };
25 use serde::{ Serialize, Deserialize };
26 use std::collections::hash_map::DefaultHasher;
27
28 #[ derive( Debug, Clone ) ]
30 pub struct CacheConfig
31 {
32 pub max_size : usize,
34 pub default_ttl : Duration,
36 pub enable_cleanup : bool,
38 pub cleanup_interval : Duration,
40 }
41
42 impl Default for CacheConfig
43 {
44 #[ inline ]
45 fn default() -> Self
46 {
47 Self
48 {
49 max_size : 1000,
50 default_ttl : Duration::from_secs( 300 ), enable_cleanup : true,
52 cleanup_interval : Duration::from_secs( 60 ), }
54 }
55 }
56
57 #[ derive( Debug, Clone ) ]
59 pub struct CacheStatistics
60 {
61 pub hits : Arc< AtomicU64 >,
63 pub misses : Arc< AtomicU64 >,
65 pub evictions : Arc< AtomicU64 >,
67 pub entries : Arc< AtomicU32 >,
69 }
70
71 impl Default for CacheStatistics
72 {
73 #[ inline ]
74 fn default() -> Self
75 {
76 Self
77 {
78 hits : Arc::new( AtomicU64::new( 0 ) ),
79 misses : Arc::new( AtomicU64::new( 0 ) ),
80 evictions : Arc::new( AtomicU64::new( 0 ) ),
81 entries : Arc::new( AtomicU32::new( 0 ) ),
82 }
83 }
84 }
85
86 impl CacheStatistics
87 {
88 #[ inline ]
90 #[ must_use ]
91 pub fn hit_rate( &self ) -> f64
92 {
93 let hits = self.hits.load( Ordering::Relaxed );
94 let misses = self.misses.load( Ordering::Relaxed );
95 let total = hits + misses;
96
97 if total == 0
98 {
99 0.0
100 }
101 else
102 {
103 ( hits as f64 / total as f64 ) * 100.0
104 }
105 }
106 }
107
108 #[ derive( Debug, Clone ) ]
110 pub struct CacheEntry< T >
111 {
112 pub value : T,
114 pub timestamp : Instant,
116 pub ttl : Duration,
118 pub access_count : Arc< AtomicU32 >,
120 pub last_accessed : Arc< Mutex< Instant > >,
122 }
123
124 impl< T > CacheEntry< T >
125 {
126 #[ inline ]
128 pub fn new( value : T, ttl : Duration ) -> Self
129 {
130 let now = Instant::now();
131 Self
132 {
133 value,
134 timestamp : now,
135 ttl,
136 access_count : Arc::new( AtomicU32::new( 0 ) ),
137 last_accessed : Arc::new( Mutex::new( now ) ),
138 }
139 }
140
141 #[ inline ]
143 pub async fn is_expired( &self ) -> bool
144 {
145 self.timestamp.elapsed() > self.ttl
146 }
147
148 #[ inline ]
150 pub async fn touch( &self )
151 {
152 self.access_count.fetch_add( 1, Ordering::Relaxed );
153 if let Ok( mut last_accessed ) = self.last_accessed.lock()
154 {
155 *last_accessed = Instant::now();
156 }
157 }
158
159 #[ inline ]
161 pub fn age( &self ) -> Duration
162 {
163 self.timestamp.elapsed()
164 }
165 }
166
167 #[ derive( Debug, Clone, PartialEq, Eq, Hash ) ]
169 pub struct RequestCacheKey
170 {
171 pub endpoint : String,
173 pub method : String,
175 pub body_hash : u64,
177 pub headers_hash : u64,
179 }
180
181 impl RequestCacheKey
182 {
183 #[ inline ]
189 pub fn new< T: Serialize >(
190 endpoint : &str,
191 method : &str,
192 body : Option< &T >,
193 headers : &HashMap< String, String >
194 ) -> crate::error::Result< Self >
195 {
196 let body_hash = if let Some( body ) = body
197 {
198 let json = serde_json::to_string( body ).map_err( |e|
199 crate ::error::OpenAIError::Internal( format!( "Failed to serialize body for cache key : {e}" ) )
200 )?;
201 Self::hash_string( &json )
202 }
203 else
204 {
205 0
206 };
207
208 let relevant_headers : HashMap< String, String > = headers
210 .iter()
211 .filter( |( key, _ )| Self::is_relevant_header( key ) )
212 .map( |( k, v )| ( k.clone(), v.clone() ) )
213 .collect();
214
215 let headers_json = serde_json::to_string( &relevant_headers ).map_err( |e|
216 crate ::error::OpenAIError::Internal( format!( "Failed to serialize headers for cache key : {e}" ) )
217 )?;
218
219 Ok( Self
220 {
221 endpoint : endpoint.to_string(),
222 method : method.to_string(),
223 body_hash,
224 headers_hash : Self::hash_string( &headers_json ),
225 })
226 }
227
228 fn is_relevant_header( key : &str ) -> bool
230 {
231 matches!( key.to_lowercase().as_str(),
233 "content-type" | "accept" | "openai-organization" | "openai-project"
234 )
235 }
236
237 fn hash_string( s : &str ) -> u64
239 {
240 let mut hasher = DefaultHasher::new();
241 s.hash( &mut hasher );
242 hasher.finish()
243 }
244 }
245
246 #[ derive( Debug ) ]
248 pub struct RequestCache< K, V >
249 where
250 K: Hash + Eq + Clone + Send + Sync + 'static,
251 V: Clone + Send + Sync + 'static,
252 {
253 entries : Arc< RwLock< HashMap< K, CacheEntry< V > > > >,
255 max_size : usize,
257 default_ttl : Duration,
259 statistics : CacheStatistics,
261 #[ allow( dead_code ) ]
263 config : CacheConfig,
264 }
265
266 impl< K, V > RequestCache< K, V >
267 where
268 K: Hash + Eq + Clone + Send + Sync + 'static,
269 V: Clone + Send + Sync + 'static,
270 {
271 #[ inline ]
273 #[ must_use ]
274 pub fn new( max_size : usize, default_ttl : Duration ) -> Self
275 {
276 Self
277 {
278 entries : Arc::new( RwLock::new( HashMap::new() ) ),
279 max_size,
280 default_ttl,
281 statistics : CacheStatistics::default(),
282 config : CacheConfig
283 {
284 max_size,
285 default_ttl,
286 ..Default::default()
287 },
288 }
289 }
290
291 #[ inline ]
293 #[ must_use ]
294 pub fn with_config( config : CacheConfig ) -> Self
295 {
296 Self
297 {
298 entries : Arc::new( RwLock::new( HashMap::new() ) ),
299 max_size : config.max_size,
300 default_ttl : config.default_ttl,
301 statistics : CacheStatistics::default(),
302 config,
303 }
304 }
305
306 #[ inline ]
308 pub async fn get( &self, key : &K ) -> Option< V >
309 {
310 let entries = self.entries.read().await;
311
312 if let Some( entry ) = entries.get( key )
313 {
314 if entry.is_expired().await
315 {
316 drop( entries );
317 let mut entries = self.entries.write().await;
319 entries.remove( key );
320 self.statistics.entries.fetch_sub( 1, Ordering::Relaxed );
321 self.statistics.misses.fetch_add( 1, Ordering::Relaxed );
322 None
323 }
324 else
325 {
326 entry.touch().await;
327 self.statistics.hits.fetch_add( 1, Ordering::Relaxed );
328 Some( entry.value.clone() )
329 }
330 }
331 else
332 {
333 self.statistics.misses.fetch_add( 1, Ordering::Relaxed );
334 None
335 }
336 }
337
338 #[ inline ]
340 pub async fn insert( &self, key : K, value : V ) -> Option< V >
341 {
342 self.insert_with_ttl( key, value, self.default_ttl ).await
343 }
344
345 #[ inline ]
347 pub async fn insert_with_ttl( &self, key : K, value : V, ttl : Duration ) -> Option< V >
348 {
349 let mut entries = self.entries.write().await;
350
351 if entries.len() >= self.max_size && !entries.contains_key( &key )
353 {
354 self.evict_lru( &mut entries ).await;
355 }
356
357 let entry = CacheEntry::new( value, ttl );
358 let old_value = entries.insert( key, entry ).map( |e| e.value );
359
360 if old_value.is_none()
361 {
362 self.statistics.entries.fetch_add( 1, Ordering::Relaxed );
363 }
364
365 old_value
366 }
367
368 #[ inline ]
370 pub async fn remove( &self, key : &K ) -> Option< V >
371 {
372 let mut entries = self.entries.write().await;
373 if let Some( entry ) = entries.remove( key )
374 {
375 self.statistics.entries.fetch_sub( 1, Ordering::Relaxed );
376 Some( entry.value )
377 }
378 else
379 {
380 None
381 }
382 }
383
384 #[ inline ]
386 pub async fn contains_key( &self, key : &K ) -> bool
387 {
388 let entries = self.entries.read().await;
389 if let Some( entry ) = entries.get( key )
390 {
391 !entry.is_expired().await
392 }
393 else
394 {
395 false
396 }
397 }
398
399 #[ inline ]
401 pub async fn len( &self ) -> usize
402 {
403 let entries = self.entries.read().await;
404 entries.len()
405 }
406
407 #[ inline ]
409 pub async fn is_empty( &self ) -> bool
410 {
411 let entries = self.entries.read().await;
412 entries.is_empty()
413 }
414
415 #[ inline ]
417 pub async fn clear( &self )
418 {
419 let mut entries = self.entries.write().await;
420 let count = u32::try_from( entries.len() ).unwrap_or( u32::MAX );
421 entries.clear();
422 self.statistics.entries.store( 0, Ordering::Relaxed );
423 self.statistics.evictions.fetch_add( u64::from( count ), Ordering::Relaxed );
424 }
425
426 #[ inline ]
428 #[ must_use ]
429 pub fn statistics( &self ) -> &CacheStatistics
430 {
431 &self.statistics
432 }
433
434 #[ inline ]
436 pub async fn cleanup_expired( &self ) -> usize
437 {
438 let mut entries = self.entries.write().await;
439 let mut keys_to_remove = Vec::new();
440
441 for ( key, entry ) in entries.iter()
442 {
443 if entry.is_expired().await
444 {
445 keys_to_remove.push( key.clone() );
446 }
447 }
448
449 let removed_count = keys_to_remove.len();
450 for key in keys_to_remove
451 {
452 entries.remove( &key );
453 }
454
455 if removed_count > 0
456 {
457 self.statistics.entries.fetch_sub( u32::try_from( removed_count ).unwrap_or( u32::MAX ), Ordering::Relaxed );
458 self.statistics.evictions.fetch_add( u64::try_from( removed_count ).unwrap_or( u64::MAX ), Ordering::Relaxed );
459 }
460
461 removed_count
462 }
463
464 async fn evict_lru( &self, entries : &mut HashMap< K, CacheEntry< V > > )
466 {
467 if entries.is_empty()
468 {
469 return;
470 }
471
472 let mut oldest_key = None;
474 let mut oldest_time = Instant::now();
475
476 for ( key, entry ) in entries.iter()
477 {
478 if let Ok( last_accessed ) = entry.last_accessed.lock()
479 {
480 if oldest_key.is_none() || *last_accessed < oldest_time
481 {
482 oldest_time = *last_accessed;
483 oldest_key = Some( key.clone() );
484 }
485 }
486 }
487
488 if let Some( key ) = oldest_key
489 {
490 entries.remove( &key );
491 self.statistics.entries.fetch_sub( 1, Ordering::Relaxed );
492 self.statistics.evictions.fetch_add( 1, Ordering::Relaxed );
493 }
494 }
495 }
496
497 pub type ApiRequestCache = RequestCache< RequestCacheKey, serde_json::Value >;
499
500 impl ApiRequestCache
501 {
502 #[ inline ]
504 #[ must_use ]
505 pub fn new_api_cache() -> Self
506 {
507 Self::new( 1000, Duration::from_secs( 300 ) )
508 }
509
510 #[ inline ]
516 pub async fn cache_response< I: Serialize, O: Serialize >(
517 &self,
518 endpoint : &str,
519 method : &str,
520 request_body : Option< &I >,
521 headers : &HashMap< String, String >,
522 response : &O,
523 ) -> crate::error::Result< () >
524 {
525 let key = RequestCacheKey::new( endpoint, method, request_body, headers )?;
526 let value = serde_json::to_value( response ).map_err( |e|
527 crate ::error::OpenAIError::Internal( format!( "Failed to serialize response for caching : {e}" ) )
528 )?;
529
530 self.insert( key, value ).await;
531 Ok( () )
532 }
533
534 #[ inline ]
540 pub async fn get_response< I: Serialize, O: for< 'de > Deserialize< 'de > >(
541 &self,
542 endpoint : &str,
543 method : &str,
544 request_body : Option< &I >,
545 headers : &HashMap< String, String >,
546 ) -> crate::error::Result< Option< O > >
547 {
548 let key = RequestCacheKey::new( endpoint, method, request_body, headers )?;
549
550 if let Some( value ) = self.get( &key ).await
551 {
552 let response = serde_json::from_value( value ).map_err( |e|
553 crate ::error::OpenAIError::Internal( format!( "Failed to deserialize cached response : {e}" ) )
554 )?;
555 Ok( Some( response ) )
556 }
557 else
558 {
559 Ok( None )
560 }
561 }
562 }
563
564} crate ::mod_interface!
567{
568 exposed use
569 {
570 CacheConfig,
571 CacheStatistics,
572 CacheEntry,
573 RequestCacheKey,
574 RequestCache,
575 ApiRequestCache,
576 };
577}