use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use scirs2_core::parallel_ops::*; use torsh_core::{device::DeviceType, dtype::TensorElement, error::Result};
use crate::Tensor;
#[derive(Debug, Clone)]
pub struct BatchingConfig {
pub min_batch_size: usize,
pub max_batch_size: usize,
pub max_wait_time: Duration,
pub parallel_execution: bool,
pub small_op_threshold: usize,
pub enabled: bool,
}
impl Default for BatchingConfig {
fn default() -> Self {
Self {
min_batch_size: 4,
max_batch_size: 32,
max_wait_time: Duration::from_micros(100),
parallel_execution: true,
small_op_threshold: 1000,
enabled: true,
}
}
}
impl BatchingConfig {
pub fn small_ops() -> Self {
Self {
min_batch_size: 8,
max_batch_size: 64,
max_wait_time: Duration::from_micros(50),
parallel_execution: true,
small_op_threshold: 500,
enabled: true,
}
}
pub fn large_ops() -> Self {
Self {
min_batch_size: 2,
max_batch_size: 8,
max_wait_time: Duration::from_micros(20),
parallel_execution: false,
small_op_threshold: 10000,
enabled: false, }
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub enum BatchableOp<T: TensorElement> {
Add(Arc<Tensor<T>>, Arc<Tensor<T>>),
Mul(Arc<Tensor<T>>, Arc<Tensor<T>>),
Sub(Arc<Tensor<T>>, Arc<Tensor<T>>),
Div(Arc<Tensor<T>>, Arc<Tensor<T>>),
AddScalar(Arc<Tensor<T>>, T),
MulScalar(Arc<Tensor<T>>, T),
ReLU(Arc<Tensor<T>>),
Sigmoid(Arc<Tensor<T>>),
Tanh(Arc<Tensor<T>>),
}
impl<T: TensorElement> BatchableOp<T> {
pub fn size(&self) -> usize {
match self {
BatchableOp::Add(a, _)
| BatchableOp::Mul(a, _)
| BatchableOp::Sub(a, _)
| BatchableOp::Div(a, _)
| BatchableOp::AddScalar(a, _)
| BatchableOp::MulScalar(a, _)
| BatchableOp::ReLU(a)
| BatchableOp::Sigmoid(a)
| BatchableOp::Tanh(a) => a.numel(),
}
}
pub fn device(&self) -> DeviceType {
match self {
BatchableOp::Add(a, _)
| BatchableOp::Mul(a, _)
| BatchableOp::Sub(a, _)
| BatchableOp::Div(a, _)
| BatchableOp::AddScalar(a, _)
| BatchableOp::MulScalar(a, _)
| BatchableOp::ReLU(a)
| BatchableOp::Sigmoid(a)
| BatchableOp::Tanh(a) => a.device,
}
}
pub fn should_batch(&self, config: &BatchingConfig) -> bool {
config.enabled && self.size() < config.small_op_threshold
}
}
struct OperationBatch<T: TensorElement> {
operations: Vec<BatchableOp<T>>,
created_at: Instant,
device: DeviceType,
}
impl<T: TensorElement> OperationBatch<T> {
fn new(device: DeviceType) -> Self {
Self {
operations: Vec::new(),
created_at: Instant::now(),
device,
}
}
fn add(&mut self, op: BatchableOp<T>) {
self.operations.push(op);
}
fn is_ready(&self, config: &BatchingConfig) -> bool {
if self.operations.len() >= config.max_batch_size {
return true;
}
if self.operations.len() >= config.min_batch_size {
let elapsed = self.created_at.elapsed();
if elapsed >= config.max_wait_time {
return true;
}
}
false
}
fn can_add(&self, config: &BatchingConfig) -> bool {
self.operations.len() < config.max_batch_size
}
fn len(&self) -> usize {
self.operations.len()
}
fn is_empty(&self) -> bool {
self.operations.is_empty()
}
}
pub struct AutoBatcher<T: TensorElement> {
current_batch: Arc<Mutex<Option<OperationBatch<T>>>>,
config: BatchingConfig,
stats: Arc<Mutex<BatchingStats>>,
}
impl<
T: TensorElement
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ torsh_core::FloatElement
+ Send
+ Sync,
> AutoBatcher<T>
{
pub fn new() -> Self {
Self::with_config(BatchingConfig::default())
}
pub fn with_config(config: BatchingConfig) -> Self {
Self {
current_batch: Arc::new(Mutex::new(None)),
config,
stats: Arc::new(Mutex::new(BatchingStats::default())),
}
}
pub fn submit(&self, op: BatchableOp<T>) -> Result<BatchHandle<T>> {
if !self.config.enabled || !op.should_batch(&self.config) {
return Ok(BatchHandle::Immediate(self.execute_single(op)?));
}
let mut batch_lock = self
.current_batch
.lock()
.expect("lock should not be poisoned");
let batch = batch_lock.get_or_insert_with(|| OperationBatch::new(op.device()));
if !batch.can_add(&self.config) || batch.device != op.device() {
let ready_batch = batch_lock
.take()
.expect("batch should exist after get_or_insert_with");
drop(batch_lock);
self.execute_batch(ready_batch)?;
let mut new_batch_lock = self
.current_batch
.lock()
.expect("lock should not be poisoned");
let new_batch = new_batch_lock.get_or_insert_with(|| OperationBatch::new(op.device()));
new_batch.add(op);
} else {
batch.add(op);
if batch.is_ready(&self.config) {
let ready_batch = batch_lock
.take()
.expect("batch should exist after is_ready check");
drop(batch_lock);
self.execute_batch(ready_batch)?;
}
}
Ok(BatchHandle::Batched)
}
pub fn flush(&self) -> Result<()> {
let batch = self
.current_batch
.lock()
.expect("lock should not be poisoned")
.take();
if let Some(batch) = batch {
if !batch.is_empty() {
self.execute_batch(batch)?;
}
}
Ok(())
}
fn execute_single(&self, op: BatchableOp<T>) -> Result<Tensor<T>>
where
T: std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ torsh_core::FloatElement,
{
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.single_ops_executed += 1;
drop(stats);
match op {
BatchableOp::Add(a, b) => a.add_op(&b),
BatchableOp::Mul(a, b) => a.mul_op(&b),
BatchableOp::Sub(a, b) => a.sub(&b),
BatchableOp::Div(a, b) => a.div(&b),
BatchableOp::AddScalar(a, s) => a.add_scalar(s),
BatchableOp::MulScalar(a, s) => a.mul_scalar(s),
BatchableOp::ReLU(a) => a.relu(),
BatchableOp::Sigmoid(a) => a.sigmoid(),
BatchableOp::Tanh(a) => a.tanh(),
}
}
fn execute_batch(&self, batch: OperationBatch<T>) -> Result<()>
where
T: std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ torsh_core::FloatElement
+ Send
+ Sync,
{
let batch_size = batch.len();
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.batches_executed += 1;
stats.total_ops_batched += batch_size;
stats.avg_batch_size = (stats.avg_batch_size * (stats.batches_executed - 1) as f64
+ batch_size as f64)
/ stats.batches_executed as f64;
drop(stats);
if self.config.parallel_execution && batch_size > 1 {
let results: Vec<Result<()>> = batch
.operations
.into_par_iter()
.map(|op| {
self.execute_single(op)?;
Ok(())
})
.collect();
for result in results {
result?;
}
} else {
for op in batch.operations {
self.execute_single(op)?;
}
}
Ok(())
}
pub fn stats(&self) -> BatchingStats {
self.stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn reset_stats(&self) {
*self.stats.lock().expect("lock should not be poisoned") = BatchingStats::default();
}
}
impl<
T: TensorElement
+ Copy
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ torsh_core::FloatElement
+ Send
+ Sync,
> Default for AutoBatcher<T>
{
fn default() -> Self {
Self::new()
}
}
pub enum BatchHandle<T: TensorElement> {
Immediate(Tensor<T>),
Batched,
}
#[derive(Debug, Clone)]
pub struct BatchingStats {
pub batches_executed: usize,
pub total_ops_batched: usize,
pub avg_batch_size: f64,
pub single_ops_executed: usize,
}
impl Default for BatchingStats {
fn default() -> Self {
Self {
batches_executed: 0,
total_ops_batched: 0,
avg_batch_size: 0.0,
single_ops_executed: 0,
}
}
}
impl BatchingStats {
pub fn batching_efficiency(&self) -> f64 {
let total_ops = self.total_ops_batched + self.single_ops_executed;
if total_ops == 0 {
0.0
} else {
(self.total_ops_batched as f64 / total_ops as f64) * 100.0
}
}
pub fn ops_saved(&self) -> f64 {
if self.batches_executed == 0 {
0.0
} else {
self.total_ops_batched as f64 - self.batches_executed as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::creation::*;
#[test]
fn test_batching_config_presets() {
let default_config = BatchingConfig::default();
assert!(default_config.enabled);
assert_eq!(default_config.min_batch_size, 4);
let small_ops = BatchingConfig::small_ops();
assert_eq!(small_ops.min_batch_size, 8);
assert_eq!(small_ops.max_batch_size, 64);
let large_ops = BatchingConfig::large_ops();
assert!(!large_ops.enabled);
let disabled = BatchingConfig::disabled();
assert!(!disabled.enabled);
}
#[test]
fn test_batchable_op_size() {
let a = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
let b = tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor_1d creation should succeed");
let op = BatchableOp::Add(Arc::new(a), Arc::new(b));
assert_eq!(op.size(), 4);
}
#[test]
fn test_batchable_op_should_batch() {
let a = tensor_1d(&[1.0f32; 100]).expect("tensor_1d creation should succeed");
let b = tensor_1d(&[2.0f32; 100]).expect("tensor_1d creation should succeed");
let op = BatchableOp::Add(Arc::new(a), Arc::new(b));
let config = BatchingConfig::default();
assert!(op.should_batch(&config));
let disabled_config = BatchingConfig::disabled();
assert!(!op.should_batch(&disabled_config));
}
#[test]
fn test_operation_batch() {
let a = tensor_1d(&[1.0f32, 2.0]).expect("tensor_1d creation should succeed");
let op = BatchableOp::AddScalar(Arc::new(a), 1.0);
let mut batch = OperationBatch::new(DeviceType::Cpu);
assert!(batch.is_empty());
batch.add(op);
assert!(!batch.is_empty());
assert_eq!(batch.len(), 1);
}
#[test]
fn test_batch_readiness() {
let config = BatchingConfig {
min_batch_size: 2,
max_batch_size: 5,
max_wait_time: Duration::from_millis(10),
..Default::default()
};
let mut batch = OperationBatch::<f32>::new(DeviceType::Cpu);
assert!(!batch.is_ready(&config));
let a = tensor_1d(&[1.0f32]).expect("tensor_1d creation should succeed");
batch.add(BatchableOp::AddScalar(Arc::new(a), 1.0));
assert!(!batch.is_ready(&config));
let b = tensor_1d(&[2.0f32]).expect("tensor_1d creation should succeed");
batch.add(BatchableOp::AddScalar(Arc::new(b), 1.0));
for _ in 0..3 {
let c = tensor_1d(&[3.0f32]).expect("tensor_1d creation should succeed");
batch.add(BatchableOp::AddScalar(Arc::new(c), 1.0));
}
assert!(batch.is_ready(&config)); }
#[test]
fn test_batching_stats() {
let mut stats = BatchingStats::default();
stats.batches_executed = 10;
stats.total_ops_batched = 50;
stats.single_ops_executed = 10;
let efficiency = stats.batching_efficiency();
assert!((efficiency - 83.33).abs() < 0.1);
let ops_saved = stats.ops_saved();
assert_eq!(ops_saved, 40.0); }
#[test]
fn test_auto_batcher_creation() {
let batcher = AutoBatcher::<f32>::new();
let stats = batcher.stats();
assert_eq!(stats.batches_executed, 0);
assert_eq!(stats.total_ops_batched, 0);
assert_eq!(stats.single_ops_executed, 0);
}
#[test]
fn test_auto_batcher_disabled() {
let config = BatchingConfig::disabled();
let batcher = AutoBatcher::<f32>::with_config(config);
let a = tensor_1d(&[1.0f32, 2.0]).expect("tensor_1d creation should succeed");
let op = BatchableOp::AddScalar(Arc::new(a), 1.0);
let handle = batcher.submit(op).expect("submit should succeed");
assert!(matches!(handle, BatchHandle::Immediate(_)));
let stats = batcher.stats();
assert_eq!(stats.single_ops_executed, 1);
assert_eq!(stats.total_ops_batched, 0);
}
}