use crate::error::ClusteringError;
#[derive(Debug, Clone)]
pub struct AdaptiveBatchConfig {
pub initial_batch_size: usize,
pub min_batch: usize,
pub max_batch: usize,
pub growth_factor: f64,
pub decay_factor: f64,
pub window: usize,
}
impl Default for AdaptiveBatchConfig {
fn default() -> Self {
Self {
initial_batch_size: 32,
min_batch: 16,
max_batch: 2048,
growth_factor: 1.5,
decay_factor: 0.8,
window: 6, }
}
}
pub struct BatchSizeController {
pub current_size: usize,
pub loss_history: Vec<f64>,
config: AdaptiveBatchConfig,
}
impl BatchSizeController {
pub fn new(config: AdaptiveBatchConfig) -> Self {
let initial = config
.initial_batch_size
.clamp(config.min_batch, config.max_batch);
Self {
current_size: initial,
loss_history: Vec::new(),
config,
}
}
pub fn record_loss(&mut self, loss: f64) {
self.loss_history.push(loss);
}
pub fn recommend_size(&self) -> usize {
let w = self.config.window.max(2);
let half = w / 2;
if self.loss_history.len() < w {
return self.current_size;
}
let recent: &[f64] = &self.loss_history[self.loss_history.len() - half..];
let prev: &[f64] =
&self.loss_history[self.loss_history.len() - w..self.loss_history.len() - half];
let mean_recent = mean(recent);
let mean_prev = mean(prev);
let std_recent = std_dev(recent);
let relative_std = if mean_recent.abs() > 1e-12 {
std_recent / mean_recent.abs()
} else {
std_recent
};
if relative_std < 0.01 {
let new_size =
((self.current_size as f64) * self.config.growth_factor).round() as usize;
return new_size.clamp(self.config.min_batch, self.config.max_batch);
}
if mean_recent > mean_prev {
let new_size = ((self.current_size as f64) * self.config.decay_factor).round() as usize;
return new_size.clamp(self.config.min_batch, self.config.max_batch);
}
self.current_size
}
pub fn adapt(&mut self, loss: f64) -> usize {
self.record_loss(loss);
let new_size = self.recommend_size();
self.current_size = new_size;
new_size
}
pub fn reset(&mut self) {
self.current_size = self
.config
.initial_batch_size
.clamp(self.config.min_batch, self.config.max_batch);
self.loss_history.clear();
}
pub fn validate(&self) -> Result<(), ClusteringError> {
if self.config.growth_factor <= 1.0 {
return Err(ClusteringError::InvalidInput(
"growth_factor must be > 1".into(),
));
}
if self.config.decay_factor <= 0.0 || self.config.decay_factor >= 1.0 {
return Err(ClusteringError::InvalidInput(
"decay_factor must be in (0, 1)".into(),
));
}
if self.config.min_batch > self.config.max_batch {
return Err(ClusteringError::InvalidInput(
"min_batch must be ≤ max_batch".into(),
));
}
Ok(())
}
}
fn mean(xs: &[f64]) -> f64 {
if xs.is_empty() {
return 0.0;
}
xs.iter().sum::<f64>() / xs.len() as f64
}
fn std_dev(xs: &[f64]) -> f64 {
if xs.len() < 2 {
return 0.0;
}
let m = mean(xs);
let var = xs.iter().map(|x| (x - m) * (x - m)).sum::<f64>() / xs.len() as f64;
var.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_initial_size_clamped() {
let config = AdaptiveBatchConfig {
initial_batch_size: 4,
min_batch: 16,
max_batch: 2048,
..Default::default()
};
let ctrl = BatchSizeController::new(config);
assert_eq!(ctrl.current_size, 16);
}
#[test]
fn test_not_enough_history_returns_current() {
let mut ctrl = BatchSizeController::new(AdaptiveBatchConfig::default());
ctrl.record_loss(1.0);
ctrl.record_loss(0.9);
assert_eq!(ctrl.recommend_size(), ctrl.current_size);
}
#[test]
fn test_decreasing_loss_grows_batch() {
let config = AdaptiveBatchConfig {
initial_batch_size: 64,
min_batch: 16,
max_batch: 2048,
growth_factor: 2.0,
decay_factor: 0.5,
window: 6,
};
let mut ctrl = BatchSizeController::new(config);
for i in 0..6 {
ctrl.record_loss(1.0 - 0.001 * i as f64);
}
let size = ctrl.recommend_size();
assert!(
size > 64,
"Batch size should grow on stable decreasing loss, got {}",
size
);
}
#[test]
fn test_increasing_loss_shrinks_batch() {
let config = AdaptiveBatchConfig {
initial_batch_size: 256,
min_batch: 16,
max_batch: 2048,
growth_factor: 1.5,
decay_factor: 0.5,
window: 6,
};
let mut ctrl = BatchSizeController::new(config);
ctrl.record_loss(0.1);
ctrl.record_loss(0.11);
ctrl.record_loss(0.12);
ctrl.record_loss(1.5);
ctrl.record_loss(1.6);
ctrl.record_loss(1.7);
let size = ctrl.recommend_size();
assert!(
size < 256,
"Batch size should shrink on increasing loss, got {}",
size
);
}
#[test]
fn test_adapt_updates_current_size() {
let mut ctrl = BatchSizeController::new(AdaptiveBatchConfig {
initial_batch_size: 256,
window: 6,
..Default::default()
});
ctrl.adapt(0.1);
ctrl.adapt(0.11);
ctrl.adapt(0.12);
ctrl.adapt(1.5);
ctrl.adapt(1.6);
let final_size = ctrl.adapt(1.7);
assert_eq!(
final_size, ctrl.current_size,
"adapt() should update current_size"
);
}
#[test]
fn test_bounds_respected() {
let config = AdaptiveBatchConfig {
initial_batch_size: 17,
min_batch: 16,
max_batch: 18,
growth_factor: 1000.0, decay_factor: 0.001, window: 6,
};
let mut ctrl = BatchSizeController::new(config);
for i in 0..6 {
ctrl.record_loss(1.0 - 0.0001 * i as f64);
}
let grown = ctrl.recommend_size();
assert!(grown <= 18, "Must not exceed max_batch");
ctrl.reset();
ctrl.record_loss(0.01);
ctrl.record_loss(0.01);
ctrl.record_loss(0.01);
ctrl.record_loss(10.0);
ctrl.record_loss(10.0);
ctrl.record_loss(10.0);
let shrunk = ctrl.recommend_size();
assert!(shrunk >= 16, "Must not go below min_batch");
}
#[test]
fn test_validate_config() {
let ctrl = BatchSizeController::new(AdaptiveBatchConfig::default());
assert!(ctrl.validate().is_ok());
let bad = BatchSizeController::new(AdaptiveBatchConfig {
growth_factor: 0.5, ..Default::default()
});
assert!(bad.validate().is_err());
}
}