use crate::InferenceSession;
use crate::error::{Error, Result};
use ronn_core::tensor::Tensor;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, mpsc};
use tokio::time::timeout;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BatchStrategy {
Static {
batch_size: usize,
},
Dynamic {
max_batch_size: usize,
timeout_ms: u64,
},
}
impl Default for BatchStrategy {
fn default() -> Self {
Self::Dynamic {
max_batch_size: 32,
timeout_ms: 10,
}
}
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub strategy: BatchStrategy,
pub queue_capacity: usize,
pub num_workers: usize,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
strategy: BatchStrategy::default(),
queue_capacity: 1024,
num_workers: 1,
}
}
}
pub struct BatchRequest {
pub inputs: HashMap<String, Tensor>,
response_tx: tokio::sync::oneshot::Sender<Result<HashMap<String, Tensor>>>,
}
impl BatchRequest {
pub fn new(
inputs: HashMap<String, Tensor>,
response_tx: tokio::sync::oneshot::Sender<Result<HashMap<String, Tensor>>>,
) -> Self {
Self {
inputs,
response_tx,
}
}
fn send_response(self, result: Result<HashMap<String, Tensor>>) {
let _ = self.response_tx.send(result);
}
}
pub struct BatchProcessor {
request_tx: mpsc::Sender<BatchRequest>,
_worker_handle: tokio::task::JoinHandle<()>,
config: BatchConfig,
}
impl BatchProcessor {
pub fn new(session: InferenceSession, config: BatchConfig) -> Self {
let (request_tx, request_rx) = mpsc::channel(config.queue_capacity);
let worker_config = config.clone();
let worker_handle = tokio::spawn(async move {
Self::worker_loop(session, request_rx, worker_config).await;
});
Self {
request_tx,
_worker_handle: worker_handle,
config,
}
}
pub async fn process(
&self,
inputs: HashMap<String, Tensor>,
) -> Result<HashMap<String, Tensor>> {
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let request = BatchRequest::new(inputs, response_tx);
self.request_tx
.send(request)
.await
.map_err(|_| Error::InferenceError("Batch processor channel closed".to_string()))?;
response_rx
.await
.map_err(|_| Error::InferenceError("Response channel closed".to_string()))?
}
async fn worker_loop(
session: InferenceSession,
mut request_rx: mpsc::Receiver<BatchRequest>,
config: BatchConfig,
) {
let session = Arc::new(RwLock::new(session));
loop {
match config.strategy {
BatchStrategy::Static { batch_size } => {
let batch = Self::collect_static_batch(&mut request_rx, batch_size).await;
if batch.is_empty() {
break; }
Self::process_batch(session.clone(), batch).await;
}
BatchStrategy::Dynamic {
max_batch_size,
timeout_ms,
} => {
let batch =
Self::collect_dynamic_batch(&mut request_rx, max_batch_size, timeout_ms)
.await;
if batch.is_empty() {
break; }
Self::process_batch(session.clone(), batch).await;
}
}
}
}
async fn collect_static_batch(
request_rx: &mut mpsc::Receiver<BatchRequest>,
batch_size: usize,
) -> Vec<BatchRequest> {
let mut batch = Vec::with_capacity(batch_size);
for _ in 0..batch_size {
match request_rx.recv().await {
Some(request) => batch.push(request),
None => break, }
}
batch
}
async fn collect_dynamic_batch(
request_rx: &mut mpsc::Receiver<BatchRequest>,
max_batch_size: usize,
timeout_ms: u64,
) -> Vec<BatchRequest> {
let mut batch = Vec::with_capacity(max_batch_size);
let deadline = Duration::from_millis(timeout_ms);
match request_rx.recv().await {
Some(request) => batch.push(request),
None => return batch, }
let start = Instant::now();
while batch.len() < max_batch_size {
let remaining = deadline.saturating_sub(start.elapsed());
if remaining.is_zero() {
break;
}
match timeout(remaining, request_rx.recv()).await {
Ok(Some(request)) => batch.push(request),
Ok(None) => break, Err(_) => break, }
}
batch
}
async fn process_batch(session: Arc<RwLock<InferenceSession>>, batch: Vec<BatchRequest>) {
if batch.is_empty() {
return;
}
let batch_size = batch.len();
let combined_inputs = match Self::combine_inputs(&batch) {
Ok(inputs) => inputs,
Err(e) => {
let err_msg = format!("{}", e);
for request in batch {
request.send_response(Err(Error::InferenceError(err_msg.clone())));
}
return;
}
};
let inputs_ref: HashMap<&str, Tensor> = combined_inputs
.iter()
.map(|(k, v)| (k.as_str(), v.clone()))
.collect();
let session = session.read().await;
let combined_outputs = match session.run(inputs_ref) {
Ok(outputs) => outputs,
Err(e) => {
let err_msg = format!("{}", e);
for request in batch {
request.send_response(Err(Error::InferenceError(err_msg.clone())));
}
return;
}
};
match Self::split_outputs(combined_outputs, batch_size) {
Ok(individual_outputs) => {
for (request, outputs) in batch.into_iter().zip(individual_outputs) {
request.send_response(Ok(outputs));
}
}
Err(e) => {
let err_msg = format!("{}", e);
for request in batch {
request.send_response(Err(Error::InferenceError(err_msg.clone())));
}
}
}
}
fn combine_inputs(batch: &[BatchRequest]) -> Result<HashMap<String, Tensor>> {
if batch.is_empty() {
return Ok(HashMap::new());
}
let input_names: Vec<_> = batch[0].inputs.keys().cloned().collect();
let mut combined = HashMap::new();
for name in input_names {
let tensors: std::result::Result<Vec<_>, Error> = batch
.iter()
.map(|req| {
req.inputs.get(&name).ok_or_else(|| {
Error::InvalidInput(format!("Missing input tensor: {}", name))
})
})
.collect();
let tensors = tensors?;
let batched = Tensor::stack(&tensors, 0)
.map_err(|e| Error::InferenceError(format!("Failed to stack tensors: {}", e)))?;
combined.insert(name, batched);
}
Ok(combined)
}
fn split_outputs(
combined: HashMap<String, Tensor>,
batch_size: usize,
) -> Result<Vec<HashMap<String, Tensor>>> {
let mut results = vec![HashMap::new(); batch_size];
for (name, batched_tensor) in combined {
let individual_tensors = batched_tensor
.split(batch_size, 0)
.map_err(|e| Error::InferenceError(format!("Failed to split tensors: {}", e)))?;
for (i, tensor) in individual_tensors.into_iter().enumerate() {
results[i].insert(name.clone(), tensor);
}
}
Ok(results)
}
pub fn config(&self) -> &BatchConfig {
&self.config
}
}
#[derive(Debug, Clone, Default)]
pub struct BatchStats {
pub total_batches: u64,
pub total_requests: u64,
pub avg_batch_size: f64,
pub max_batch_size: usize,
pub min_batch_size: usize,
pub total_processing_time_ms: f64,
pub avg_batch_time_ms: f64,
}
impl BatchStats {
pub fn throughput(&self) -> f64 {
if self.total_processing_time_ms == 0.0 {
0.0
} else {
(self.total_requests as f64 * 1000.0) / self.total_processing_time_ms
}
}
pub fn utilization(&self, max_batch_size: usize) -> f64 {
if max_batch_size == 0 {
0.0
} else {
self.avg_batch_size / max_batch_size as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_config_default() {
let config = BatchConfig::default();
assert_eq!(config.queue_capacity, 1024);
assert_eq!(config.num_workers, 1);
match config.strategy {
BatchStrategy::Dynamic {
max_batch_size,
timeout_ms,
} => {
assert_eq!(max_batch_size, 32);
assert_eq!(timeout_ms, 10);
}
_ => panic!("Expected dynamic strategy"),
}
}
#[test]
fn test_batch_strategy_static() {
let strategy = BatchStrategy::Static { batch_size: 16 };
match strategy {
BatchStrategy::Static { batch_size } => {
assert_eq!(batch_size, 16);
}
_ => panic!("Expected static strategy"),
}
}
#[test]
fn test_batch_stats_throughput() {
let stats = BatchStats {
total_requests: 1000,
total_processing_time_ms: 1000.0,
..Default::default()
};
assert_eq!(stats.throughput(), 1000.0); }
#[test]
fn test_batch_stats_utilization() {
let stats = BatchStats {
avg_batch_size: 16.0,
..Default::default()
};
assert_eq!(stats.utilization(32), 0.5); }
}