Skip to main content

ferrolearn_core/
streaming.rs

1//! Streaming data adapter for incremental learning.
2//!
3//! The [`StreamingFitter`] feeds batches from an iterator to a
4//! [`PartialFit`](crate::PartialFit) model, enabling online/streaming
5//! learning workflows where the full dataset does not fit in memory.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use ferrolearn_core::streaming::StreamingFitter;
11//!
12//! // Assume `model` implements PartialFit<Array2<f64>, Array1<f64>>
13//! let batches = vec![
14//!     (x_batch1, y_batch1),
15//!     (x_batch2, y_batch2),
16//! ];
17//!
18//! let fitter = StreamingFitter::new(model).n_epochs(3);
19//! let fitted = fitter.fit_batches(batches)?;
20//! let predictions = fitted.predict(&x_test)?;
21//! ```
22
23#[cfg(not(feature = "std"))]
24use alloc::vec::Vec;
25
26use crate::traits::PartialFit;
27
28/// Feeds batches from an iterator to a [`PartialFit`] model.
29///
30/// This adapter collects batches from an iterator and feeds them to a model
31/// that implements [`PartialFit`]. Multiple epochs can be specified, causing
32/// the batches to be replayed multiple times for convergence.
33///
34/// # Type Parameters
35///
36/// - `M`: The model type, which must implement [`PartialFit<X, Y>`].
37pub struct StreamingFitter<M> {
38    /// The initial (unfitted or partially fitted) model.
39    model: M,
40    /// Number of passes over the batch iterator.
41    n_epochs: usize,
42}
43
44impl<M> StreamingFitter<M> {
45    /// Create a new `StreamingFitter` wrapping the given model.
46    ///
47    /// The default number of epochs is 1 (a single pass over the data).
48    pub fn new(model: M) -> Self {
49        Self { model, n_epochs: 1 }
50    }
51
52    /// Set the number of epochs (passes over the data).
53    ///
54    /// Each epoch replays all batches in order. More epochs can improve
55    /// convergence for online learning algorithms.
56    ///
57    /// # Panics
58    ///
59    /// This method does not panic, but [`fit_batches`](StreamingFitter::fit_batches)
60    /// will return early with the initial model state if `n_epochs` is 0.
61    #[must_use]
62    pub fn n_epochs(mut self, n_epochs: usize) -> Self {
63        self.n_epochs = n_epochs;
64        self
65    }
66
67    /// Feed all batches to the model, returning the fitted result.
68    ///
69    /// The batches are collected into a `Vec` so they can be replayed
70    /// across multiple epochs. For a single epoch with a non-cloneable
71    /// iterator, use [`fit_batches_single_epoch`](StreamingFitter::fit_batches_single_epoch).
72    ///
73    /// # Errors
74    ///
75    /// Returns the first error encountered during any `partial_fit` call.
76    pub fn fit_batches<X, Y, I>(self, batches: I) -> Result<M::FitResult, M::Error>
77    where
78        M: PartialFit<X, Y>,
79        M::FitResult: PartialFit<X, Y, FitResult = M::FitResult, Error = M::Error>,
80        I: IntoIterator<Item = (X, Y)>,
81    {
82        let batches: Vec<(X, Y)> = batches.into_iter().collect();
83
84        if batches.is_empty() || self.n_epochs == 0 {
85            // Feed a zero-length sequence: we need at least one batch.
86            // This is inherent to PartialFit requiring at least one call.
87            // Return an error if there are no batches at all.
88            // However, since we can't construct a FitResult without data,
89            // we must have at least one batch.
90            return Err(self.no_batches_error());
91        }
92
93        // First epoch, first batch: transition from M to M::FitResult.
94        let mut batch_iter = batches.iter();
95        let (first_x, first_y) = batch_iter.next().unwrap();
96        let mut fitted = self.model.partial_fit(first_x, first_y)?;
97
98        // First epoch, remaining batches.
99        for (x, y) in batch_iter {
100            fitted = fitted.partial_fit(x, y)?;
101        }
102
103        // Subsequent epochs.
104        for _ in 1..self.n_epochs {
105            for (x, y) in &batches {
106                fitted = fitted.partial_fit(x, y)?;
107            }
108        }
109
110        Ok(fitted)
111    }
112
113    /// Feed batches from a single-pass iterator to the model.
114    ///
115    /// Unlike [`fit_batches`](StreamingFitter::fit_batches), this method
116    /// does not collect the batches, so it only supports a single epoch.
117    /// The `n_epochs` setting is ignored.
118    ///
119    /// # Errors
120    ///
121    /// Returns the first error encountered during any `partial_fit` call.
122    pub fn fit_batches_single_epoch<X, Y, I>(self, batches: I) -> Result<M::FitResult, M::Error>
123    where
124        M: PartialFit<X, Y>,
125        M::FitResult: PartialFit<X, Y, FitResult = M::FitResult, Error = M::Error>,
126        I: IntoIterator<Item = (X, Y)>,
127    {
128        let mut iter = batches.into_iter();
129
130        let (first_x, first_y) = match iter.next() {
131            Some(batch) => batch,
132            None => return Err(self.no_batches_error()),
133        };
134
135        let mut fitted = self.model.partial_fit(&first_x, &first_y)?;
136
137        for (x, y) in iter {
138            fitted = fitted.partial_fit(&x, &y)?;
139        }
140
141        Ok(fitted)
142    }
143}
144
145impl<M> StreamingFitter<M> {
146    /// Produce a "no batches" error. This is a helper that constructs
147    /// the appropriate error when no data is available.
148    ///
149    /// We use a trick: we need `M::Error`, but we can only get it from
150    /// a failed `partial_fit` call. Instead, we construct a simple
151    /// sentinel. This requires `M::Error: From<&str>` or similar.
152    /// Since we cannot guarantee that, we panic with a descriptive message.
153    /// This is acceptable because calling `fit_batches` with zero batches
154    /// is a programming error.
155    fn no_batches_error<E>(&self) -> E
156    where
157        E: core::fmt::Display,
158    {
159        // We cannot generically construct an arbitrary error type.
160        // Panicking here is acceptable: zero batches is a precondition violation.
161        panic!(
162            "StreamingFitter::fit_batches called with zero batches; at least one batch is required"
163        );
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::error::FerroError;
171    use crate::traits::{PartialFit, Predict};
172
173    /// A simple accumulator model for testing streaming fits.
174    /// Accumulates the sum of all values seen in each batch.
175    #[derive(Clone)]
176    struct Accumulator {
177        sum: f64,
178    }
179
180    impl Accumulator {
181        fn new() -> Self {
182            Self { sum: 0.0 }
183        }
184    }
185
186    /// The fitted version of Accumulator.
187    #[derive(Clone)]
188    struct FittedAccumulator {
189        sum: f64,
190    }
191
192    impl Predict<Vec<f64>> for FittedAccumulator {
193        type Output = f64;
194        type Error = FerroError;
195
196        fn predict(&self, x: &Vec<f64>) -> Result<f64, FerroError> {
197            // Predict: scale each input by the accumulated sum
198            Ok(x.iter().sum::<f64>() + self.sum)
199        }
200    }
201
202    impl PartialFit<Vec<f64>, Vec<f64>> for Accumulator {
203        type FitResult = FittedAccumulator;
204        type Error = FerroError;
205
206        fn partial_fit(self, x: &Vec<f64>, _y: &Vec<f64>) -> Result<FittedAccumulator, FerroError> {
207            Ok(FittedAccumulator {
208                sum: self.sum + x.iter().sum::<f64>(),
209            })
210        }
211    }
212
213    impl PartialFit<Vec<f64>, Vec<f64>> for FittedAccumulator {
214        type FitResult = FittedAccumulator;
215        type Error = FerroError;
216
217        fn partial_fit(self, x: &Vec<f64>, _y: &Vec<f64>) -> Result<FittedAccumulator, FerroError> {
218            Ok(FittedAccumulator {
219                sum: self.sum + x.iter().sum::<f64>(),
220            })
221        }
222    }
223
224    #[test]
225    fn test_streaming_single_batch() {
226        let model = Accumulator::new();
227        let fitter = StreamingFitter::new(model);
228
229        let batches = vec![(vec![1.0, 2.0, 3.0], vec![0.0])];
230
231        let fitted = fitter.fit_batches(batches).unwrap();
232        // Sum of [1, 2, 3] = 6
233        let pred = fitted.predict(&vec![0.0]).unwrap();
234        assert!((pred - 6.0).abs() < 1e-10);
235    }
236
237    #[test]
238    fn test_streaming_multiple_batches() {
239        let model = Accumulator::new();
240        let fitter = StreamingFitter::new(model);
241
242        let batches = vec![
243            (vec![1.0, 2.0], vec![0.0]),
244            (vec![3.0, 4.0], vec![0.0]),
245            (vec![5.0], vec![0.0]),
246        ];
247
248        let fitted = fitter.fit_batches(batches).unwrap();
249        // Sum = 1+2+3+4+5 = 15
250        let pred = fitted.predict(&vec![0.0]).unwrap();
251        assert!((pred - 15.0).abs() < 1e-10);
252    }
253
254    #[test]
255    fn test_streaming_multiple_epochs() {
256        let model = Accumulator::new();
257        let fitter = StreamingFitter::new(model).n_epochs(3);
258
259        let batches = vec![(vec![1.0, 2.0], vec![0.0]), (vec![3.0], vec![0.0])];
260
261        let fitted = fitter.fit_batches(batches).unwrap();
262        // Per epoch sum = 1+2+3 = 6, 3 epochs = 18
263        let pred = fitted.predict(&vec![0.0]).unwrap();
264        assert!((pred - 18.0).abs() < 1e-10);
265    }
266
267    #[test]
268    fn test_streaming_single_epoch_method() {
269        let model = Accumulator::new();
270        let fitter = StreamingFitter::new(model);
271
272        let batches = vec![(vec![10.0], vec![0.0]), (vec![20.0], vec![0.0])];
273
274        let fitted = fitter.fit_batches_single_epoch(batches).unwrap();
275        // Sum = 10 + 20 = 30
276        let pred = fitted.predict(&vec![0.0]).unwrap();
277        assert!((pred - 30.0).abs() < 1e-10);
278    }
279
280    #[test]
281    fn test_streaming_predict_after_fit() {
282        let model = Accumulator::new();
283        let fitter = StreamingFitter::new(model).n_epochs(1);
284
285        let batches = vec![(vec![5.0], vec![0.0])];
286
287        let fitted = fitter.fit_batches(batches).unwrap();
288        // Predict with input [1.0, 2.0]: result = (1+2) + 5 = 8
289        let pred = fitted.predict(&vec![1.0, 2.0]).unwrap();
290        assert!((pred - 8.0).abs() < 1e-10);
291    }
292
293    #[test]
294    #[should_panic(expected = "zero batches")]
295    fn test_streaming_empty_batches_panics() {
296        let model = Accumulator::new();
297        let fitter = StreamingFitter::new(model);
298
299        let batches: Vec<(Vec<f64>, Vec<f64>)> = vec![];
300        let _ = fitter.fit_batches(batches);
301    }
302
303    #[test]
304    fn test_streaming_fitter_builder_pattern() {
305        let fitter = StreamingFitter::new(Accumulator::new()).n_epochs(5);
306        assert_eq!(fitter.n_epochs, 5);
307    }
308}