use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, RwLock};
use std::time::Instant;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveBatchConfig {
pub min_batch_size: usize,
pub max_batch_size: usize,
pub samples_per_size: usize,
pub warmup_iterations: usize,
pub target_latency_percentile: f64,
pub target_latency_ms: f64,
pub throughput_weight: f64,
pub latency_weight: f64,
pub memory_weight: f64,
pub reevaluation_interval_secs: u64,
}
impl Default for AdaptiveBatchConfig {
fn default() -> Self {
Self {
min_batch_size: 1,
max_batch_size: 64,
samples_per_size: 10,
warmup_iterations: 3,
target_latency_percentile: 95.0,
target_latency_ms: 100.0,
throughput_weight: 0.4,
latency_weight: 0.4,
memory_weight: 0.2,
reevaluation_interval_secs: 300, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceSample {
pub batch_size: usize,
pub latency_ms: f64,
pub throughput_rps: f64,
pub memory_usage_mb: f64,
pub gpu_memory_mb: f64,
pub cpu_utilization: f32,
pub gpu_utilization: f32,
pub timestamp: std::time::SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchSizeStats {
pub batch_size: usize,
pub sample_count: usize,
pub avg_latency_ms: f64,
pub p95_latency_ms: f64,
pub p99_latency_ms: f64,
pub avg_throughput_rps: f64,
pub max_throughput_rps: f64,
pub avg_memory_mb: f64,
pub max_memory_mb: f64,
pub avg_cpu_utilization: f32,
pub avg_gpu_utilization: f32,
pub score: f64, }
#[derive(Debug)]
pub struct AdaptiveBatchOptimizer {
config: AdaptiveBatchConfig,
performance_data: Arc<RwLock<HashMap<usize, VecDeque<PerformanceSample>>>>,
batch_stats: Arc<RwLock<HashMap<usize, BatchSizeStats>>>,
optimal_batch_size: Arc<RwLock<Option<usize>>>,
last_evaluation: Arc<RwLock<Instant>>,
current_test_size: Arc<RwLock<Option<usize>>>,
test_iteration: Arc<RwLock<usize>>,
}
pub type AdaptiveBatchManager = AdaptiveBatchOptimizer;
impl AdaptiveBatchOptimizer {
pub fn new(config: AdaptiveBatchConfig) -> Self {
Self {
config,
performance_data: Arc::new(RwLock::new(HashMap::new())),
batch_stats: Arc::new(RwLock::new(HashMap::new())),
optimal_batch_size: Arc::new(RwLock::new(None)),
last_evaluation: Arc::new(RwLock::new(Instant::now())),
current_test_size: Arc::new(RwLock::new(None)),
test_iteration: Arc::new(RwLock::new(0)),
}
}
pub fn get_optimal_batch_size(&self) -> Option<usize> {
*self.optimal_batch_size.read().expect("lock should not be poisoned")
}
pub fn record_sample(&self, sample: PerformanceSample) -> Result<()> {
let batch_size = sample.batch_size;
{
let mut data = self.performance_data.write().expect("lock should not be poisoned");
let samples = data.entry(batch_size).or_default();
samples.push_back(sample);
while samples.len() > self.config.samples_per_size * 2 {
samples.pop_front();
}
}
self.update_batch_stats(batch_size)?;
let should_reevaluate = {
let last_eval = self.last_evaluation.read().expect("lock should not be poisoned");
last_eval.elapsed().as_secs() >= self.config.reevaluation_interval_secs
};
if should_reevaluate {
self.evaluate_optimal_batch_size()?;
}
Ok(())
}
pub fn get_next_test_size(&self) -> Option<usize> {
let mut current_test = self.current_test_size.write().expect("lock should not be poisoned");
let mut iteration = self.test_iteration.write().expect("lock should not be poisoned");
match *current_test {
None => {
*current_test = Some(self.config.min_batch_size);
*iteration = 0;
Some(self.config.min_batch_size)
},
Some(size) => {
*iteration += 1;
let enough_samples = {
let data = self.performance_data.read().expect("lock should not be poisoned");
data.get(&size)
.map(|samples| samples.len() >= self.config.samples_per_size)
.unwrap_or(false)
};
if enough_samples {
let next_size = self.get_next_size_to_test(size);
*current_test = next_size;
*iteration = 0;
next_size
} else {
Some(size)
}
},
}
}
fn get_next_size_to_test(&self, current_size: usize) -> Option<usize> {
let progression = [1, 2, 4, 8, 16, 32, 64, 128];
if let Some(pos) = progression.iter().position(|&x| x == current_size) {
if pos + 1 < progression.len() && progression[pos + 1] <= self.config.max_batch_size {
Some(progression[pos + 1])
} else {
None }
} else {
let next = current_size * 2;
if next <= self.config.max_batch_size {
Some(next)
} else {
None
}
}
}
fn update_batch_stats(&self, batch_size: usize) -> Result<()> {
let samples = {
let data = self.performance_data.read().expect("lock should not be poisoned");
data.get(&batch_size).cloned().unwrap_or_default()
};
if samples.len() < 3 {
return Ok(()); }
let sample_vec: Vec<_> = samples.iter().collect();
let avg_latency =
sample_vec.iter().map(|s| s.latency_ms).sum::<f64>() / sample_vec.len() as f64;
let avg_throughput =
sample_vec.iter().map(|s| s.throughput_rps).sum::<f64>() / sample_vec.len() as f64;
let max_throughput = sample_vec.iter().map(|s| s.throughput_rps).fold(0.0, f64::max);
let avg_memory =
sample_vec.iter().map(|s| s.memory_usage_mb).sum::<f64>() / sample_vec.len() as f64;
let max_memory = sample_vec.iter().map(|s| s.memory_usage_mb).fold(0.0, f64::max);
let avg_cpu =
sample_vec.iter().map(|s| s.cpu_utilization).sum::<f32>() / sample_vec.len() as f32;
let avg_gpu =
sample_vec.iter().map(|s| s.gpu_utilization).sum::<f32>() / sample_vec.len() as f32;
let mut latencies: Vec<f64> = sample_vec.iter().map(|s| s.latency_ms).collect();
latencies.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let p95_index = ((latencies.len() as f64) * 0.95) as usize;
let p99_index = ((latencies.len() as f64) * 0.99) as usize;
let p95_latency = latencies.get(p95_index).copied().unwrap_or(avg_latency);
let p99_latency = latencies.get(p99_index).copied().unwrap_or(avg_latency);
let score =
self.calculate_optimization_score(avg_latency, p95_latency, avg_throughput, avg_memory);
let stats = BatchSizeStats {
batch_size,
sample_count: sample_vec.len(),
avg_latency_ms: avg_latency,
p95_latency_ms: p95_latency,
p99_latency_ms: p99_latency,
avg_throughput_rps: avg_throughput,
max_throughput_rps: max_throughput,
avg_memory_mb: avg_memory,
max_memory_mb: max_memory,
avg_cpu_utilization: avg_cpu,
avg_gpu_utilization: avg_gpu,
score,
};
{
let mut batch_stats = self.batch_stats.write().expect("lock should not be poisoned");
batch_stats.insert(batch_size, stats);
}
Ok(())
}
fn calculate_optimization_score(
&self,
avg_latency: f64,
p95_latency: f64,
avg_throughput: f64,
avg_memory: f64,
) -> f64 {
let latency_score = if p95_latency <= self.config.target_latency_ms {
1.0 } else {
(self.config.target_latency_ms / p95_latency).max(0.1) };
let throughput_score = (avg_throughput / 100.0).min(1.0);
let memory_score = (1000.0 / (avg_memory + 100.0)).min(1.0);
(self.config.latency_weight * latency_score)
+ (self.config.throughput_weight * throughput_score)
+ (self.config.memory_weight * memory_score)
}
fn evaluate_optimal_batch_size(&self) -> Result<()> {
let stats = self.batch_stats.read().expect("lock should not be poisoned");
if stats.is_empty() {
return Ok(());
}
let optimal = stats
.values()
.filter(|s| s.sample_count >= self.config.samples_per_size)
.max_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
.map(|s| s.batch_size);
if let Some(optimal_size) = optimal {
let mut current_optimal =
self.optimal_batch_size.write().expect("lock should not be poisoned");
*current_optimal = Some(optimal_size);
tracing::info!(
"Updated optimal batch size to {}: score {:.3}",
optimal_size,
stats.get(&optimal_size).expect("optimal_size came from stats").score
);
}
{
let mut last_eval = self.last_evaluation.write().expect("lock should not be poisoned");
*last_eval = Instant::now();
}
Ok(())
}
pub fn get_performance_report(&self) -> PerformanceReport {
let stats = self.batch_stats.read().expect("lock should not be poisoned");
let optimal = *self.optimal_batch_size.read().expect("lock should not be poisoned");
let mut batch_performances: Vec<_> = stats.values().cloned().collect();
batch_performances.sort_by_key(|s| s.batch_size);
PerformanceReport {
optimal_batch_size: optimal,
batch_performances,
total_evaluations: stats.len(),
last_evaluation: *self.last_evaluation.read().expect("lock should not be poisoned"),
}
}
pub fn export_data(&self) -> Result<String> {
let stats = self.batch_stats.read().expect("lock should not be poisoned");
let data = stats.values().collect::<Vec<_>>();
serde_json::to_string_pretty(&data)
.map_err(|e| TrustformersError::runtime_error(format!("Failed to export data: {}", e)))
}
pub fn import_data(&self, data: &str) -> Result<()> {
let imported_stats: Vec<BatchSizeStats> = serde_json::from_str(data).map_err(|e| {
TrustformersError::runtime_error(format!("Failed to import data: {}", e))
})?;
{
let mut stats = self.batch_stats.write().expect("lock should not be poisoned");
for stat in imported_stats {
stats.insert(stat.batch_size, stat);
}
}
self.evaluate_optimal_batch_size()?;
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceReport {
pub optimal_batch_size: Option<usize>,
pub batch_performances: Vec<BatchSizeStats>,
pub total_evaluations: usize,
#[serde(skip, default = "Instant::now")]
pub last_evaluation: Instant,
}
impl PerformanceReport {
pub fn get_top_performers(&self, n: usize) -> Vec<&BatchSizeStats> {
let mut sorted = self.batch_performances.iter().collect::<Vec<_>>();
sorted.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
sorted.into_iter().take(n).collect()
}
pub fn compare_batch_sizes(&self, size1: usize, size2: usize) -> Option<BatchComparison> {
let stats1 = self.batch_performances.iter().find(|s| s.batch_size == size1)?;
let stats2 = self.batch_performances.iter().find(|s| s.batch_size == size2)?;
Some(BatchComparison {
size1,
size2,
latency_improvement: stats1.avg_latency_ms / stats2.avg_latency_ms,
throughput_improvement: stats2.avg_throughput_rps / stats1.avg_throughput_rps,
memory_difference: stats2.avg_memory_mb - stats1.avg_memory_mb,
score_difference: stats2.score - stats1.score,
})
}
}
#[derive(Debug, Clone)]
pub struct BatchComparison {
pub size1: usize,
pub size2: usize,
pub latency_improvement: f64, pub throughput_improvement: f64, pub memory_difference: f64, pub score_difference: f64, }
#[cfg(test)]
mod tests {
use super::*;
use std::time::SystemTime;
fn make_sample(
batch_size: usize,
latency_ms: f64,
throughput_rps: f64,
memory_usage_mb: f64,
) -> PerformanceSample {
PerformanceSample {
batch_size,
latency_ms,
throughput_rps,
memory_usage_mb,
gpu_memory_mb: memory_usage_mb * 2.0,
cpu_utilization: 0.6,
gpu_utilization: 0.7,
timestamp: SystemTime::now(),
}
}
fn populate_optimizer(optimizer: &AdaptiveBatchOptimizer, batch_size: usize, n: usize) {
for _ in 0..n {
let sample = make_sample(batch_size, 60.0, 50.0, 256.0);
optimizer.record_sample(sample).expect("record_sample should succeed");
}
}
#[test]
fn test_config_min_less_than_max() {
let config = AdaptiveBatchConfig::default();
assert!(config.min_batch_size < config.max_batch_size);
}
#[test]
fn test_config_weights_sum_to_one() {
let config = AdaptiveBatchConfig::default();
let total = config.throughput_weight + config.latency_weight + config.memory_weight;
assert!((total - 1.0).abs() < 1e-9, "weights must sum to 1.0");
}
#[test]
fn test_config_target_latency_positive() {
let config = AdaptiveBatchConfig::default();
assert!(config.target_latency_ms > 0.0);
}
#[test]
fn test_optimizer_starts_with_no_optimal() {
let optimizer = AdaptiveBatchOptimizer::new(AdaptiveBatchConfig::default());
assert!(
optimizer.get_optimal_batch_size().is_none(),
"no optimal batch size before any data"
);
}
#[test]
fn test_adaptive_batch_optimizer_basic() {
let config = AdaptiveBatchConfig::default();
let optimizer = AdaptiveBatchOptimizer::new(config);
let sample = make_sample(8, 50.0, 160.0, 512.0);
optimizer.record_sample(sample).expect("record_sample should succeed");
assert_eq!(optimizer.get_next_test_size(), Some(1));
assert_eq!(optimizer.get_next_test_size(), Some(1));
}
#[test]
fn test_optimization_score_calculation() {
let config = AdaptiveBatchConfig::default();
let optimizer = AdaptiveBatchOptimizer::new(config);
let score1 = optimizer.calculate_optimization_score(80.0, 95.0, 50.0, 500.0);
let score2 = optimizer.calculate_optimization_score(120.0, 150.0, 50.0, 500.0);
assert!(
score1 > score2,
"score should be higher when meeting latency target"
);
}
#[test]
fn test_optimization_score_lower_memory_better() {
let config = AdaptiveBatchConfig::default();
let optimizer = AdaptiveBatchOptimizer::new(config);
let score_low_mem = optimizer.calculate_optimization_score(80.0, 90.0, 50.0, 100.0);
let score_high_mem = optimizer.calculate_optimization_score(80.0, 90.0, 50.0, 5000.0);
assert!(
score_low_mem > score_high_mem,
"lower memory should yield better score"
);
}
#[test]
fn test_optimization_score_higher_throughput_better() {
let config = AdaptiveBatchConfig::default();
let optimizer = AdaptiveBatchOptimizer::new(config);
let score_high = optimizer.calculate_optimization_score(80.0, 90.0, 200.0, 500.0);
let score_low = optimizer.calculate_optimization_score(80.0, 90.0, 5.0, 500.0);
assert!(
score_high > score_low,
"higher throughput should yield better score"
);
}
#[test]
fn test_dynamic_batch_size_based_on_throughput() {
let throughput = 1.0_f64 / (50.0_f64 / 1000.0);
assert!(
throughput > 10.0,
"batch 1 at 50ms latency should exceed 10 rps threshold"
);
}
#[test]
fn test_batch_size_increase_on_spare_capacity() {
let batch = 8_usize;
let latency_ms = 5.0_f64; let throughput = batch as f64 / (latency_ms / 1000.0);
assert!(
throughput > 100.0,
"large batch at low latency should show spare capacity"
);
}
#[test]
fn test_exponential_smoothing_formula() {
let old_ema = 50.0_f64;
let alpha = 0.1_f64;
let new_val = 100.0_f64;
let new_ema = old_ema * (1.0 - alpha) + new_val * alpha;
assert!((new_ema - 55.0).abs() < 1e-9, "EMA update should be 55.0");
}
#[test]
fn test_get_next_size_to_test_exponential() {
let config = AdaptiveBatchConfig {
min_batch_size: 1,
max_batch_size: 64,
..Default::default()
};
let optimizer = AdaptiveBatchOptimizer::new(config);
for _ in 0..15 {
optimizer
.record_sample(make_sample(1, 50.0, 20.0, 128.0))
.expect("record_sample should succeed");
}
let next = optimizer.get_next_test_size();
assert!(
next == Some(1) || next == Some(2),
"should transition from 1 to 2 in exponential progression"
);
}
#[test]
fn test_get_next_size_none_when_exhausted() {
let config = AdaptiveBatchConfig {
min_batch_size: 64,
max_batch_size: 64,
samples_per_size: 2,
..Default::default()
};
let optimizer = AdaptiveBatchOptimizer::new(config);
for _ in 0..5 {
optimizer
.record_sample(make_sample(64, 50.0, 50.0, 512.0))
.expect("record_sample should succeed");
}
let _ = optimizer.get_next_test_size(); }
#[test]
fn test_performance_report() {
let config = AdaptiveBatchConfig::default();
let optimizer = AdaptiveBatchOptimizer::new(config);
for batch_size in [2_usize, 4, 8, 16] {
for _ in 0..5 {
let sample = make_sample(
batch_size,
50.0 + batch_size as f64 * 5.0,
batch_size as f64 * 10.0,
batch_size as f64 * 64.0,
);
optimizer.record_sample(sample).expect("record_sample should succeed");
}
}
let report = optimizer.get_performance_report();
assert!(!report.batch_performances.is_empty());
let top_performers = report.get_top_performers(2);
assert!(top_performers.len() <= 2);
}
#[test]
fn test_performance_report_empty() {
let config = AdaptiveBatchConfig::default();
let optimizer = AdaptiveBatchOptimizer::new(config);
let report = optimizer.get_performance_report();
assert!(
report.batch_performances.is_empty(),
"fresh optimizer has no performance data"
);
assert!(report.optimal_batch_size.is_none());
}
#[test]
fn test_batch_size_comparison_throughput_improvement() {
let config = AdaptiveBatchConfig::default();
let optimizer = AdaptiveBatchOptimizer::new(config);
for _ in 0..5 {
optimizer
.record_sample(make_sample(4, 80.0, 50.0, 256.0))
.expect("record_sample should succeed");
optimizer
.record_sample(make_sample(8, 80.0, 100.0, 512.0))
.expect("record_sample should succeed");
}
let report = optimizer.get_performance_report();
if let Some(cmp) = report.compare_batch_sizes(4, 8) {
assert!(
cmp.throughput_improvement > 1.0,
"batch 8 should have better throughput than batch 4"
);
}
}
#[test]
fn test_export_and_import_data_roundtrip() {
let config = AdaptiveBatchConfig::default();
let optimizer = AdaptiveBatchOptimizer::new(config.clone());
for _ in 0..5 {
optimizer
.record_sample(make_sample(4, 60.0, 60.0, 256.0))
.expect("record_sample should succeed");
}
let exported = optimizer.export_data().expect("export_data should succeed");
let optimizer2 = AdaptiveBatchOptimizer::new(config);
optimizer2.import_data(&exported).expect("import_data should succeed");
let report = optimizer2.get_performance_report();
assert!(
!report.batch_performances.is_empty(),
"imported data should populate report"
);
}
#[test]
fn test_alias_type() {
let _mgr: AdaptiveBatchManager =
AdaptiveBatchOptimizer::new(AdaptiveBatchConfig::default());
}
#[test]
fn test_latency_target_adjustment_formula() {
let config = AdaptiveBatchConfig {
target_latency_ms: 100.0,
..Default::default()
};
let optimizer = AdaptiveBatchOptimizer::new(config);
let score = optimizer.calculate_optimization_score(150.0, 200.0, 30.0, 500.0);
assert!(
score < 1.0,
"exceeding latency target should penalise score"
);
}
}