use crate::error::{NumRs2Error, Result};
use crate::gpu::array::GpuArray;
use crate::gpu::context::GpuContextRef;
use crate::gpu::memory::TransferOptimizer;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
const DEFAULT_MAX_BATCH_SIZE: usize = 32;
const DEFAULT_BATCH_TIMEOUT_MS: u64 = 10;
const DEFAULT_MIN_BATCH_SIZE: usize = 4;
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub max_batch_size: usize,
pub batch_timeout: Duration,
pub min_batch_size: usize,
pub enable_dynamic_optimization: bool,
pub enable_auto_flush: bool,
pub target_occupancy: f32,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_batch_size: DEFAULT_MAX_BATCH_SIZE,
batch_timeout: Duration::from_millis(DEFAULT_BATCH_TIMEOUT_MS),
min_batch_size: DEFAULT_MIN_BATCH_SIZE,
enable_dynamic_optimization: true,
enable_auto_flush: true,
target_occupancy: 0.8,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OperationType {
MatMul,
Add,
Multiply,
Subtract,
Divide,
Conv2D,
Exp,
Log,
Sqrt,
}
impl OperationType {
pub fn is_batchable(&self) -> bool {
matches!(
self,
OperationType::MatMul
| OperationType::Add
| OperationType::Multiply
| OperationType::Subtract
| OperationType::Divide
| OperationType::Conv2D
)
}
pub fn cost_factor(&self) -> f32 {
match self {
OperationType::MatMul => 10.0,
OperationType::Conv2D => 8.0,
OperationType::Multiply | OperationType::Divide => 2.0,
OperationType::Add | OperationType::Subtract => 1.0,
OperationType::Exp | OperationType::Log => 3.0,
OperationType::Sqrt => 2.5,
}
}
}
struct QueuedOperation<'a, T: bytemuck::Pod + bytemuck::Zeroable> {
op_type: OperationType,
input_a: &'a GpuArray<T>,
input_b: Option<&'a GpuArray<T>>,
queued_at: Instant,
priority: i32,
cost: f32,
}
pub struct BatchResult<T: bytemuck::Pod + bytemuck::Zeroable> {
pub result: GpuArray<T>,
pub op_type: OperationType,
pub execution_time_us: u64,
}
#[derive(Debug, Clone, Copy)]
pub struct BatchStatistics {
pub total_operations: u64,
pub total_flushes: u64,
pub total_executed: u64,
pub avg_batch_size: f32,
pub max_batch_size: usize,
pub current_queue_depth: usize,
pub avg_wait_time_us: u64,
pub avg_execution_time_us: u64,
pub throughput_ops_per_sec: f32,
pub estimated_gpu_occupancy: f32,
pub auto_flush_count: u64,
pub manual_flush_count: u64,
}
impl Default for BatchStatistics {
fn default() -> Self {
Self {
total_operations: 0,
total_flushes: 0,
total_executed: 0,
avg_batch_size: 0.0,
max_batch_size: 0,
current_queue_depth: 0,
avg_wait_time_us: 0,
avg_execution_time_us: 0,
throughput_ops_per_sec: 0.0,
estimated_gpu_occupancy: 0.0,
auto_flush_count: 0,
manual_flush_count: 0,
}
}
}
struct BatchQueueState<T: bytemuck::Pod + bytemuck::Zeroable> {
queue: VecDeque<OwnedQueuedOperation<T>>,
stats: BatchStatistics,
last_flush: Instant,
dynamic_batch_size: usize,
recent_execution_times: VecDeque<u64>,
recent_batch_sizes: VecDeque<usize>,
}
struct OwnedQueuedOperation<T: bytemuck::Pod + bytemuck::Zeroable> {
op_type: OperationType,
input_a: Arc<GpuArray<T>>,
input_b: Option<Arc<GpuArray<T>>>,
queued_at: Instant,
priority: i32,
cost: f32,
}
pub struct BatchQueue<T: bytemuck::Pod + bytemuck::Zeroable> {
context: GpuContextRef,
config: BatchConfig,
state: Arc<Mutex<BatchQueueState<T>>>,
transfer_optimizer: Arc<Mutex<TransferOptimizer>>,
}
impl<T: bytemuck::Pod + bytemuck::Zeroable> BatchQueue<T> {
pub fn new(context: GpuContextRef, config: BatchConfig) -> Self {
let transfer_optimizer = TransferOptimizer::new(
context.clone(),
crate::gpu::memory::TransferStrategy::Batched,
);
Self {
context: context.clone(),
config: config.clone(),
state: Arc::new(Mutex::new(BatchQueueState {
queue: VecDeque::new(),
stats: BatchStatistics::default(),
last_flush: Instant::now(),
dynamic_batch_size: config.max_batch_size,
recent_execution_times: VecDeque::with_capacity(100),
recent_batch_sizes: VecDeque::with_capacity(100),
})),
transfer_optimizer: Arc::new(Mutex::new(transfer_optimizer)),
}
}
pub fn with_default_config(context: GpuContextRef) -> Self {
Self::new(context, BatchConfig::default())
}
pub fn queue_matmul(&mut self, a: Arc<GpuArray<T>>, b: Arc<GpuArray<T>>) -> Result<()> {
self.queue_operation(OperationType::MatMul, a, Some(b), 0)
}
pub fn queue_add(&mut self, a: Arc<GpuArray<T>>, b: Arc<GpuArray<T>>) -> Result<()> {
self.queue_operation(OperationType::Add, a, Some(b), 0)
}
pub fn queue_multiply(&mut self, a: Arc<GpuArray<T>>, b: Arc<GpuArray<T>>) -> Result<()> {
self.queue_operation(OperationType::Multiply, a, Some(b), 0)
}
pub fn queue_subtract(&mut self, a: Arc<GpuArray<T>>, b: Arc<GpuArray<T>>) -> Result<()> {
self.queue_operation(OperationType::Subtract, a, Some(b), 0)
}
pub fn queue_divide(&mut self, a: Arc<GpuArray<T>>, b: Arc<GpuArray<T>>) -> Result<()> {
self.queue_operation(OperationType::Divide, a, Some(b), 0)
}
fn queue_operation(
&mut self,
op_type: OperationType,
input_a: Arc<GpuArray<T>>,
input_b: Option<Arc<GpuArray<T>>>,
priority: i32,
) -> Result<()> {
let mut state = self
.state
.lock()
.map_err(|e| NumRs2Error::RuntimeError(format!("Failed to lock batch queue: {}", e)))?;
let cost = op_type.cost_factor() * (input_a.size() as f32);
let op = OwnedQueuedOperation {
op_type,
input_a,
input_b,
queued_at: Instant::now(),
priority,
cost,
};
state.queue.push_back(op);
state.stats.total_operations += 1;
state.stats.current_queue_depth = state.queue.len();
if self.config.enable_auto_flush && self.should_auto_flush(&state) {
drop(state);
self.flush()?;
}
Ok(())
}
fn should_auto_flush(&self, state: &BatchQueueState<T>) -> bool {
let queue_size = state.queue.len();
let time_since_last_flush = state.last_flush.elapsed();
if queue_size >= self.config.max_batch_size {
return true;
}
if time_since_last_flush >= self.config.batch_timeout
&& queue_size >= self.config.min_batch_size
{
return true;
}
false
}
pub fn flush(&mut self) -> Result<Vec<BatchResult<T>>> {
let mut state = self
.state
.lock()
.map_err(|e| NumRs2Error::RuntimeError(format!("Failed to lock batch queue: {}", e)))?;
if state.queue.is_empty() {
return Ok(Vec::new());
}
let flush_start = Instant::now();
let batch_size = if self.config.enable_dynamic_optimization {
state.dynamic_batch_size.min(state.queue.len())
} else {
self.config.max_batch_size.min(state.queue.len())
};
let mut operations = Vec::with_capacity(batch_size);
for _ in 0..batch_size {
if let Some(op) = state.queue.pop_front() {
operations.push(op);
} else {
break;
}
}
state.stats.current_queue_depth = state.queue.len();
state.stats.total_flushes += 1;
state.stats.manual_flush_count += 1;
let actual_batch_size = operations.len();
state.stats.max_batch_size = state.stats.max_batch_size.max(actual_batch_size);
drop(state);
let results = self.execute_batch(operations)?;
let mut state = self
.state
.lock()
.map_err(|e| NumRs2Error::RuntimeError(format!("Failed to lock batch queue: {}", e)))?;
let execution_time = flush_start.elapsed().as_micros() as u64;
state.stats.total_executed += actual_batch_size as u64;
state.last_flush = Instant::now();
state.recent_execution_times.push_back(execution_time);
state.recent_batch_sizes.push_back(actual_batch_size);
if state.recent_execution_times.len() > 100 {
state.recent_execution_times.pop_front();
state.recent_batch_sizes.pop_front();
}
self.update_statistics(&mut state);
if self.config.enable_dynamic_optimization {
self.optimize_batch_size(&mut state);
}
Ok(results)
}
fn execute_batch(
&self,
operations: Vec<OwnedQueuedOperation<T>>,
) -> Result<Vec<BatchResult<T>>> {
let mut results = Vec::with_capacity(operations.len());
for op in operations {
let start = Instant::now();
let result_array = match op.op_type {
OperationType::MatMul => {
if let Some(input_b) = &op.input_b {
crate::gpu::ops::matmul(&op.input_a, input_b)?
} else {
return Err(NumRs2Error::InvalidOperation(
"MatMul requires two inputs".to_string(),
));
}
}
OperationType::Add => {
if let Some(input_b) = &op.input_b {
crate::gpu::ops::add(&op.input_a, input_b)?
} else {
return Err(NumRs2Error::InvalidOperation(
"Add requires two inputs".to_string(),
));
}
}
OperationType::Multiply => {
if let Some(input_b) = &op.input_b {
crate::gpu::ops::multiply(&op.input_a, input_b)?
} else {
return Err(NumRs2Error::InvalidOperation(
"Multiply requires two inputs".to_string(),
));
}
}
OperationType::Subtract => {
if let Some(input_b) = &op.input_b {
crate::gpu::ops::subtract(&op.input_a, input_b)?
} else {
return Err(NumRs2Error::InvalidOperation(
"Subtract requires two inputs".to_string(),
));
}
}
OperationType::Divide => {
if let Some(input_b) = &op.input_b {
crate::gpu::ops::divide(&op.input_a, input_b)?
} else {
return Err(NumRs2Error::InvalidOperation(
"Divide requires two inputs".to_string(),
));
}
}
OperationType::Exp => crate::gpu::ops::exp(&op.input_a)?,
OperationType::Log => crate::gpu::ops::log(&op.input_a)?,
OperationType::Sqrt => crate::gpu::ops::sqrt(&op.input_a)?,
OperationType::Conv2D => {
return Err(NumRs2Error::NotImplemented(
"Conv2D batching not yet implemented".to_string(),
))
}
};
let execution_time = start.elapsed().as_micros() as u64;
results.push(BatchResult {
result: result_array,
op_type: op.op_type,
execution_time_us: execution_time,
});
}
Ok(results)
}
fn update_statistics(&self, state: &mut BatchQueueState<T>) {
if state.stats.total_flushes == 0 {
return;
}
state.stats.avg_batch_size =
state.stats.total_executed as f32 / state.stats.total_flushes as f32;
if !state.recent_execution_times.is_empty() {
let sum: u64 = state.recent_execution_times.iter().sum();
state.stats.avg_execution_time_us = sum / state.recent_execution_times.len() as u64;
}
if state.stats.avg_execution_time_us > 0 {
state.stats.throughput_ops_per_sec = (state.stats.avg_batch_size * 1_000_000.0)
/ state.stats.avg_execution_time_us as f32;
}
state.stats.estimated_gpu_occupancy = self.estimate_gpu_occupancy(state);
}
fn estimate_gpu_occupancy(&self, state: &BatchQueueState<T>) -> f32 {
if state.recent_execution_times.is_empty() {
return 0.0;
}
let batch_efficiency = state.stats.avg_batch_size / self.config.max_batch_size as f32;
let throughput_factor = (state.stats.throughput_ops_per_sec / 1000.0).min(1.0);
(batch_efficiency * 0.6 + throughput_factor * 0.4).min(1.0)
}
fn optimize_batch_size(&self, state: &mut BatchQueueState<T>) {
if state.recent_execution_times.len() < 10 {
return;
}
let current_occupancy = state.stats.estimated_gpu_occupancy;
let target_occupancy = self.config.target_occupancy;
if current_occupancy < target_occupancy - 0.1 {
state.dynamic_batch_size =
(state.dynamic_batch_size + 4).min(self.config.max_batch_size);
} else if current_occupancy > target_occupancy + 0.1 {
state.dynamic_batch_size =
(state.dynamic_batch_size.saturating_sub(2)).max(self.config.min_batch_size);
}
}
pub fn statistics(&self) -> Result<BatchStatistics> {
let state = self
.state
.lock()
.map_err(|e| NumRs2Error::RuntimeError(format!("Failed to lock batch queue: {}", e)))?;
Ok(state.stats)
}
pub fn queue_depth(&self) -> Result<usize> {
let state = self
.state
.lock()
.map_err(|e| NumRs2Error::RuntimeError(format!("Failed to lock batch queue: {}", e)))?;
Ok(state.queue.len())
}
pub fn clear(&mut self) -> Result<()> {
let mut state = self
.state
.lock()
.map_err(|e| NumRs2Error::RuntimeError(format!("Failed to lock batch queue: {}", e)))?;
state.queue.clear();
state.stats.current_queue_depth = 0;
Ok(())
}
pub fn is_empty(&self) -> Result<bool> {
let state = self
.state
.lock()
.map_err(|e| NumRs2Error::RuntimeError(format!("Failed to lock batch queue: {}", e)))?;
Ok(state.queue.is_empty())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_config_default() {
let config = BatchConfig::default();
assert_eq!(config.max_batch_size, DEFAULT_MAX_BATCH_SIZE);
assert!(config.enable_dynamic_optimization);
assert!(config.enable_auto_flush);
}
#[test]
fn test_operation_type_cost_factor() {
assert!(OperationType::MatMul.cost_factor() > OperationType::Add.cost_factor());
assert!(OperationType::Conv2D.cost_factor() > OperationType::Multiply.cost_factor());
}
#[test]
fn test_operation_type_is_batchable() {
assert!(OperationType::MatMul.is_batchable());
assert!(OperationType::Add.is_batchable());
assert!(OperationType::Conv2D.is_batchable());
}
#[test]
fn test_batch_statistics_default() {
let stats = BatchStatistics::default();
assert_eq!(stats.total_operations, 0);
assert_eq!(stats.total_flushes, 0);
assert_eq!(stats.current_queue_depth, 0);
}
}