#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::traits::PartialFit;
pub struct StreamingFitter<M> {
model: M,
n_epochs: usize,
}
impl<M> StreamingFitter<M> {
pub fn new(model: M) -> Self {
Self { model, n_epochs: 1 }
}
#[must_use]
pub fn n_epochs(mut self, n_epochs: usize) -> Self {
self.n_epochs = n_epochs;
self
}
pub fn fit_batches<X, Y, I>(self, batches: I) -> Result<M::FitResult, M::Error>
where
M: PartialFit<X, Y>,
M::FitResult: PartialFit<X, Y, FitResult = M::FitResult, Error = M::Error>,
I: IntoIterator<Item = (X, Y)>,
{
let batches: Vec<(X, Y)> = batches.into_iter().collect();
if batches.is_empty() || self.n_epochs == 0 {
return Err(self.no_batches_error());
}
let mut batch_iter = batches.iter();
let (first_x, first_y) = batch_iter.next().unwrap();
let mut fitted = self.model.partial_fit(first_x, first_y)?;
for (x, y) in batch_iter {
fitted = fitted.partial_fit(x, y)?;
}
for _ in 1..self.n_epochs {
for (x, y) in &batches {
fitted = fitted.partial_fit(x, y)?;
}
}
Ok(fitted)
}
pub fn fit_batches_single_epoch<X, Y, I>(self, batches: I) -> Result<M::FitResult, M::Error>
where
M: PartialFit<X, Y>,
M::FitResult: PartialFit<X, Y, FitResult = M::FitResult, Error = M::Error>,
I: IntoIterator<Item = (X, Y)>,
{
let mut iter = batches.into_iter();
let (first_x, first_y) = match iter.next() {
Some(batch) => batch,
None => return Err(self.no_batches_error()),
};
let mut fitted = self.model.partial_fit(&first_x, &first_y)?;
for (x, y) in iter {
fitted = fitted.partial_fit(&x, &y)?;
}
Ok(fitted)
}
}
impl<M> StreamingFitter<M> {
fn no_batches_error<E>(&self) -> E
where
E: core::fmt::Display,
{
panic!(
"StreamingFitter::fit_batches called with zero batches; at least one batch is required"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::FerroError;
use crate::traits::{PartialFit, Predict};
#[derive(Clone)]
struct Accumulator {
sum: f64,
}
impl Accumulator {
fn new() -> Self {
Self { sum: 0.0 }
}
}
#[derive(Clone)]
struct FittedAccumulator {
sum: f64,
}
impl Predict<Vec<f64>> for FittedAccumulator {
type Output = f64;
type Error = FerroError;
fn predict(&self, x: &Vec<f64>) -> Result<f64, FerroError> {
Ok(x.iter().sum::<f64>() + self.sum)
}
}
impl PartialFit<Vec<f64>, Vec<f64>> for Accumulator {
type FitResult = FittedAccumulator;
type Error = FerroError;
fn partial_fit(self, x: &Vec<f64>, _y: &Vec<f64>) -> Result<FittedAccumulator, FerroError> {
Ok(FittedAccumulator {
sum: self.sum + x.iter().sum::<f64>(),
})
}
}
impl PartialFit<Vec<f64>, Vec<f64>> for FittedAccumulator {
type FitResult = FittedAccumulator;
type Error = FerroError;
fn partial_fit(self, x: &Vec<f64>, _y: &Vec<f64>) -> Result<FittedAccumulator, FerroError> {
Ok(FittedAccumulator {
sum: self.sum + x.iter().sum::<f64>(),
})
}
}
#[test]
fn test_streaming_single_batch() {
let model = Accumulator::new();
let fitter = StreamingFitter::new(model);
let batches = vec![(vec![1.0, 2.0, 3.0], vec![0.0])];
let fitted = fitter.fit_batches(batches).unwrap();
let pred = fitted.predict(&vec![0.0]).unwrap();
assert!((pred - 6.0).abs() < 1e-10);
}
#[test]
fn test_streaming_multiple_batches() {
let model = Accumulator::new();
let fitter = StreamingFitter::new(model);
let batches = vec![
(vec![1.0, 2.0], vec![0.0]),
(vec![3.0, 4.0], vec![0.0]),
(vec![5.0], vec![0.0]),
];
let fitted = fitter.fit_batches(batches).unwrap();
let pred = fitted.predict(&vec![0.0]).unwrap();
assert!((pred - 15.0).abs() < 1e-10);
}
#[test]
fn test_streaming_multiple_epochs() {
let model = Accumulator::new();
let fitter = StreamingFitter::new(model).n_epochs(3);
let batches = vec![(vec![1.0, 2.0], vec![0.0]), (vec![3.0], vec![0.0])];
let fitted = fitter.fit_batches(batches).unwrap();
let pred = fitted.predict(&vec![0.0]).unwrap();
assert!((pred - 18.0).abs() < 1e-10);
}
#[test]
fn test_streaming_single_epoch_method() {
let model = Accumulator::new();
let fitter = StreamingFitter::new(model);
let batches = vec![(vec![10.0], vec![0.0]), (vec![20.0], vec![0.0])];
let fitted = fitter.fit_batches_single_epoch(batches).unwrap();
let pred = fitted.predict(&vec![0.0]).unwrap();
assert!((pred - 30.0).abs() < 1e-10);
}
#[test]
fn test_streaming_predict_after_fit() {
let model = Accumulator::new();
let fitter = StreamingFitter::new(model).n_epochs(1);
let batches = vec![(vec![5.0], vec![0.0])];
let fitted = fitter.fit_batches(batches).unwrap();
let pred = fitted.predict(&vec![1.0, 2.0]).unwrap();
assert!((pred - 8.0).abs() < 1e-10);
}
#[test]
#[should_panic(expected = "zero batches")]
fn test_streaming_empty_batches_panics() {
let model = Accumulator::new();
let fitter = StreamingFitter::new(model);
let batches: Vec<(Vec<f64>, Vec<f64>)> = vec![];
let _ = fitter.fit_batches(batches);
}
#[test]
fn test_streaming_fitter_builder_pattern() {
let fitter = StreamingFitter::new(Accumulator::new()).n_epochs(5);
assert_eq!(fitter.n_epochs, 5);
}
}