scirs2-metrics 0.3.0

Machine Learning evaluation metrics module for SciRS2 (scirs2-metrics)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
//! Adaptive batching strategies for streaming data processing
//!
//! This module provides strategies for grouping incoming stream elements into
//! batches before processing, trading latency for throughput:
//!
//! - [`FixedBatcher`]   – fixed-size count-based batching
//! - [`AdaptiveBatcher`] – dynamically adjusts batch size based on observed
//!   throughput and downstream processing latency
//! - [`PriorityBatcher`] – assembles batches by element priority so that
//!   high-priority items are never delayed by low-priority fill

use crate::error::{MetricsError, Result};
use scirs2_core::numeric::Float;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, VecDeque};
use std::time::{Duration, Instant};

// ── BatchOutcome ─────────────────────────────────────────────────────────────

/// The result returned by each batcher's `push` method.
#[derive(Debug, Clone)]
pub struct BatchOutcome<T> {
    /// The assembled batch (may be partial on timeout flush).
    pub items: Vec<T>,
    /// Whether the batch was triggered by a timeout rather than reaching the
    /// target batch size.
    pub is_timeout_flush: bool,
    /// Size target that was active when this batch was emitted.
    pub target_size: usize,
}

// ── BatcherStats ─────────────────────────────────────────────────────────────

/// Aggregate statistics collected across all emitted batches.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BatcherStats {
    /// Total number of batches emitted.
    pub batches_emitted: u64,
    /// Total elements processed.
    pub total_elements: u64,
    /// Number of timeout-triggered flushes.
    pub timeout_flushes: u64,
    /// Minimum batch size seen.
    pub min_batch_size: usize,
    /// Maximum batch size seen.
    pub max_batch_size: usize,
    /// Running mean batch size (Welford online update).
    pub mean_batch_size: f64,
}

impl BatcherStats {
    /// Update statistics with a newly emitted batch of `size` elements.
    pub fn record_batch(&mut self, size: usize, is_timeout: bool) {
        self.batches_emitted += 1;
        self.total_elements += size as u64;
        if is_timeout {
            self.timeout_flushes += 1;
        }
        if self.batches_emitted == 1 {
            self.min_batch_size = size;
            self.max_batch_size = size;
        } else {
            self.min_batch_size = self.min_batch_size.min(size);
            self.max_batch_size = self.max_batch_size.max(size);
        }
        // Welford incremental mean
        let delta = size as f64 - self.mean_batch_size;
        self.mean_batch_size += delta / self.batches_emitted as f64;
    }
}

// ── FixedBatcher ─────────────────────────────────────────────────────────────

/// A simple batcher that accumulates elements until a fixed count is reached.
///
/// An optional `max_wait` duration causes the current buffer to be flushed even
/// if it has fewer than `batch_size` elements once the wall-clock threshold is
/// exceeded.
#[derive(Debug)]
pub struct FixedBatcher<T> {
    batch_size: usize,
    max_wait: Option<Duration>,
    buffer: Vec<T>,
    window_start: Instant,
    stats: BatcherStats,
}

impl<T> FixedBatcher<T> {
    /// Create a new fixed batcher.
    ///
    /// # Arguments
    /// * `batch_size` – target number of elements per batch (must be >= 1)
    /// * `max_wait`   – optional timeout after which a partial batch is emitted
    pub fn new(batch_size: usize, max_wait: Option<Duration>) -> Result<Self> {
        if batch_size == 0 {
            return Err(MetricsError::InvalidInput(
                "FixedBatcher batch_size must be >= 1".to_string(),
            ));
        }
        Ok(Self {
            batch_size,
            max_wait,
            buffer: Vec::with_capacity(batch_size),
            window_start: Instant::now(),
            stats: BatcherStats::default(),
        })
    }

    /// Push an element.  Returns a completed batch when the target is reached or
    /// the timeout expires; otherwise returns `None`.
    pub fn push(&mut self, value: T) -> Option<BatchOutcome<T>> {
        self.buffer.push(value);

        let timeout_expired = self
            .max_wait
            .map_or(false, |d| self.window_start.elapsed() >= d);

        if self.buffer.len() >= self.batch_size {
            Some(self.flush_internal(false))
        } else if timeout_expired {
            Some(self.flush_internal(true))
        } else {
            None
        }
    }

    /// Force-flush whatever is in the buffer regardless of batch size / timeout.
    pub fn flush(&mut self) -> Option<BatchOutcome<T>> {
        if self.buffer.is_empty() {
            None
        } else {
            Some(self.flush_internal(true))
        }
    }

    /// Number of elements currently buffered (not yet emitted).
    #[inline]
    pub fn buffered_len(&self) -> usize {
        self.buffer.len()
    }

    /// Cumulative statistics since construction.
    #[inline]
    pub fn stats(&self) -> &BatcherStats {
        &self.stats
    }

    fn flush_internal(&mut self, is_timeout: bool) -> BatchOutcome<T> {
        let items = std::mem::take(&mut self.buffer);
        let size = items.len();
        self.stats.record_batch(size, is_timeout);
        self.buffer = Vec::with_capacity(self.batch_size);
        self.window_start = Instant::now();
        BatchOutcome {
            items,
            is_timeout_flush: is_timeout,
            target_size: self.batch_size,
        }
    }
}

// ── AdaptiveBatcher ──────────────────────────────────────────────────────────

/// Adaptation policy used by [`AdaptiveBatcher`].
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AdaptationPolicy {
    /// Adjust batch size to maintain a target processing latency.
    LatencyTarget {
        /// Desired maximum latency per batch (ms).
        target_ms: f64,
    },
    /// Adjust batch size to maximise throughput while staying within a latency
    /// budget.
    ThroughputMaximisation {
        /// Hard latency ceiling (ms).
        max_latency_ms: f64,
    },
    /// Exponential moving average of observed batch sizes with a fixed
    /// smoothing factor.
    ExponentialSmoothing {
        /// Smoothing factor α ∈ (0, 1].
        alpha: f64,
    },
}

/// An adaptive batcher that adjusts its target batch size over time.
///
/// The adaptation runs every `adaptation_interval` batches.
#[derive(Debug)]
pub struct AdaptiveBatcher<T, F: Float + std::fmt::Debug> {
    min_batch_size: usize,
    max_batch_size: usize,
    current_target: usize,
    policy: AdaptationPolicy,
    adaptation_interval: u64,
    buffer: VecDeque<T>,
    stats: BatcherStats,
    /// Most recently measured processing latencies (one per batch, ms).
    latency_history: VecDeque<F>,
    latency_history_cap: usize,
    batches_since_adapt: u64,
}

impl<T, F: Float + std::fmt::Debug> AdaptiveBatcher<T, F> {
    /// Create a new adaptive batcher.
    ///
    /// # Arguments
    /// * `min_batch_size`       – smallest allowed batch size
    /// * `max_batch_size`       – largest allowed batch size
    /// * `initial_batch_size`   – starting batch size (clamped to [min, max])
    /// * `policy`               – adaptation policy
    /// * `adaptation_interval`  – number of batches between adaptations
    pub fn new(
        min_batch_size: usize,
        max_batch_size: usize,
        initial_batch_size: usize,
        policy: AdaptationPolicy,
        adaptation_interval: u64,
    ) -> Result<Self> {
        if min_batch_size == 0 {
            return Err(MetricsError::InvalidInput(
                "AdaptiveBatcher min_batch_size must be >= 1".to_string(),
            ));
        }
        if max_batch_size < min_batch_size {
            return Err(MetricsError::InvalidInput(format!(
                "max_batch_size ({max_batch_size}) < min_batch_size ({min_batch_size})"
            )));
        }
        if adaptation_interval == 0 {
            return Err(MetricsError::InvalidInput(
                "adaptation_interval must be >= 1".to_string(),
            ));
        }
        let current_target = initial_batch_size.clamp(min_batch_size, max_batch_size);
        Ok(Self {
            min_batch_size,
            max_batch_size,
            current_target,
            policy,
            adaptation_interval,
            buffer: VecDeque::new(),
            stats: BatcherStats::default(),
            latency_history: VecDeque::new(),
            latency_history_cap: 64,
            batches_since_adapt: 0,
        })
    }

    /// Ingest an element.  Returns a batch when the current target is reached.
    pub fn push(&mut self, value: T) -> Option<BatchOutcome<T>> {
        self.buffer.push_back(value);
        if self.buffer.len() >= self.current_target {
            Some(self.emit_batch(false))
        } else {
            None
        }
    }

    /// Record the processing latency for the most recently emitted batch and
    /// trigger adaptation if the interval is reached.
    pub fn record_latency(&mut self, latency_ms: F) {
        self.latency_history.push_back(latency_ms);
        while self.latency_history.len() > self.latency_history_cap {
            self.latency_history.pop_front();
        }

        self.batches_since_adapt += 1;
        if self.batches_since_adapt >= self.adaptation_interval {
            self.adapt();
            self.batches_since_adapt = 0;
        }
    }

    /// Force-flush the buffer.
    pub fn flush(&mut self) -> Option<BatchOutcome<T>> {
        if self.buffer.is_empty() {
            None
        } else {
            Some(self.emit_batch(true))
        }
    }

    /// Current target batch size.
    #[inline]
    pub fn current_target(&self) -> usize {
        self.current_target
    }

    /// Cumulative statistics since construction.
    #[inline]
    pub fn stats(&self) -> &BatcherStats {
        &self.stats
    }

    fn emit_batch(&mut self, is_timeout: bool) -> BatchOutcome<T> {
        let items: Vec<T> = self.buffer.drain(..self.current_target.min(self.buffer.len())).collect();
        let size = items.len();
        self.stats.record_batch(size, is_timeout);
        BatchOutcome {
            items,
            is_timeout_flush: is_timeout,
            target_size: self.current_target,
        }
    }

    fn adapt(&mut self) {
        if self.latency_history.is_empty() {
            return;
        }
        let n = F::from(self.latency_history.len()).expect("usize fits in F");
        let mean_lat = self.latency_history.iter().copied().fold(F::zero(), |a, x| a + x) / n;
        let mean_lat_f64 = mean_lat.to_f64().unwrap_or(f64::MAX);

        let new_target = match &self.policy {
            AdaptationPolicy::LatencyTarget { target_ms } => {
                let ratio = target_ms / mean_lat_f64.max(f64::EPSILON);
                let adjusted = (self.current_target as f64 * ratio).round() as usize;
                adjusted.clamp(self.min_batch_size, self.max_batch_size)
            }
            AdaptationPolicy::ThroughputMaximisation { max_latency_ms } => {
                if mean_lat_f64 < *max_latency_ms * 0.8 {
                    // Well below budget — grow
                    ((self.current_target as f64 * 1.25) as usize)
                        .clamp(self.min_batch_size, self.max_batch_size)
                } else if mean_lat_f64 > *max_latency_ms {
                    // Over budget — shrink
                    ((self.current_target as f64 * 0.75) as usize)
                        .clamp(self.min_batch_size, self.max_batch_size)
                } else {
                    self.current_target
                }
            }
            AdaptationPolicy::ExponentialSmoothing { alpha } => {
                let smoothed =
                    *alpha * self.stats.mean_batch_size + (1.0 - alpha) * self.current_target as f64;
                (smoothed.round() as usize).clamp(self.min_batch_size, self.max_batch_size)
            }
        };

        self.current_target = new_target;
    }
}

// ── PriorityBatcher ──────────────────────────────────────────────────────────

/// An element wrapper that carries a priority level for use in the priority
/// queue inside [`PriorityBatcher`].
#[derive(Debug, Clone)]
struct PrioritizedItem<T> {
    priority: u32,
    sequence: u64, // tie-breaker: lower sequence means older
    value: T,
}

impl<T> PartialEq for PrioritizedItem<T> {
    fn eq(&self, other: &Self) -> bool {
        self.priority == other.priority && self.sequence == other.sequence
    }
}
impl<T> Eq for PrioritizedItem<T> {}

impl<T> PartialOrd for PrioritizedItem<T> {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl<T> Ord for PrioritizedItem<T> {
    fn cmp(&self, other: &Self) -> Ordering {
        // Higher priority first; on tie prefer the older (lower sequence) item
        other
            .priority
            .cmp(&self.priority)
            .reverse()
            .then_with(|| self.sequence.cmp(&other.sequence))
    }
}

/// A batcher that fills each batch from highest to lowest priority.
///
/// When the target batch size is reached the top-priority elements are emitted.
/// All remaining elements stay in the queue for subsequent batches.
#[derive(Debug)]
pub struct PriorityBatcher<T> {
    batch_size: usize,
    queue: BinaryHeap<PrioritizedItem<T>>,
    sequence_counter: u64,
    stats: BatcherStats,
}

impl<T: Clone + std::fmt::Debug> PriorityBatcher<T> {
    /// Create a new priority batcher.
    ///
    /// # Arguments
    /// * `batch_size` – number of elements per emitted batch (must be >= 1)
    pub fn new(batch_size: usize) -> Result<Self> {
        if batch_size == 0 {
            return Err(MetricsError::InvalidInput(
                "PriorityBatcher batch_size must be >= 1".to_string(),
            ));
        }
        Ok(Self {
            batch_size,
            queue: BinaryHeap::new(),
            sequence_counter: 0,
            stats: BatcherStats::default(),
        })
    }

    /// Push an element with a numeric priority (higher = more urgent).
    ///
    /// Returns a batch when the queue reaches the target size.
    pub fn push(&mut self, value: T, priority: u32) -> Option<BatchOutcome<T>> {
        self.queue.push(PrioritizedItem {
            priority,
            sequence: self.sequence_counter,
            value,
        });
        self.sequence_counter += 1;

        if self.queue.len() >= self.batch_size {
            Some(self.drain_batch())
        } else {
            None
        }
    }

    /// Flush the highest-priority elements currently in the queue.
    pub fn flush(&mut self) -> Option<BatchOutcome<T>> {
        if self.queue.is_empty() {
            None
        } else {
            Some(self.drain_batch())
        }
    }

    /// Number of elements waiting in the priority queue.
    #[inline]
    pub fn queued_len(&self) -> usize {
        self.queue.len()
    }

    /// Cumulative statistics since construction.
    #[inline]
    pub fn stats(&self) -> &BatcherStats {
        &self.stats
    }

    fn drain_batch(&mut self) -> BatchOutcome<T> {
        let take = self.batch_size.min(self.queue.len());
        let items: Vec<T> = (0..take)
            .filter_map(|_| self.queue.pop().map(|p| p.value))
            .collect();
        let size = items.len();
        self.stats.record_batch(size, false);
        BatchOutcome {
            items,
            is_timeout_flush: false,
            target_size: self.batch_size,
        }
    }
}

// ── Tests ────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn fixed_batcher_emits_at_target() {
        let mut b: FixedBatcher<i32> = FixedBatcher::new(3, None).expect("valid");
        assert!(b.push(1).is_none());
        assert!(b.push(2).is_none());
        let out = b.push(3).expect("batch emitted");
        assert_eq!(out.items, vec![1, 2, 3]);
        assert!(!out.is_timeout_flush);
        assert_eq!(b.stats().batches_emitted, 1);
    }

    #[test]
    fn fixed_batcher_flush() {
        let mut b: FixedBatcher<i32> = FixedBatcher::new(10, None).expect("valid");
        b.push(1);
        b.push(2);
        let out = b.flush().expect("partial flush");
        assert_eq!(out.items, vec![1, 2]);
        assert!(out.is_timeout_flush);
        assert!(b.flush().is_none());
    }

    #[test]
    fn fixed_batcher_zero_size_errors() {
        assert!(FixedBatcher::<i32>::new(0, None).is_err());
    }

    #[test]
    fn adaptive_batcher_basic() {
        let mut b: AdaptiveBatcher<i32, f64> = AdaptiveBatcher::new(
            1,
            100,
            4,
            AdaptationPolicy::LatencyTarget { target_ms: 10.0 },
            5,
        )
        .expect("valid");

        for i in 0..4 {
            let out = b.push(i);
            if i < 3 {
                assert!(out.is_none());
            } else {
                assert!(out.is_some());
            }
        }
        assert_eq!(b.stats().batches_emitted, 1);
    }

    #[test]
    fn adaptive_batcher_adaptation_does_not_panic() {
        let mut b: AdaptiveBatcher<i32, f64> = AdaptiveBatcher::new(
            1,
            50,
            5,
            AdaptationPolicy::ThroughputMaximisation { max_latency_ms: 20.0 },
            2,
        )
        .expect("valid");
        // Feed 20 elements, recording latency after each batch
        for i in 0..20_i32 {
            if let Some(_out) = b.push(i) {
                b.record_latency(15.0_f64);
            }
        }
        // current_target must remain in valid range
        assert!(b.current_target() >= 1);
        assert!(b.current_target() <= 50);
    }

    #[test]
    fn priority_batcher_ordering() {
        let mut b: PriorityBatcher<&str> = PriorityBatcher::new(3).expect("valid");
        assert!(b.push("low", 1).is_none());
        assert!(b.push("high", 10).is_none());
        let out = b.push("medium", 5).expect("batch emitted");
        // highest priority first
        assert_eq!(out.items[0], "high");
        assert_eq!(out.items[1], "medium");
        assert_eq!(out.items[2], "low");
    }

    #[test]
    fn priority_batcher_invalid_size() {
        assert!(PriorityBatcher::<i32>::new(0).is_err());
    }

    #[test]
    fn batcher_stats_record() {
        let mut s = BatcherStats::default();
        s.record_batch(10, false);
        s.record_batch(5, true);
        assert_eq!(s.batches_emitted, 2);
        assert_eq!(s.total_elements, 15);
        assert_eq!(s.timeout_flushes, 1);
        assert_eq!(s.min_batch_size, 5);
        assert_eq!(s.max_batch_size, 10);
        assert!((s.mean_batch_size - 7.5).abs() < 1e-10);
    }
}