ferrolearn_core/
streaming.rs1#[cfg(not(feature = "std"))]
24use alloc::vec::Vec;
25
26use crate::traits::PartialFit;
27
28pub struct StreamingFitter<M> {
38 model: M,
40 n_epochs: usize,
42}
43
44impl<M> StreamingFitter<M> {
45 pub fn new(model: M) -> Self {
49 Self { model, n_epochs: 1 }
50 }
51
52 #[must_use]
62 pub fn n_epochs(mut self, n_epochs: usize) -> Self {
63 self.n_epochs = n_epochs;
64 self
65 }
66
67 pub fn fit_batches<X, Y, I>(self, batches: I) -> Result<M::FitResult, M::Error>
77 where
78 M: PartialFit<X, Y>,
79 M::FitResult: PartialFit<X, Y, FitResult = M::FitResult, Error = M::Error>,
80 I: IntoIterator<Item = (X, Y)>,
81 {
82 let batches: Vec<(X, Y)> = batches.into_iter().collect();
83
84 if batches.is_empty() || self.n_epochs == 0 {
85 return Err(self.no_batches_error());
91 }
92
93 let mut batch_iter = batches.iter();
95 let (first_x, first_y) = batch_iter.next().unwrap();
96 let mut fitted = self.model.partial_fit(first_x, first_y)?;
97
98 for (x, y) in batch_iter {
100 fitted = fitted.partial_fit(x, y)?;
101 }
102
103 for _ in 1..self.n_epochs {
105 for (x, y) in &batches {
106 fitted = fitted.partial_fit(x, y)?;
107 }
108 }
109
110 Ok(fitted)
111 }
112
113 pub fn fit_batches_single_epoch<X, Y, I>(self, batches: I) -> Result<M::FitResult, M::Error>
123 where
124 M: PartialFit<X, Y>,
125 M::FitResult: PartialFit<X, Y, FitResult = M::FitResult, Error = M::Error>,
126 I: IntoIterator<Item = (X, Y)>,
127 {
128 let mut iter = batches.into_iter();
129
130 let (first_x, first_y) = match iter.next() {
131 Some(batch) => batch,
132 None => return Err(self.no_batches_error()),
133 };
134
135 let mut fitted = self.model.partial_fit(&first_x, &first_y)?;
136
137 for (x, y) in iter {
138 fitted = fitted.partial_fit(&x, &y)?;
139 }
140
141 Ok(fitted)
142 }
143}
144
145impl<M> StreamingFitter<M> {
146 fn no_batches_error<E>(&self) -> E
156 where
157 E: core::fmt::Display,
158 {
159 panic!(
162 "StreamingFitter::fit_batches called with zero batches; at least one batch is required"
163 );
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::error::FerroError;
171 use crate::traits::{PartialFit, Predict};
172
173 #[derive(Clone)]
176 struct Accumulator {
177 sum: f64,
178 }
179
180 impl Accumulator {
181 fn new() -> Self {
182 Self { sum: 0.0 }
183 }
184 }
185
186 #[derive(Clone)]
188 struct FittedAccumulator {
189 sum: f64,
190 }
191
192 impl Predict<Vec<f64>> for FittedAccumulator {
193 type Output = f64;
194 type Error = FerroError;
195
196 fn predict(&self, x: &Vec<f64>) -> Result<f64, FerroError> {
197 Ok(x.iter().sum::<f64>() + self.sum)
199 }
200 }
201
202 impl PartialFit<Vec<f64>, Vec<f64>> for Accumulator {
203 type FitResult = FittedAccumulator;
204 type Error = FerroError;
205
206 fn partial_fit(self, x: &Vec<f64>, _y: &Vec<f64>) -> Result<FittedAccumulator, FerroError> {
207 Ok(FittedAccumulator {
208 sum: self.sum + x.iter().sum::<f64>(),
209 })
210 }
211 }
212
213 impl PartialFit<Vec<f64>, Vec<f64>> for FittedAccumulator {
214 type FitResult = FittedAccumulator;
215 type Error = FerroError;
216
217 fn partial_fit(self, x: &Vec<f64>, _y: &Vec<f64>) -> Result<FittedAccumulator, FerroError> {
218 Ok(FittedAccumulator {
219 sum: self.sum + x.iter().sum::<f64>(),
220 })
221 }
222 }
223
224 #[test]
225 fn test_streaming_single_batch() {
226 let model = Accumulator::new();
227 let fitter = StreamingFitter::new(model);
228
229 let batches = vec![(vec![1.0, 2.0, 3.0], vec![0.0])];
230
231 let fitted = fitter.fit_batches(batches).unwrap();
232 let pred = fitted.predict(&vec![0.0]).unwrap();
234 assert!((pred - 6.0).abs() < 1e-10);
235 }
236
237 #[test]
238 fn test_streaming_multiple_batches() {
239 let model = Accumulator::new();
240 let fitter = StreamingFitter::new(model);
241
242 let batches = vec![
243 (vec![1.0, 2.0], vec![0.0]),
244 (vec![3.0, 4.0], vec![0.0]),
245 (vec![5.0], vec![0.0]),
246 ];
247
248 let fitted = fitter.fit_batches(batches).unwrap();
249 let pred = fitted.predict(&vec![0.0]).unwrap();
251 assert!((pred - 15.0).abs() < 1e-10);
252 }
253
254 #[test]
255 fn test_streaming_multiple_epochs() {
256 let model = Accumulator::new();
257 let fitter = StreamingFitter::new(model).n_epochs(3);
258
259 let batches = vec![(vec![1.0, 2.0], vec![0.0]), (vec![3.0], vec![0.0])];
260
261 let fitted = fitter.fit_batches(batches).unwrap();
262 let pred = fitted.predict(&vec![0.0]).unwrap();
264 assert!((pred - 18.0).abs() < 1e-10);
265 }
266
267 #[test]
268 fn test_streaming_single_epoch_method() {
269 let model = Accumulator::new();
270 let fitter = StreamingFitter::new(model);
271
272 let batches = vec![(vec![10.0], vec![0.0]), (vec![20.0], vec![0.0])];
273
274 let fitted = fitter.fit_batches_single_epoch(batches).unwrap();
275 let pred = fitted.predict(&vec![0.0]).unwrap();
277 assert!((pred - 30.0).abs() < 1e-10);
278 }
279
280 #[test]
281 fn test_streaming_predict_after_fit() {
282 let model = Accumulator::new();
283 let fitter = StreamingFitter::new(model).n_epochs(1);
284
285 let batches = vec![(vec![5.0], vec![0.0])];
286
287 let fitted = fitter.fit_batches(batches).unwrap();
288 let pred = fitted.predict(&vec![1.0, 2.0]).unwrap();
290 assert!((pred - 8.0).abs() < 1e-10);
291 }
292
293 #[test]
294 #[should_panic(expected = "zero batches")]
295 fn test_streaming_empty_batches_panics() {
296 let model = Accumulator::new();
297 let fitter = StreamingFitter::new(model);
298
299 let batches: Vec<(Vec<f64>, Vec<f64>)> = vec![];
300 let _ = fitter.fit_batches(batches);
301 }
302
303 #[test]
304 fn test_streaming_fitter_builder_pattern() {
305 let fitter = StreamingFitter::new(Accumulator::new()).n_epochs(5);
306 assert_eq!(fitter.n_epochs, 5);
307 }
308}