use std::collections::HashMap;
#[derive(Debug)]
pub struct BatchSizeOptimizer {
pub optimalsizes: std::collections::HashMap<String, usize>,
pub memory_limit: usize,
performance_history: Vec<BatchPerformanceRecord>,
}
#[derive(Debug, Clone)]
pub struct BatchPerformanceRecord {
pub operation: String,
pub batchsize: usize,
pub execution_time: f64,
pub memory_usage: usize,
pub throughput: f64,
}
impl BatchSizeOptimizer {
pub fn new(memory_limit: usize) -> Self {
Self {
optimalsizes: std::collections::HashMap::new(),
memory_limit,
performance_history: Vec::new(),
}
}
pub fn optimize_batchsize(&mut self, operation: &str, datasize: usize) -> usize {
if let Some(&optimal) = self.optimalsizes.get(operation) {
return optimal.min(datasize);
}
let default_batch = match operation {
"matrix_multiply" => (self.memory_limit / 8).min(1024), "matrix_vector" => (self.memory_limit / 4).min(2048), "element_wise" => (self.memory_limit / 2).min(4096), "decomposition" => (self.memory_limit / 16).min(512), _ => (self.memory_limit / 8).min(1024),
};
default_batch.min(datasize)
}
pub fn record_performance(&mut self, record: BatchPerformanceRecord) {
self.performance_history.push(record.clone());
let _current_optimal = self
.optimalsizes
.get(&record.operation)
.copied()
.unwrap_or(0);
if record.throughput > 0.0 {
let best_record = self
.performance_history
.iter()
.filter(|r| r.operation == record.operation)
.max_by(|a, b| {
a.throughput
.partial_cmp(&b.throughput)
.expect("Operation failed")
});
if let Some(best) = best_record {
self.optimalsizes
.insert(record.operation.clone(), best.batchsize);
}
}
}
pub fn get_performance_history(&self, operation: &str) -> Vec<&BatchPerformanceRecord> {
self.performance_history
.iter()
.filter(|record| record.operation == operation)
.collect()
}
pub fn clear_history(&mut self) {
self.performance_history.clear();
}
pub fn get_optimal_sizes(&self) -> &HashMap<String, usize> {
&self.optimalsizes
}
}