use mod_interface::mod_interface;
#[ cfg( feature = "caching" ) ]
mod private
{
use crate::
{
environment ::{ EnvironmentInterface, OpenaiEnvironment },
error ::{ OpenAIError, Result },
};
use std::
{
collections ::HashMap,
sync ::Arc,
};
use core::
{
time ::Duration,
hash ::Hash,
};
use std::time::Instant;
use tokio::sync::RwLock;
use serde::{ Serialize, Deserialize };
use sha2::{ Sha256, Digest };
#[ derive( Debug, Clone ) ]
pub struct CacheConfig
{
pub max_entries : usize,
pub default_ttl : Duration,
pub max_response_size : usize,
pub enable_compression : bool,
pub cache_errors : bool,
pub cleanup_interval : Duration,
}
impl Default for CacheConfig
{
#[ inline ]
fn default() -> Self
{
Self
{
max_entries : 1000,
default_ttl : Duration::from_secs( 300 ), max_response_size : 1024 * 1024, enable_compression : true,
cache_errors : false,
cleanup_interval : Duration::from_secs( 60 ), }
}
}
#[ derive( Debug, Clone ) ]
pub struct CacheEntry
{
pub data : Vec< u8 >,
pub created_at : Instant,
pub ttl : Duration,
pub size_bytes : usize,
pub hit_count : u64,
pub method : String,
pub path : String,
}
impl CacheEntry
{
#[ inline ]
#[ must_use ]
pub fn is_expired( &self ) -> bool
{
self.created_at.elapsed() > self.ttl
}
#[ inline ]
pub fn record_hit( &mut self )
{
self.hit_count += 1;
}
}
#[ derive( Debug, Clone, PartialEq, Eq, Hash ) ]
pub struct CacheKey
{
pub method : String,
pub path : String,
pub body_hash : Option< String >,
pub query_hash : Option< String >,
}
impl CacheKey
{
#[ inline ]
#[ must_use ]
pub fn new( method : &str, path : &str, body : Option< &[u8] >, query : Option< &str > ) -> Self
{
let body_hash = body.map( Self::hash_bytes );
let query_hash = query.map( Self::hash_string );
Self
{
method : method.to_uppercase(),
path : path.to_string(),
body_hash,
query_hash,
}
}
#[ inline ]
#[ must_use ]
pub fn to_cache_key( &self ) -> String
{
let mut hasher = Sha256::new();
hasher.update( &self.method );
hasher.update( &self.path );
if let Some( ref body_hash ) = self.body_hash
{
hasher.update( body_hash );
}
if let Some( ref query_hash ) = self.query_hash
{
hasher.update( query_hash );
}
format!( "{:x}", hasher.finalize() )
}
fn hash_bytes( data : &[u8] ) -> String
{
let mut hasher = Sha256::new();
hasher.update( data );
format!( "{:x}", hasher.finalize() )
}
fn hash_string( data : &str ) -> String
{
Self::hash_bytes( data.as_bytes() )
}
}
impl core::fmt::Display for CacheKey
{
#[ inline ]
fn fmt( &self, f : &mut core::fmt::Formatter< '_ > ) -> core::fmt::Result
{
write!( f, "{}", self.to_cache_key() )
}
}
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct CacheStatistics
{
pub total_requests : u64,
pub cache_hits : u64,
pub cache_misses : u64,
pub hit_ratio : f64,
pub current_entries : usize,
pub total_cached_bytes : usize,
pub average_ttl_seconds : f64,
pub expired_entries_cleaned : u64,
pub average_response_size : f64,
}
#[ derive( Debug ) ]
pub struct ResponseCache
{
cache : Arc< RwLock< HashMap< String, CacheEntry > > >,
config : CacheConfig,
stats : Arc< RwLock< CacheStatistics > >,
cleanup_handle : Option< tokio::task::JoinHandle< () > >,
}
impl ResponseCache
{
#[ inline ]
#[ must_use ]
pub fn new() -> Self
{
Self::with_config( CacheConfig::default() )
}
#[ inline ]
#[ must_use ]
pub fn with_config( config : CacheConfig ) -> Self
{
let cache = Arc::new( RwLock::new( HashMap::new() ) );
let stats = Arc::new( RwLock::new( CacheStatistics
{
total_requests : 0,
cache_hits : 0,
cache_misses : 0,
hit_ratio : 0.0,
current_entries : 0,
total_cached_bytes : 0,
average_ttl_seconds : config.default_ttl.as_secs_f64(),
expired_entries_cleaned : 0,
average_response_size : 0.0,
} ) );
let mut instance = Self
{
cache,
config,
stats,
cleanup_handle : None,
};
if instance.config.cleanup_interval > Duration::ZERO
{
instance.start_cleanup_task();
}
instance
}
#[ inline ]
pub async fn get( &self, key : &CacheKey ) -> Option< Vec< u8 > >
{
let key_str = key.to_cache_key();
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
stats.total_requests += 1;
if let Some( entry ) = cache.get_mut( &key_str )
{
if !entry.is_expired()
{
entry.record_hit();
stats.cache_hits += 1;
stats.hit_ratio = stats.cache_hits as f64 / stats.total_requests as f64;
return Some( entry.data.clone() );
}
cache.remove( &key_str );
stats.current_entries = cache.len();
stats.expired_entries_cleaned += 1;
}
stats.cache_misses += 1;
stats.hit_ratio = stats.cache_hits as f64 / stats.total_requests as f64;
None
}
#[ inline ]
pub async fn put( &self, key : &CacheKey, data : Vec< u8 >, ttl : Option< Duration > ) -> Result< () >
{
let data_size = data.len();
if data_size > self.config.max_response_size
{
return Err( OpenAIError::Internal( format!(
"Response too large for caching : {} bytes (max : {})",
data_size,
self.config.max_response_size
) ).into() );
}
let key_str = key.to_cache_key();
let ttl = ttl.unwrap_or( self.config.default_ttl );
let entry = CacheEntry
{
data,
created_at : Instant::now(),
ttl,
size_bytes : data_size,
hit_count : 0,
method : key.method.clone(),
path : key.path.clone(),
};
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
if cache.len() >= self.config.max_entries
{
Self::evict_oldest_entry( &mut cache, &mut stats );
}
cache.insert( key_str, entry );
stats.current_entries = cache.len();
stats.total_cached_bytes += data_size;
if stats.total_requests > 0
{
stats.average_response_size = stats.total_cached_bytes as f64 / stats.current_entries as f64;
}
Ok( () )
}
#[ inline ]
pub async fn clear( &self )
{
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
cache.clear();
stats.current_entries = 0;
stats.total_cached_bytes = 0;
}
#[ inline ]
pub async fn get_statistics( &self ) -> CacheStatistics
{
let stats = self.stats.read().await;
stats.clone()
}
#[ inline ]
pub async fn cleanup_expired( &self ) -> usize
{
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
let initial_count = cache.len();
cache.retain( | _, entry | !entry.is_expired() );
let final_count = cache.len();
let cleaned_count = initial_count - final_count;
stats.current_entries = final_count;
stats.expired_entries_cleaned += cleaned_count as u64;
stats.total_cached_bytes = cache.values().map( | e | e.size_bytes ).sum();
cleaned_count
}
fn start_cleanup_task( &mut self )
{
let cache = Arc::clone( &self.cache );
let stats = Arc::clone( &self.stats );
let cleanup_interval = self.config.cleanup_interval;
let handle = tokio::spawn( async move
{
let mut interval = tokio::time::interval( cleanup_interval );
loop
{
interval.tick().await;
let mut cache_guard = cache.write().await;
let mut stats_guard = stats.write().await;
let initial_count = cache_guard.len();
cache_guard.retain( | _, entry | !entry.is_expired() );
let final_count = cache_guard.len();
let cleaned_count = initial_count - final_count;
if cleaned_count > 0
{
stats_guard.current_entries = final_count;
stats_guard.expired_entries_cleaned += cleaned_count as u64;
stats_guard.total_cached_bytes = cache_guard.values().map( | e | e.size_bytes ).sum();
}
}
} );
self.cleanup_handle = Some( handle );
}
fn evict_oldest_entry( cache : &mut HashMap< String, CacheEntry >, stats : &mut CacheStatistics )
{
if let Some( ( oldest_key, oldest_entry ) ) = cache.iter()
.min_by_key( | ( _, entry ) | entry.created_at )
.map( | ( k, v ) | ( k.clone(), v.clone() ) )
{
cache.remove( &oldest_key );
stats.total_cached_bytes = stats.total_cached_bytes.saturating_sub( oldest_entry.size_bytes );
}
}
}
impl Drop for ResponseCache
{
#[ inline ]
fn drop( &mut self )
{
if let Some( handle ) = self.cleanup_handle.take()
{
handle.abort();
}
}
}
#[ derive( Debug ) ]
pub struct CachedClient< E >
where
E: OpenaiEnvironment + EnvironmentInterface + Send + Sync + 'static,
{
client : crate::client::Client< E >,
cache : ResponseCache,
config : CacheConfig,
}
impl< E > CachedClient< E >
where
E: OpenaiEnvironment + EnvironmentInterface + Send + Sync + 'static,
{
#[ inline ]
pub fn new( client : crate::client::Client< E > ) -> Self
{
Self::with_cache_config( client, CacheConfig::default() )
}
#[ inline ]
pub fn with_cache_config( client : crate::client::Client< E >, config : CacheConfig ) -> Self
{
let cache = ResponseCache::with_config( config.clone() );
Self
{
client,
cache,
config,
}
}
#[ inline ]
pub async fn get_cached< T >( &self, path : &str, ttl : Option< Duration > ) -> Result< T >
where
T: serde::de::DeserializeOwned + serde::Serialize,
{
let cache_key = CacheKey::new( "GET", path, None, None );
if let Some( cached_data ) = self.cache.get( &cache_key ).await
{
let result : T = serde_json::from_slice( &cached_data )
.map_err( | e | OpenAIError::Internal( format!( "Failed to deserialize cached response : {e}" ) ) )?;
return Ok( result );
}
let response : T = self.client.get( path ).await?;
if let Ok( serialized ) = serde_json::to_vec( &response )
{
let _ = self.cache.put( &cache_key, serialized, ttl ).await;
}
Ok( response )
}
#[ inline ]
pub async fn post_cached< I, O >( &self, path : &str, body : &I, ttl : Option< Duration > ) -> Result< O >
where
I: serde::Serialize + Send + Sync,
O: serde::de::DeserializeOwned + serde::Serialize,
{
let body_bytes = serde_json::to_vec( body )
.map_err( | e | OpenAIError::Internal( format!( "Failed to serialize request body : {e}" ) ) )?;
let cache_key = CacheKey::new( "POST", path, Some( &body_bytes ), None );
if let Some( cached_data ) = self.cache.get( &cache_key ).await
{
let result : O = serde_json::from_slice( &cached_data )
.map_err( | e | OpenAIError::Internal( format!( "Failed to deserialize cached response : {e}" ) ) )?;
return Ok( result );
}
let response : O = self.client.post( path, body ).await?;
if ttl.is_some()
{
if let Ok( serialized ) = serde_json::to_vec( &response )
{
let _ = self.cache.put( &cache_key, serialized, ttl ).await;
}
}
Ok( response )
}
#[ inline ]
pub async fn get_cache_statistics( &self ) -> CacheStatistics
{
self.cache.get_statistics().await
}
#[ inline ]
pub async fn clear_cache( &self )
{
self.cache.clear().await;
}
#[ inline ]
pub fn client( &self ) -> &crate::client::Client< E >
{
&self.client
}
#[ inline ]
pub fn cache_config( &self ) -> &CacheConfig
{
&self.config
}
}
impl Default for ResponseCache
{
#[ inline ]
fn default() -> Self
{
Self::new()
}
}
#[ cfg( test ) ]
mod tests
{
use super::*;
#[ test ]
fn test_cache_key_generation()
{
let key1 = CacheKey::new( "GET", "/test", None, None );
let key2 = CacheKey::new( "GET", "/test", None, None );
let key3 = CacheKey::new( "POST", "/test", None, None );
assert_eq!( key1.to_cache_key(), key2.to_cache_key() );
assert_ne!( key1.to_cache_key(), key3.to_cache_key() );
}
#[ test ]
fn test_cache_entry_expiration()
{
let mut entry = CacheEntry
{
data : vec![ 1, 2, 3 ],
created_at : Instant::now().checked_sub( Duration::from_secs( 10 ) ).unwrap(),
ttl : Duration::from_secs( 5 ),
size_bytes : 3,
hit_count : 0,
method : "GET".to_string(),
path : "/test".to_string(),
};
assert!( entry.is_expired() );
entry.created_at = Instant::now();
assert!( !entry.is_expired() );
}
#[ tokio::test ]
async fn test_cache_basic_operations()
{
let cache = ResponseCache::new();
let key = CacheKey::new( "GET", "/test", None, None );
let data = vec![ 1, 2, 3, 4, 5 ];
assert!( cache.get( &key ).await.is_none() );
cache.put( &key, data.clone(), None ).await.unwrap();
let cached = cache.get( &key ).await.unwrap();
assert_eq!( cached, data );
let stats = cache.get_statistics().await;
assert_eq!( stats.cache_hits, 1 );
assert_eq!( stats.cache_misses, 1 );
assert_eq!( stats.current_entries, 1 );
}
#[ tokio::test ]
async fn test_cache_ttl_expiration()
{
let cache = ResponseCache::new();
let key = CacheKey::new( "GET", "/test", None, None );
let data = vec![ 1, 2, 3 ];
cache.put( &key, data, Some( Duration::from_millis( 1 ) ) ).await.unwrap();
tokio ::time::sleep( Duration::from_millis( 10 ) ).await;
assert!( cache.get( &key ).await.is_none() );
}
}
}
mod_interface!
{
orphan use private::
{
CacheConfig,
CacheEntry,
CacheKey,
CacheStatistics,
ResponseCache,
CachedClient,
};
}