1use crate::dataset::Dataset;
21use crate::error::{Result, ScryLearnError};
22use crate::neural::activation::Activation;
23use crate::neural::callback::{
24 self, CallbackAction, EpochMetrics, TrainingCallback, TrainingHistory,
25};
26use crate::neural::layer::FastRng;
27use crate::neural::network::{self, Network};
28use crate::neural::optimizer::{LearningRateSchedule, OptimizerKind, OptimizerState};
29use crate::partial_fit::PartialFit;
30
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39#[non_exhaustive]
40pub struct MLPClassifier {
41 hidden_layers: Vec<usize>,
42 activation: Activation,
43 optimizer_kind: OptimizerKind,
44 learning_rate: f64,
45 max_iter: usize,
46 batch_size: usize,
47 alpha: f64,
48 tolerance: f64,
49 early_stopping: bool,
50 validation_fraction: f64,
51 n_iter_no_change: usize,
52 seed: u64,
53 dropout_rate: f64,
55 lr_schedule: LearningRateSchedule,
57 fitted: bool,
59 n_features: usize,
60 n_classes: usize,
61 class_labels: Vec<f64>,
62 network_weights: Vec<(Vec<f64>, Vec<f64>)>,
63 network_dims: Vec<(usize, usize)>,
64 pub loss_curve: Vec<f64>,
66 training_history: TrainingHistory,
68 #[cfg_attr(feature = "serde", serde(skip))]
70 callbacks: Vec<Box<dyn TrainingCallback>>,
71 #[cfg_attr(feature = "serde", serde(default))]
72 _schema_version: u32,
73}
74
75impl Clone for MLPClassifier {
76 fn clone(&self) -> Self {
77 Self {
78 hidden_layers: self.hidden_layers.clone(),
79 activation: self.activation,
80 optimizer_kind: self.optimizer_kind,
81 learning_rate: self.learning_rate,
82 max_iter: self.max_iter,
83 batch_size: self.batch_size,
84 alpha: self.alpha,
85 tolerance: self.tolerance,
86 early_stopping: self.early_stopping,
87 validation_fraction: self.validation_fraction,
88 n_iter_no_change: self.n_iter_no_change,
89 seed: self.seed,
90 dropout_rate: self.dropout_rate,
91 lr_schedule: self.lr_schedule,
92 fitted: self.fitted,
93 n_features: self.n_features,
94 n_classes: self.n_classes,
95 class_labels: self.class_labels.clone(),
96 network_weights: self.network_weights.clone(),
97 network_dims: self.network_dims.clone(),
98 loss_curve: self.loss_curve.clone(),
99 training_history: self.training_history.clone(),
100 callbacks: Vec::new(),
102 _schema_version: 0,
103 }
104 }
105}
106
107impl MLPClassifier {
108 pub fn new() -> Self {
110 Self {
111 hidden_layers: vec![100],
112 activation: Activation::Relu,
113 optimizer_kind: OptimizerKind::default(),
114 learning_rate: 0.001,
115 max_iter: 200,
116 batch_size: 200,
117 alpha: 0.0001,
118 tolerance: 1e-4,
119 early_stopping: false,
120 validation_fraction: 0.1,
121 n_iter_no_change: 10,
122 seed: 42,
123 dropout_rate: 0.0,
124 lr_schedule: LearningRateSchedule::Constant,
125 fitted: false,
126 n_features: 0,
127 n_classes: 0,
128 class_labels: Vec::new(),
129 network_weights: Vec::new(),
130 network_dims: Vec::new(),
131 loss_curve: Vec::new(),
132 training_history: TrainingHistory::new(),
133 callbacks: Vec::new(),
134 _schema_version: 0,
135 }
136 }
137
138 pub fn hidden_layers(mut self, sizes: &[usize]) -> Self {
140 self.hidden_layers = sizes.to_vec();
141 self
142 }
143
144 pub fn activation(mut self, activation: Activation) -> Self {
146 self.activation = activation;
147 self
148 }
149
150 pub fn optimizer(mut self, kind: OptimizerKind) -> Self {
152 self.optimizer_kind = kind;
153 self
154 }
155
156 pub fn learning_rate(mut self, lr: f64) -> Self {
158 self.learning_rate = lr;
159 self
160 }
161
162 pub fn max_iter(mut self, n: usize) -> Self {
164 self.max_iter = n;
165 self
166 }
167
168 pub fn batch_size(mut self, n: usize) -> Self {
170 self.batch_size = n;
171 self
172 }
173
174 pub fn alpha(mut self, a: f64) -> Self {
176 self.alpha = a;
177 self
178 }
179
180 pub fn tolerance(mut self, tol: f64) -> Self {
182 self.tolerance = tol;
183 self
184 }
185
186 pub fn tol(self, t: f64) -> Self {
188 self.tolerance(t)
189 }
190
191 pub fn early_stopping(mut self, enable: bool) -> Self {
193 self.early_stopping = enable;
194 self
195 }
196
197 pub fn validation_fraction(mut self, frac: f64) -> Self {
199 self.validation_fraction = frac;
200 self
201 }
202
203 pub fn n_iter_no_change(mut self, n: usize) -> Self {
205 self.n_iter_no_change = n;
206 self
207 }
208
209 pub fn seed(mut self, s: u64) -> Self {
211 self.seed = s;
212 self
213 }
214
215 pub fn learning_rate_schedule(mut self, schedule: LearningRateSchedule) -> Self {
219 self.lr_schedule = schedule;
220 self
221 }
222
223 pub fn dropout(mut self, p: f64) -> Self {
229 self.dropout_rate = p;
230 self
231 }
232
233 pub fn callback(mut self, cb: Box<dyn TrainingCallback>) -> Self {
235 self.callbacks.push(cb);
236 self
237 }
238
239 pub fn fit(&mut self, data: &Dataset) -> Result<()> {
241 data.validate_finite()?;
242 let n_samples = data.n_samples();
243 let n_features = data.n_features();
244
245 if n_samples == 0 {
246 return Err(ScryLearnError::EmptyDataset);
247 }
248
249 let mut class_labels: Vec<f64> = data.target.clone();
251 class_labels.sort_by(|a, b| a.total_cmp(b));
252 class_labels.dedup();
253 let n_classes = class_labels.len();
254
255 if n_classes < 2 {
256 return Err(ScryLearnError::InvalidParameter(
257 "need at least 2 classes".into(),
258 ));
259 }
260
261 let x = build_row_major(&data.features, n_samples, n_features);
263
264 let y: Vec<f64> = data
266 .target
267 .iter()
268 .map(|&t| {
269 class_labels
270 .iter()
271 .position(|&c| (c - t).abs() < f64::EPSILON)
272 .expect("target value must appear in class_labels") as f64
273 })
274 .collect();
275
276 let (train_x, train_y, val_x, val_y) = if self.early_stopping {
278 let mut rng = FastRng::new(self.seed);
279 let val_size = (n_samples as f64 * self.validation_fraction).max(1.0) as usize;
280 let train_size = n_samples - val_size;
281 let mut indices: Vec<usize> = (0..n_samples).collect();
282 rng.shuffle(&mut indices);
283
284 let mut tx = Vec::with_capacity(train_size * n_features);
285 let mut ty = Vec::with_capacity(train_size);
286 let mut vx = Vec::with_capacity(val_size * n_features);
287 let mut vy = Vec::with_capacity(val_size);
288
289 for &i in &indices[..train_size] {
290 tx.extend_from_slice(&x[i * n_features..(i + 1) * n_features]);
291 ty.push(y[i]);
292 }
293 for &i in &indices[train_size..] {
294 vx.extend_from_slice(&x[i * n_features..(i + 1) * n_features]);
295 vy.push(y[i]);
296 }
297 (tx, ty, Some(vx), Some(vy))
298 } else {
299 (x, y, None, None)
300 };
301
302 let train_n = train_y.len();
303
304 let mut sizes = Vec::with_capacity(self.hidden_layers.len() + 2);
306 sizes.push(n_features);
307 sizes.extend_from_slice(&self.hidden_layers);
308 sizes.push(n_classes);
309
310 let mut net =
311 Network::new_with_dropout(&sizes, self.activation, self.seed, self.dropout_rate);
312 let param_sizes = net.param_group_sizes();
313 let mut optimizer = OptimizerState::new_with_schedule(
314 self.optimizer_kind,
315 self.learning_rate,
316 ¶m_sizes,
317 self.lr_schedule,
318 );
319
320 let batch_size = self.batch_size.min(train_n);
321 let mut rng = FastRng::new(self.seed.wrapping_add(1));
322 let mut indices: Vec<usize> = (0..train_n).collect();
323
324 self.loss_curve.clear();
325 self.training_history = TrainingHistory::new();
326 let mut best_val_loss = f64::INFINITY;
327 let mut best_weights: Option<Vec<(Vec<f64>, Vec<f64>)>> = None;
328 let mut no_improve = 0;
329
330 let mut callbacks = std::mem::take(&mut self.callbacks);
332
333 for epoch_idx in 0..self.max_iter {
334 let epoch_start = std::time::Instant::now();
335 rng.shuffle(&mut indices);
336
337 let mut epoch_loss = 0.0;
338 let mut n_batches = 0;
339 let mut last_grad_norm = 0.0;
340 let mut epoch_correct = 0usize;
341 let mut epoch_total = 0usize;
342
343 for chunk in indices.chunks(batch_size) {
344 let b = chunk.len();
345 let mut batch_x = Vec::with_capacity(b * n_features);
346 let mut batch_y = Vec::with_capacity(b);
347 for &i in chunk {
348 batch_x.extend_from_slice(&train_x[i * n_features..(i + 1) * n_features]);
349 batch_y.push(train_y[i]);
350 }
351
352 let logits = net.forward(&batch_x, b, true);
353 let (loss, grad) = network::cross_entropy_loss(&logits, &batch_y, b, n_classes);
354 epoch_loss += loss;
355 n_batches += 1;
356
357 let preds = network::argmax_predictions(&logits, b, n_classes);
359 for (p, t) in preds.iter().zip(batch_y.iter()) {
360 if (*p - *t).abs() < f64::EPSILON {
361 epoch_correct += 1;
362 }
363 epoch_total += 1;
364 }
365
366 let layer_grads = net.backward(&grad, self.alpha);
367 last_grad_norm = callback::compute_grad_norm(&layer_grads);
368 optimizer.tick();
369 net.apply_gradients(&layer_grads, &mut optimizer);
370 }
371
372 let avg_loss = epoch_loss / n_batches as f64;
373 self.loss_curve.push(avg_loss);
374
375 optimizer.adjust_lr(avg_loss);
377
378 let train_accuracy = if epoch_total > 0 {
379 Some(epoch_correct as f64 / epoch_total as f64)
380 } else {
381 None
382 };
383
384 let mut val_loss_epoch = None;
386 let mut val_metric_epoch = None;
387
388 if self.early_stopping {
389 if let (Some(ref vx), Some(ref vy)) = (&val_x, &val_y) {
390 let val_n = vy.len();
391 let val_logits = net.forward(vx, val_n, false);
392 let (val_loss, _) =
393 network::cross_entropy_loss(&val_logits, vy, val_n, n_classes);
394 val_loss_epoch = Some(val_loss);
395
396 let val_preds = network::argmax_predictions(&val_logits, val_n, n_classes);
398 let val_correct = val_preds
399 .iter()
400 .zip(vy.iter())
401 .filter(|(p, t)| (**p - **t).abs() < f64::EPSILON)
402 .count();
403 val_metric_epoch = Some(val_correct as f64 / val_n as f64);
404
405 if val_loss < best_val_loss - self.tolerance {
406 best_val_loss = val_loss;
407 best_weights = Some(net.save_weights());
408 no_improve = 0;
409 } else {
410 no_improve += 1;
411 }
412 }
413 } else {
414 let n = self.loss_curve.len();
416 if n >= 2 {
417 let improvement = self.loss_curve[n - 2] - self.loss_curve[n - 1];
418 if improvement.abs() < self.tolerance {
419 no_improve += 1;
420 } else {
421 no_improve = 0;
422 }
423 }
424 }
425
426 let elapsed = epoch_start.elapsed();
427 let metrics = EpochMetrics {
428 epoch: epoch_idx,
429 train_loss: avg_loss,
430 val_loss: val_loss_epoch,
431 train_metric: train_accuracy,
432 val_metric: val_metric_epoch,
433 learning_rate: optimizer.current_lr(),
434 grad_norm: last_grad_norm,
435 elapsed_ms: elapsed.as_millis() as u64,
436 };
437
438 let mut cb_stop = false;
440 for cb in &mut callbacks {
441 if cb.on_epoch_end(&metrics) == CallbackAction::Stop {
442 cb_stop = true;
443 }
444 }
445
446 self.training_history.push(metrics);
447
448 if cb_stop {
449 break;
450 }
451
452 if no_improve >= self.n_iter_no_change
453 && (self.early_stopping || self.loss_curve.len() >= 2)
454 {
455 break;
456 }
457 }
458
459 for cb in &mut callbacks {
461 cb.on_training_end();
462 }
463 self.callbacks = callbacks;
464
465 if let Some(ref best) = best_weights {
467 net.restore_weights(best);
468 }
469
470 self.network_weights = net.save_weights();
472 self.network_dims = net.layer_dims();
473 self.n_features = n_features;
474 self.n_classes = n_classes;
475 self.class_labels = class_labels;
476 self.fitted = true;
477
478 Ok(())
479 }
480
481 pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
485 let proba = self.predict_proba(features)?;
486 let batch = features.len();
487 let preds = network::argmax_predictions(&proba, batch, self.n_classes);
488 Ok(preds
490 .iter()
491 .map(|&i| self.class_labels[i as usize])
492 .collect())
493 }
494
495 pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
499 if !self.fitted {
500 return Err(ScryLearnError::NotFitted);
501 }
502
503 let batch = features.len();
504 if batch == 0 {
505 return Ok(Vec::new());
506 }
507
508 let n_feat = features[0].len();
509 if n_feat != self.n_features {
510 return Err(ScryLearnError::ShapeMismatch {
511 expected: self.n_features,
512 got: n_feat,
513 });
514 }
515
516 let mut net = self.rebuild_network();
517 let x: Vec<f64> = features
518 .iter()
519 .flat_map(|row| row.iter().copied())
520 .collect();
521 let logits = net.forward(&x, batch, false);
522 Ok(network::softmax(&logits, batch, self.n_classes))
523 }
524
525 pub fn n_classes(&self) -> usize {
527 self.n_classes
528 }
529
530 pub fn n_features(&self) -> usize {
532 self.n_features
533 }
534
535 pub fn loss_curve(&self) -> &[f64] {
537 &self.loss_curve
538 }
539
540 pub fn history(&self) -> Option<&TrainingHistory> {
544 if self.training_history.is_empty() {
545 None
546 } else {
547 Some(&self.training_history)
548 }
549 }
550
551 pub fn weights(&self) -> &[(Vec<f64>, Vec<f64>)] {
553 &self.network_weights
554 }
555
556 pub fn layer_dims(&self) -> &[(usize, usize)] {
558 &self.network_dims
559 }
560
561 pub fn activation_fn(&self) -> Activation {
563 self.activation
564 }
565
566 fn rebuild_network(&self) -> Network {
568 let mut sizes = Vec::with_capacity(self.network_dims.len() + 1);
569 sizes.push(self.network_dims[0].0);
570 for &(_, out) in &self.network_dims {
571 sizes.push(out);
572 }
573 let mut net = Network::new_with_dropout(&sizes, self.activation, 0, self.dropout_rate);
574 net.restore_weights(&self.network_weights);
575 net
576 }
577}
578
579impl PartialFit for MLPClassifier {
580 fn partial_fit(&mut self, data: &Dataset) -> Result<()> {
586 let n_samples = data.n_samples();
587 let n_features = data.n_features();
588 if n_samples == 0 {
589 if self.is_initialized() {
590 return Ok(());
591 }
592 return Err(ScryLearnError::EmptyDataset);
593 }
594
595 let mut batch_labels: Vec<f64> = data.target.clone();
597 batch_labels.sort_by(|a, b| a.total_cmp(b));
598 batch_labels.dedup();
599
600 if self.is_initialized() {
601 if n_features != self.n_features {
602 return Err(ScryLearnError::ShapeMismatch {
603 expected: self.n_features,
604 got: n_features,
605 });
606 }
607 for &label in &batch_labels {
609 if !self
610 .class_labels
611 .iter()
612 .any(|&c| (c - label).abs() < f64::EPSILON)
613 {
614 return Err(ScryLearnError::InvalidParameter(format!(
615 "partial_fit encountered new class {label} not seen during \
616 initialization (known classes: {:?}). MLPClassifier cannot add \
617 classes after network initialization — pass all possible classes \
618 in the first batch.",
619 self.class_labels
620 )));
621 }
622 }
623 } else {
624 let n_classes = batch_labels.len();
625 if n_classes < 2 {
626 return Err(ScryLearnError::InvalidParameter(
627 "need at least 2 classes".into(),
628 ));
629 }
630
631 let mut sizes = Vec::with_capacity(self.hidden_layers.len() + 2);
633 sizes.push(n_features);
634 sizes.extend_from_slice(&self.hidden_layers);
635 sizes.push(n_classes);
636
637 let net = Network::new(&sizes, self.activation, self.seed);
638 self.network_weights = net.save_weights();
639 self.network_dims = net.layer_dims();
640 self.n_features = n_features;
641 self.n_classes = n_classes;
642 self.class_labels = batch_labels;
643 self.loss_curve.clear();
644 }
645
646 let x = build_row_major(&data.features, n_samples, n_features);
648 let y: Vec<f64> = data
649 .target
650 .iter()
651 .map(|&t| {
652 self.class_labels
653 .iter()
654 .position(|&c| (c - t).abs() < f64::EPSILON)
655 .unwrap_or(0) as f64
656 })
657 .collect();
658
659 let mut net = self.rebuild_network();
661 let param_sizes = net.param_group_sizes();
662 let mut optimizer =
663 OptimizerState::new(self.optimizer_kind, self.learning_rate, ¶m_sizes);
664
665 let batch_size = self.batch_size.min(n_samples);
666 let mut rng = FastRng::new(self.seed.wrapping_add(self.loss_curve.len() as u64));
667 let mut indices: Vec<usize> = (0..n_samples).collect();
668
669 rng.shuffle(&mut indices);
671 let mut epoch_loss = 0.0;
672 let mut n_batches = 0;
673
674 for chunk in indices.chunks(batch_size) {
675 let b = chunk.len();
676 let mut batch_x = Vec::with_capacity(b * n_features);
677 let mut batch_y = Vec::with_capacity(b);
678 for &i in chunk {
679 batch_x.extend_from_slice(&x[i * n_features..(i + 1) * n_features]);
680 batch_y.push(y[i]);
681 }
682
683 let logits = net.forward(&batch_x, b, true);
684 let (loss, grad) = network::cross_entropy_loss(&logits, &batch_y, b, self.n_classes);
685 epoch_loss += loss;
686 n_batches += 1;
687
688 let layer_grads = net.backward(&grad, self.alpha);
689 optimizer.tick();
690 net.apply_gradients(&layer_grads, &mut optimizer);
691 }
692
693 self.loss_curve.push(epoch_loss / n_batches as f64);
694 self.network_weights = net.save_weights();
695 self.fitted = true;
696 Ok(())
697 }
698
699 fn is_initialized(&self) -> bool {
700 !self.network_weights.is_empty()
701 }
702}
703
704impl Default for MLPClassifier {
705 fn default() -> Self {
706 Self::new()
707 }
708}
709
710impl std::fmt::Debug for MLPClassifier {
711 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
712 f.debug_struct("MLPClassifier")
713 .field("hidden_layers", &self.hidden_layers)
714 .field("activation", &self.activation)
715 .field("fitted", &self.fitted)
716 .field("n_classes", &self.n_classes)
717 .finish()
718 }
719}
720
721fn build_row_major(features: &[Vec<f64>], n_samples: usize, n_features: usize) -> Vec<f64> {
723 let mut x = vec![0.0; n_samples * n_features];
724 for j in 0..n_features {
725 for i in 0..n_samples {
726 x[i * n_features + j] = features[j][i];
727 }
728 }
729 x
730}
731
732#[cfg(test)]
733mod tests {
734 use super::*;
735
736 fn xor_dataset() -> Dataset {
737 Dataset::new(
738 vec![vec![0.0, 0.0, 1.0, 1.0], vec![0.0, 1.0, 0.0, 1.0]],
739 vec![0.0, 1.0, 1.0, 0.0],
740 vec!["x1".into(), "x2".into()],
741 "xor",
742 )
743 }
744
745 fn linearly_separable() -> Dataset {
746 let mut f1 = Vec::new();
747 let mut f2 = Vec::new();
748 let mut target = Vec::new();
749 for i in 0..50 {
750 let v = i as f64 * 0.1;
751 f1.push(v);
752 f2.push(v + 0.5);
753 target.push(0.0);
754 f1.push(v + 5.0);
755 f2.push(v + 5.5);
756 target.push(1.0);
757 }
758 Dataset::new(
759 vec![f1, f2],
760 target,
761 vec!["f1".into(), "f2".into()],
762 "class",
763 )
764 }
765
766 #[test]
767 fn not_fitted_error() {
768 let clf = MLPClassifier::new();
769 let result = clf.predict(&[vec![1.0, 2.0]]);
770 assert!(matches!(result, Err(ScryLearnError::NotFitted)));
771 }
772
773 #[test]
774 fn xor_problem() {
775 let data = xor_dataset();
777 let mut clf = MLPClassifier::new()
778 .hidden_layers(&[10, 10])
779 .learning_rate(0.01)
780 .max_iter(1000)
781 .batch_size(4)
782 .seed(42);
783 clf.fit(&data).unwrap();
784
785 let preds = clf
786 .predict(&[
787 vec![0.0, 0.0],
788 vec![0.0, 1.0],
789 vec![1.0, 0.0],
790 vec![1.0, 1.0],
791 ])
792 .unwrap();
793
794 let correct = preds
795 .iter()
796 .zip([0.0, 1.0, 1.0, 0.0].iter())
797 .filter(|(p, t)| (**p - **t).abs() < f64::EPSILON)
798 .count();
799
800 assert!(
801 correct >= 3,
802 "XOR: got {correct}/4 correct, preds={preds:?}"
803 );
804 }
805
806 #[test]
807 fn linearly_separable_data() {
808 let data = linearly_separable();
809 let mut clf = MLPClassifier::new()
810 .hidden_layers(&[20])
811 .max_iter(200)
812 .seed(42);
813 clf.fit(&data).unwrap();
814
815 let test_x = vec![vec![0.5, 1.0], vec![5.5, 6.0]];
816 let preds = clf.predict(&test_x).unwrap();
817 assert!((preds[0] - 0.0).abs() < f64::EPSILON);
818 assert!((preds[1] - 1.0).abs() < f64::EPSILON);
819 }
820
821 #[test]
822 fn early_stopping_halts() {
823 let data = linearly_separable();
824 let mut clf = MLPClassifier::new()
825 .hidden_layers(&[20])
826 .max_iter(500)
827 .early_stopping(true)
828 .n_iter_no_change(5)
829 .seed(42);
830 clf.fit(&data).unwrap();
831
832 assert!(
834 clf.loss_curve.len() < 500,
835 "expected early stop, got {} epochs",
836 clf.loss_curve.len()
837 );
838 }
839
840 #[test]
841 fn predict_proba_sums_to_one() {
842 let data = linearly_separable();
843 let mut clf = MLPClassifier::new()
844 .hidden_layers(&[10])
845 .max_iter(50)
846 .seed(42);
847 clf.fit(&data).unwrap();
848
849 let proba = clf.predict_proba(&[vec![1.0, 1.5]]).unwrap();
850 let sum: f64 = proba.iter().sum();
851 assert!((sum - 1.0).abs() < 1e-6);
852 }
853
854 #[test]
855 fn shape_mismatch_error() {
856 let data = linearly_separable();
857 let mut clf = MLPClassifier::new()
858 .hidden_layers(&[10])
859 .max_iter(10)
860 .seed(42);
861 clf.fit(&data).unwrap();
862
863 let result = clf.predict(&[vec![1.0, 2.0, 3.0]]); assert!(matches!(result, Err(ScryLearnError::ShapeMismatch { .. })));
865 }
866
867 #[test]
868 fn loss_decreases() {
869 let data = linearly_separable();
870 let mut clf = MLPClassifier::new()
871 .hidden_layers(&[20])
872 .max_iter(50)
873 .seed(42);
874 clf.fit(&data).unwrap();
875
876 let curve = clf.loss_curve();
877 assert!(curve.len() >= 2);
878 assert!(curve.first().unwrap() > curve.last().unwrap());
880 }
881
882 #[test]
883 fn partial_fit_is_initialized() {
884 let mut clf = MLPClassifier::new();
885 assert!(!clf.is_initialized());
886
887 let data = linearly_separable();
888 clf.partial_fit(&data).unwrap();
889 assert!(clf.is_initialized());
890 }
891
892 #[test]
893 fn partial_fit_loss_decreases() {
894 let data = linearly_separable();
895 let mut clf = MLPClassifier::new()
896 .hidden_layers(&[20])
897 .learning_rate(0.01)
898 .batch_size(32)
899 .seed(42);
900
901 for _ in 0..10 {
903 clf.partial_fit(&data).unwrap();
904 }
905
906 let curve = clf.loss_curve();
907 assert!(curve.len() == 10);
908 assert!(
910 curve.first().unwrap() > curve.last().unwrap(),
911 "loss should decrease: first={} last={}",
912 curve.first().unwrap(),
913 curve.last().unwrap()
914 );
915 }
916
917 #[test]
918 fn partial_fit_classifies_after_batches() {
919 let mut clf = MLPClassifier::new()
920 .hidden_layers(&[20])
921 .learning_rate(0.01)
922 .batch_size(32)
923 .seed(42);
924
925 let data = linearly_separable();
926 for _ in 0..50 {
927 clf.partial_fit(&data).unwrap();
928 }
929
930 let preds = clf.predict(&[vec![0.5, 1.0], vec![5.5, 6.0]]).unwrap();
931 assert!(
932 (preds[0] - 0.0).abs() < f64::EPSILON,
933 "x=0.5 should be class 0"
934 );
935 assert!(
936 (preds[1] - 1.0).abs() < f64::EPSILON,
937 "x=5.5 should be class 1"
938 );
939 }
940}