Skip to main content

scirs2_cluster/adaptive/
mod.rs

1//! Self-adaptive mini-batch size controller.
2//!
3//! `BatchSizeController` monitors a streaming loss signal and automatically
4//! grows or shrinks the mini-batch size using simple heuristics:
5//!
6//! - If the loss is **decreasing fast** (relative standard deviation of recent
7//!   losses is small): the current gradient direction is reliable → **grow** the
8//!   batch to exploit it.
9//! - If the loss **plateaus or increases** (recent mean is above the previous
10//!   window mean): gradient estimates are noisy or the optimum is near →
11//!   **shrink** the batch for finer stochastic exploration.
12
13use crate::error::ClusteringError;
14
15/// Configuration for [`BatchSizeController`].
16#[derive(Debug, Clone)]
17pub struct AdaptiveBatchConfig {
18    /// Starting batch size.
19    pub initial_batch_size: usize,
20    /// Hard lower bound on batch size.
21    pub min_batch: usize,
22    /// Hard upper bound on batch size.
23    pub max_batch: usize,
24    /// Multiplicative factor when growing the batch (must be > 1).
25    pub growth_factor: f64,
26    /// Multiplicative factor when shrinking the batch (must be in (0, 1)).
27    pub decay_factor: f64,
28    /// Number of recent losses used to compute statistics.
29    pub window: usize,
30}
31
32impl Default for AdaptiveBatchConfig {
33    fn default() -> Self {
34        Self {
35            initial_batch_size: 32,
36            min_batch: 16,
37            max_batch: 2048,
38            growth_factor: 1.5,
39            decay_factor: 0.8,
40            window: 6, // 3 "recent" + 3 "previous"
41        }
42    }
43}
44
45/// Online controller that tracks a loss history and recommends a batch size.
46pub struct BatchSizeController {
47    /// Current recommended batch size.
48    pub current_size: usize,
49    /// Full history of recorded losses (oldest first).
50    pub loss_history: Vec<f64>,
51    config: AdaptiveBatchConfig,
52}
53
54impl BatchSizeController {
55    /// Create a new controller with the given configuration.
56    pub fn new(config: AdaptiveBatchConfig) -> Self {
57        let initial = config
58            .initial_batch_size
59            .clamp(config.min_batch, config.max_batch);
60        Self {
61            current_size: initial,
62            loss_history: Vec::new(),
63            config,
64        }
65    }
66
67    /// Record a new loss value **without** updating the batch size recommendation.
68    pub fn record_loss(&mut self, loss: f64) {
69        self.loss_history.push(loss);
70    }
71
72    /// Recommend a batch size based on the recorded loss history.
73    ///
74    /// Decision rules (applied to the most recent `window` observations):
75    ///
76    /// 1. Not enough history → return current size unchanged.
77    /// 2. Split history into `last_half` and `prev_half` (each `window/2` long).
78    /// 3. If `std(last_half) / mean(last_half) < 0.01` → **grow** (stable descent).
79    /// 4. If `mean(last_half) > mean(prev_half)` → **shrink** (loss increased).
80    /// 5. Otherwise → no change.
81    pub fn recommend_size(&self) -> usize {
82        let w = self.config.window.max(2);
83        let half = w / 2;
84
85        if self.loss_history.len() < w {
86            return self.current_size;
87        }
88
89        let recent: &[f64] = &self.loss_history[self.loss_history.len() - half..];
90        let prev: &[f64] =
91            &self.loss_history[self.loss_history.len() - w..self.loss_history.len() - half];
92
93        let mean_recent = mean(recent);
94        let mean_prev = mean(prev);
95        let std_recent = std_dev(recent);
96
97        // Rule 1: loss is decreasing reliably → grow
98        let relative_std = if mean_recent.abs() > 1e-12 {
99            std_recent / mean_recent.abs()
100        } else {
101            std_recent
102        };
103
104        if relative_std < 0.01 {
105            let new_size =
106                ((self.current_size as f64) * self.config.growth_factor).round() as usize;
107            return new_size.clamp(self.config.min_batch, self.config.max_batch);
108        }
109
110        // Rule 2: loss has increased → shrink
111        if mean_recent > mean_prev {
112            let new_size = ((self.current_size as f64) * self.config.decay_factor).round() as usize;
113            return new_size.clamp(self.config.min_batch, self.config.max_batch);
114        }
115
116        self.current_size
117    }
118
119    /// Record `loss`, update `current_size` and return the new recommended size.
120    pub fn adapt(&mut self, loss: f64) -> usize {
121        self.record_loss(loss);
122        let new_size = self.recommend_size();
123        self.current_size = new_size;
124        new_size
125    }
126
127    /// Reset to initial state.
128    pub fn reset(&mut self) {
129        self.current_size = self
130            .config
131            .initial_batch_size
132            .clamp(self.config.min_batch, self.config.max_batch);
133        self.loss_history.clear();
134    }
135
136    /// Validate the configuration.
137    pub fn validate(&self) -> Result<(), ClusteringError> {
138        if self.config.growth_factor <= 1.0 {
139            return Err(ClusteringError::InvalidInput(
140                "growth_factor must be > 1".into(),
141            ));
142        }
143        if self.config.decay_factor <= 0.0 || self.config.decay_factor >= 1.0 {
144            return Err(ClusteringError::InvalidInput(
145                "decay_factor must be in (0, 1)".into(),
146            ));
147        }
148        if self.config.min_batch > self.config.max_batch {
149            return Err(ClusteringError::InvalidInput(
150                "min_batch must be ≤ max_batch".into(),
151            ));
152        }
153        Ok(())
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Statistics helpers
159// ---------------------------------------------------------------------------
160
161fn mean(xs: &[f64]) -> f64 {
162    if xs.is_empty() {
163        return 0.0;
164    }
165    xs.iter().sum::<f64>() / xs.len() as f64
166}
167
168fn std_dev(xs: &[f64]) -> f64 {
169    if xs.len() < 2 {
170        return 0.0;
171    }
172    let m = mean(xs);
173    let var = xs.iter().map(|x| (x - m) * (x - m)).sum::<f64>() / xs.len() as f64;
174    var.sqrt()
175}
176
177// ---------------------------------------------------------------------------
178// Tests
179// ---------------------------------------------------------------------------
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_initial_size_clamped() {
187        let config = AdaptiveBatchConfig {
188            initial_batch_size: 4,
189            min_batch: 16,
190            max_batch: 2048,
191            ..Default::default()
192        };
193        let ctrl = BatchSizeController::new(config);
194        assert_eq!(ctrl.current_size, 16);
195    }
196
197    #[test]
198    fn test_not_enough_history_returns_current() {
199        let mut ctrl = BatchSizeController::new(AdaptiveBatchConfig::default());
200        ctrl.record_loss(1.0);
201        ctrl.record_loss(0.9);
202        // window=6 not reached yet
203        assert_eq!(ctrl.recommend_size(), ctrl.current_size);
204    }
205
206    #[test]
207    fn test_decreasing_loss_grows_batch() {
208        let config = AdaptiveBatchConfig {
209            initial_batch_size: 64,
210            min_batch: 16,
211            max_batch: 2048,
212            growth_factor: 2.0,
213            decay_factor: 0.5,
214            window: 6,
215        };
216        let mut ctrl = BatchSizeController::new(config);
217
218        // Very stable decreasing losses → small relative std
219        for i in 0..6 {
220            ctrl.record_loss(1.0 - 0.001 * i as f64);
221        }
222        let size = ctrl.recommend_size();
223        assert!(
224            size > 64,
225            "Batch size should grow on stable decreasing loss, got {}",
226            size
227        );
228    }
229
230    #[test]
231    fn test_increasing_loss_shrinks_batch() {
232        let config = AdaptiveBatchConfig {
233            initial_batch_size: 256,
234            min_batch: 16,
235            max_batch: 2048,
236            growth_factor: 1.5,
237            decay_factor: 0.5,
238            window: 6,
239        };
240        let mut ctrl = BatchSizeController::new(config);
241
242        // prev_half: low losses; last_half: high losses → mean_recent > mean_prev
243        ctrl.record_loss(0.1);
244        ctrl.record_loss(0.11);
245        ctrl.record_loss(0.12);
246        ctrl.record_loss(1.5);
247        ctrl.record_loss(1.6);
248        ctrl.record_loss(1.7);
249
250        let size = ctrl.recommend_size();
251        assert!(
252            size < 256,
253            "Batch size should shrink on increasing loss, got {}",
254            size
255        );
256    }
257
258    #[test]
259    fn test_adapt_updates_current_size() {
260        let mut ctrl = BatchSizeController::new(AdaptiveBatchConfig {
261            initial_batch_size: 256,
262            window: 6,
263            ..Default::default()
264        });
265
266        // Force shrink: inject increasing losses
267        ctrl.adapt(0.1);
268        ctrl.adapt(0.11);
269        ctrl.adapt(0.12);
270        ctrl.adapt(1.5);
271        ctrl.adapt(1.6);
272        let final_size = ctrl.adapt(1.7);
273        assert_eq!(
274            final_size, ctrl.current_size,
275            "adapt() should update current_size"
276        );
277    }
278
279    #[test]
280    fn test_bounds_respected() {
281        let config = AdaptiveBatchConfig {
282            initial_batch_size: 17,
283            min_batch: 16,
284            max_batch: 18,
285            growth_factor: 1000.0, // extreme
286            decay_factor: 0.001,   // extreme
287            window: 6,
288        };
289        let mut ctrl = BatchSizeController::new(config);
290
291        // Trigger grow
292        for i in 0..6 {
293            ctrl.record_loss(1.0 - 0.0001 * i as f64);
294        }
295        let grown = ctrl.recommend_size();
296        assert!(grown <= 18, "Must not exceed max_batch");
297
298        // Trigger shrink
299        ctrl.reset();
300        ctrl.record_loss(0.01);
301        ctrl.record_loss(0.01);
302        ctrl.record_loss(0.01);
303        ctrl.record_loss(10.0);
304        ctrl.record_loss(10.0);
305        ctrl.record_loss(10.0);
306        let shrunk = ctrl.recommend_size();
307        assert!(shrunk >= 16, "Must not go below min_batch");
308    }
309
310    #[test]
311    fn test_validate_config() {
312        let ctrl = BatchSizeController::new(AdaptiveBatchConfig::default());
313        assert!(ctrl.validate().is_ok());
314
315        let bad = BatchSizeController::new(AdaptiveBatchConfig {
316            growth_factor: 0.5, // invalid
317            ..Default::default()
318        });
319        assert!(bad.validate().is_err());
320    }
321}