use mod_interface::mod_interface;
#[ cfg( feature = "batching" ) ]
mod private
{
use std::
{
collections ::{ HashMap, VecDeque },
sync ::Arc,
time ::Instant,
};
use core::
{
hash ::Hash,
time ::Duration,
};
use tokio::sync::{ RwLock, Notify };
use blake3::{ Hash as Blake3Hash, Hasher as Blake3Hasher };
#[ derive( Debug, Clone ) ]
pub struct BatchConfig
{
pub max_batch_size : usize,
pub flush_timeout : Duration,
pub max_concurrent_batches : usize,
pub enable_smart_batching : bool,
pub smart_batch_threshold : usize,
}
impl Default for BatchConfig
{
#[ inline ]
fn default() -> Self
{
Self
{
max_batch_size : 100,
flush_timeout : Duration::from_millis( 50 ),
max_concurrent_batches : 10,
enable_smart_batching : true,
smart_batch_threshold : 5,
}
}
}
#[ derive( Debug, Clone, PartialEq, Eq, Hash ) ]
pub struct RequestSignature
{
pub method : String,
pub path : String,
pub structure_hash : Blake3Hash,
}
impl RequestSignature
{
#[ inline ]
#[ must_use ]
pub fn new( method : &str, path : &str, body : &[u8] ) -> Self
{
let mut hasher = Blake3Hasher::new();
hasher.update( method.as_bytes() );
hasher.update( path.as_bytes() );
hasher.update( body );
let structure_hash = hasher.finalize();
Self
{
method : method.to_string(),
path : path.to_string(),
structure_hash,
}
}
#[ inline ]
#[ must_use ]
pub fn is_batchable_with( &self, other : &RequestSignature ) -> bool
{
self.method == other.method &&
self.path == other.path &&
self.is_batch_compatible_endpoint()
}
fn is_batch_compatible_endpoint( &self ) -> bool
{
matches!( self.path.as_str(),
"embeddings" |
"chat/completions" |
"moderations" |
"images/generations" |
"files" |
"fine_tuning/jobs"
)
}
}
#[ derive( Debug ) ]
pub struct BatchedRequest< T >
where
T: Send + Sync,
{
pub id : String,
pub signature : RequestSignature,
pub payload : T,
pub response_sender : tokio::sync::oneshot::Sender< Result< Vec< u8 >, crate::error::OpenAIError > >,
pub queued_at : Instant,
}
#[ derive( Debug ) ]
pub struct BatchResult
{
pub results : Vec< Result< Vec< u8 >, crate::error::OpenAIError > >,
pub processing_time : Duration,
pub http_requests_count : usize,
pub efficiency_ratio : f64,
}
#[ derive( Debug ) ]
pub struct RequestBatcher< T >
where
T: Send + Sync,
{
config : BatchConfig,
pending_requests : Arc< RwLock< HashMap< RequestSignature, VecDeque< BatchedRequest< T > > > > >,
batch_notify : Arc< Notify >,
active_batches : Arc< RwLock< usize > >,
metrics : Arc< RwLock< BatchMetrics > >,
}
#[ derive( Debug, Clone, Default ) ]
pub struct BatchMetrics
{
pub total_requests : u64,
pub total_batches : u64,
pub avg_batch_size : f64,
pub http_requests_saved : u64,
pub avg_batch_time : Duration,
pub efficiency_improvement : f64,
}
impl< T > RequestBatcher< T >
where
T: Send + Sync + 'static,
{
#[ inline ]
#[ must_use ]
pub fn new( config : BatchConfig ) -> Self
{
Self
{
config,
pending_requests : Arc::new( RwLock::new( HashMap::new() ) ),
batch_notify : Arc::new( Notify::new() ),
active_batches : Arc::new( RwLock::new( 0 ) ),
metrics : Arc::new( RwLock::new( BatchMetrics::default() ) ),
}
}
#[ inline ]
pub async fn submit_request(
&self,
signature : RequestSignature,
payload : T,
) -> Result< Vec< u8 >, crate::error::OpenAIError >
{
if !self.config.enable_smart_batching || !signature.is_batch_compatible_endpoint()
{
return Ok( Self::execute_single_request( signature, payload ) );
}
let ( tx, rx ) = tokio::sync::oneshot::channel();
let request_id = uuid::Uuid::new_v4().to_string();
let batched_request = BatchedRequest
{
id : request_id,
signature : signature.clone(),
payload,
response_sender : tx,
queued_at : Instant::now(),
};
{
let mut pending = self.pending_requests.write().await;
pending.entry( signature.clone() ).or_insert_with( VecDeque::new ).push_back( batched_request );
}
let should_process = self.should_trigger_batch_processing( &signature ).await;
if should_process
{
self.batch_notify.notify_one();
}
self.ensure_batch_processor_running().await;
rx.await.map_err( | _ | crate::error::OpenAIError::Internal( "Batch processing failed".to_string() ) )?
}
async fn should_trigger_batch_processing( &self, signature : &RequestSignature ) -> bool
{
let pending = self.pending_requests.read().await;
if let Some( queue ) = pending.get( signature )
{
queue.len() >= self.config.smart_batch_threshold ||
queue.front().is_some_and( | req | req.queued_at.elapsed() >= self.config.flush_timeout )
}
else
{
false
}
}
async fn ensure_batch_processor_running( &self )
{
let active_count = *self.active_batches.read().await;
if active_count < self.config.max_concurrent_batches
{
let pending_requests = Arc::clone( &self.pending_requests );
let batch_notify = Arc::clone( &self.batch_notify );
let active_batches = Arc::clone( &self.active_batches );
let metrics = Arc::clone( &self.metrics );
let config = self.config.clone();
tokio ::spawn( async move
{
{
let mut active = active_batches.write().await;
*active += 1;
}
loop
{
batch_notify.notified().await;
let batch_to_process = {
let mut pending = pending_requests.write().await;
Self::extract_ready_batch( &mut pending, &config )
};
if let Some( ( signature, requests ) ) = batch_to_process
{
let start_time = Instant::now();
let batch_size = requests.len();
let _results = Self::process_batch_requests( &signature, requests );
let processing_time = start_time.elapsed();
{
let mut metrics_guard = metrics.write().await;
metrics_guard.total_requests += batch_size as u64;
metrics_guard.total_batches += 1;
metrics_guard.avg_batch_size = ( metrics_guard.avg_batch_size * ( metrics_guard.total_batches - 1 ) as f64 + batch_size as f64 ) / metrics_guard.total_batches as f64;
metrics_guard.http_requests_saved += ( batch_size as u64 ).saturating_sub( 1 );
let new_avg_nanos = ( metrics_guard.avg_batch_time.as_nanos() * u128::from( metrics_guard.total_batches - 1 ) +
processing_time.as_nanos() ) / u128::from( metrics_guard.total_batches );
let bounded_nanos = new_avg_nanos.min( u128::from( u64::MAX ) );
metrics_guard.avg_batch_time = Duration::from_nanos( u64::try_from( bounded_nanos ).unwrap_or( u64::MAX ) );
if metrics_guard.total_requests > 0
{
metrics_guard.efficiency_improvement = metrics_guard.http_requests_saved as f64 / metrics_guard.total_requests as f64;
}
}
}
else
{
break;
}
}
{
let mut active = active_batches.write().await;
*active = active.saturating_sub( 1 );
}
} );
}
}
fn extract_ready_batch(
pending : &mut HashMap< RequestSignature, VecDeque< BatchedRequest< T > > >,
config : &BatchConfig,
) -> Option< ( RequestSignature, Vec< BatchedRequest< T > > ) >
{
for ( signature, queue ) in pending.iter_mut()
{
if queue.len() >= config.smart_batch_threshold ||
queue.front().is_some_and( | req | req.queued_at.elapsed() >= config.flush_timeout )
{
let mut batch = Vec::new();
for _ in 0..config.max_batch_size.min( queue.len() )
{
if let Some( request ) = queue.pop_front()
{
batch.push( request );
}
}
if !batch.is_empty()
{
return Some( ( signature.clone(), batch ) );
}
}
}
None
}
fn process_batch_requests(
signature : &RequestSignature,
requests : Vec< BatchedRequest< T > >,
) -> BatchResult
{
let start_time = Instant::now();
let request_count = requests.len();
let results : Vec< Result< Vec< u8 >, crate::error::OpenAIError > > = requests.into_iter().map( | request |
{
let mock_response = b"{ \"batched\": true }".to_vec();
let _ = request.response_sender.send( Ok( mock_response.clone() ) );
Ok( mock_response )
} ).collect();
let processing_time = start_time.elapsed();
let http_requests_count = if signature.is_batch_compatible_endpoint() { 1 } else { request_count };
let efficiency_ratio = request_count as f64 / http_requests_count as f64;
BatchResult
{
results,
processing_time,
http_requests_count,
efficiency_ratio,
}
}
fn execute_single_request(
_signature : RequestSignature,
_payload : T,
) -> Vec< u8 >
{
b"{ \"single\": true }".to_vec()
}
#[ inline ]
pub async fn get_metrics( &self ) -> BatchMetrics
{
self.metrics.read().await.clone()
}
#[ inline ]
pub async fn flush_all_pending( &self )
{
let mut pending = self.pending_requests.write().await;
for ( _, queue ) in pending.iter_mut()
{
while let Some( request ) = queue.pop_front()
{
let _ = request.response_sender.send( Err( crate::error::OpenAIError::Internal( "Request flushed".to_string() ) ) );
}
}
pending.clear();
}
}
#[ derive( Debug ) ]
pub struct BatchOptimizer;
impl BatchOptimizer
{
#[ inline ]
#[ must_use ]
pub fn analyze_batching_potential( requests : &[ RequestSignature ] ) -> BatchingAnalysis
{
let mut signature_counts = HashMap::new();
let mut total_batchable = 0;
for signature in requests
{
let count = signature_counts.entry( signature.clone() ).or_insert( 0 );
*count += 1;
if signature.is_batch_compatible_endpoint()
{
total_batchable += 1;
}
}
let potential_batches = signature_counts.values().map( | &count | ( count + 99 ) / 100 ).sum::< usize >();
let http_requests_saved = requests.len().saturating_sub( potential_batches );
let efficiency_gain = if requests.is_empty() { 0.0 } else { http_requests_saved as f64 / requests.len() as f64 };
BatchingAnalysis
{
total_requests : requests.len(),
batchable_requests : total_batchable,
potential_batches,
http_requests_saved,
efficiency_gain,
recommended_batch_size : Self::calculate_optimal_batch_size( &signature_counts ),
}
}
fn calculate_optimal_batch_size( signature_counts : &HashMap< RequestSignature, usize > ) -> usize
{
if signature_counts.is_empty()
{
return 50; }
let avg_similar_requests = signature_counts.values().sum::< usize >() as f64 / signature_counts.len() as f64;
#[ allow(clippy::cast_possible_truncation, clippy::cast_sign_loss) ]
let avg_usize = avg_similar_requests as usize;
match avg_usize
{
1..=5 => 10,
6..=20 => 25,
21..=50 => 50,
51..=100 => 75,
_ => 100,
}
}
}
#[ derive( Debug, Clone ) ]
pub struct BatchingAnalysis
{
pub total_requests : usize,
pub batchable_requests : usize,
pub potential_batches : usize,
pub http_requests_saved : usize,
pub efficiency_gain : f64,
pub recommended_batch_size : usize,
}
}
mod_interface!
{
exposed use
{
BatchConfig,
RequestSignature,
BatchedRequest,
BatchResult,
RequestBatcher,
BatchMetrics,
BatchOptimizer,
BatchingAnalysis,
};
}