1use ferrolearn_core::error::FerroError;
49use ferrolearn_core::introspection::HasClasses;
50use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
51use ferrolearn_core::traits::{Fit, Predict};
52use ndarray::{Array1, Array2};
53use num_traits::{Float, FromPrimitive, ToPrimitive};
54
55#[derive(Debug, Clone)]
64pub struct ComplementNB<F> {
65 pub alpha: F,
67 pub class_prior: Option<Vec<F>>,
71 pub fit_prior: bool,
75 pub force_alpha: bool,
78 pub norm: bool,
82}
83
84impl<F: Float> ComplementNB<F> {
85 #[must_use]
87 pub fn new() -> Self {
88 Self {
89 alpha: F::one(),
90 class_prior: None,
91 fit_prior: true,
92 force_alpha: true,
93 norm: false,
94 }
95 }
96
97 #[must_use]
99 pub fn with_alpha(mut self, alpha: F) -> Self {
100 self.alpha = alpha;
101 self
102 }
103
104 #[must_use]
110 pub fn with_class_prior(mut self, priors: Vec<F>) -> Self {
111 self.class_prior = Some(priors);
112 self
113 }
114
115 #[must_use]
117 pub fn with_fit_prior(mut self, fit_prior: bool) -> Self {
118 self.fit_prior = fit_prior;
119 self
120 }
121
122 #[must_use]
124 pub fn with_force_alpha(mut self, force_alpha: bool) -> Self {
125 self.force_alpha = force_alpha;
126 self
127 }
128
129 #[must_use]
132 pub fn with_norm(mut self, norm: bool) -> Self {
133 self.norm = norm;
134 self
135 }
136}
137
138impl<F: Float> Default for ComplementNB<F> {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct FittedComplementNB<F> {
147 classes: Vec<usize>,
149 weights: Array2<F>,
152 feature_counts: Array2<F>,
154 class_counts: Vec<usize>,
156 alpha: F,
159 norm: bool,
162}
163
164impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ComplementNB<F> {
165 type Fitted = FittedComplementNB<F>;
166 type Error = FerroError;
167
168 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedComplementNB<F>, FerroError> {
176 let (n_samples, n_features) = x.dim();
177
178 if n_samples == 0 {
179 return Err(FerroError::InsufficientSamples {
180 required: 1,
181 actual: 0,
182 context: "ComplementNB requires at least one sample".into(),
183 });
184 }
185
186 if n_samples != y.len() {
187 return Err(FerroError::ShapeMismatch {
188 expected: vec![n_samples],
189 actual: vec![y.len()],
190 context: "y length must match number of samples in X".into(),
191 });
192 }
193
194 if x.iter().any(|&v| v < F::zero()) {
196 return Err(FerroError::InvalidParameter {
197 name: "X".into(),
198 reason: "ComplementNB requires non-negative feature values".into(),
199 });
200 }
201
202 let mut classes: Vec<usize> = y.to_vec();
204 classes.sort_unstable();
205 classes.dedup();
206 let n_classes = classes.len();
207
208 let n_feat_f = F::from(n_features).unwrap();
209 let alpha = crate::clamp_alpha(self.alpha, self.force_alpha);
210
211 let mut class_feature_counts = Array2::<F>::zeros((n_classes, n_features));
213 let mut class_counts = vec![0usize; n_classes];
214
215 for (sample_idx, &label) in y.iter().enumerate() {
216 let ci = classes.iter().position(|&c| c == label).unwrap();
217 class_counts[ci] += 1;
218 for j in 0..n_features {
219 class_feature_counts[[ci, j]] = class_feature_counts[[ci, j]] + x[[sample_idx, j]];
220 }
221 }
222
223 let total_feature_counts: Array1<F> = class_feature_counts.rows().into_iter().fold(
225 Array1::<F>::zeros(n_features),
226 |acc, row| {
227 let mut result = acc;
228 for j in 0..n_features {
229 result[j] = result[j] + row[j];
230 }
231 result
232 },
233 );
234
235 let total_all: F = total_feature_counts.sum();
236
237 let mut weights = Array2::<F>::zeros((n_classes, n_features));
243
244 for ci in 0..n_classes {
245 let complement_total = total_all - class_feature_counts.row(ci).sum();
247
248 let denom = complement_total + alpha * n_feat_f;
249
250 for j in 0..n_features {
251 let complement_count_j = total_feature_counts[j] - class_feature_counts[[ci, j]];
252 weights[[ci, j]] = -((complement_count_j + alpha) / denom).ln();
256 }
257 }
258
259 if self.norm {
260 apply_norm_inplace(&mut weights);
261 }
262
263 if let Some(ref priors) = self.class_prior {
265 if priors.len() != n_classes {
266 return Err(FerroError::InvalidParameter {
267 name: "class_prior".into(),
268 reason: format!(
269 "length {} does not match number of classes {}",
270 priors.len(),
271 n_classes
272 ),
273 });
274 }
275 }
276
277 Ok(FittedComplementNB {
278 classes,
279 weights,
280 feature_counts: class_feature_counts,
281 class_counts,
282 alpha,
283 norm: self.norm,
284 })
285 }
286}
287
288fn apply_norm_inplace<F: Float>(weights: &mut Array2<F>) {
294 let n_classes = weights.nrows();
295 let n_features = weights.ncols();
296 for ci in 0..n_classes {
297 let row_sum = (0..n_features).fold(F::zero(), |acc, j| acc + weights[[ci, j]]);
298 if row_sum == F::zero() {
299 continue;
300 }
301 for j in 0..n_features {
302 weights[[ci, j]] = weights[[ci, j]] / row_sum;
303 }
304 }
305}
306
307impl<F: Float + Send + Sync + 'static> FittedComplementNB<F> {
308 pub fn partial_fit(&mut self, x: &Array2<F>, y: &Array1<usize>) -> Result<(), FerroError> {
319 let (n_samples, n_features) = x.dim();
320
321 if n_samples == 0 {
322 return Ok(());
323 }
324
325 if n_samples != y.len() {
326 return Err(FerroError::ShapeMismatch {
327 expected: vec![n_samples],
328 actual: vec![y.len()],
329 context: "y length must match number of samples in X".into(),
330 });
331 }
332
333 if n_features != self.weights.ncols() {
334 return Err(FerroError::ShapeMismatch {
335 expected: vec![self.weights.ncols()],
336 actual: vec![n_features],
337 context: "number of features must match fitted ComplementNB".into(),
338 });
339 }
340
341 if x.iter().any(|&v| v < F::zero()) {
342 return Err(FerroError::InvalidParameter {
343 name: "X".into(),
344 reason: "ComplementNB requires non-negative feature values".into(),
345 });
346 }
347
348 for (ci, &class_label) in self.classes.clone().iter().enumerate() {
350 let new_indices: Vec<usize> = y
351 .iter()
352 .enumerate()
353 .filter_map(|(i, &label)| if label == class_label { Some(i) } else { None })
354 .collect();
355
356 if new_indices.is_empty() {
357 continue;
358 }
359
360 self.class_counts[ci] += new_indices.len();
361
362 for &i in &new_indices {
363 for j in 0..n_features {
364 self.feature_counts[[ci, j]] = self.feature_counts[[ci, j]] + x[[i, j]];
365 }
366 }
367 }
368
369 let n_classes = self.classes.len();
371 let n_feat_f = F::from(n_features).unwrap();
372
373 let total_feature_counts: Array1<F> = self.feature_counts.rows().into_iter().fold(
374 Array1::<F>::zeros(n_features),
375 |acc, row| {
376 let mut result = acc;
377 for j in 0..n_features {
378 result[j] = result[j] + row[j];
379 }
380 result
381 },
382 );
383
384 let total_all: F = total_feature_counts.sum();
385
386 for ci in 0..n_classes {
387 let complement_total = total_all - self.feature_counts.row(ci).sum();
388 let denom = complement_total + self.alpha * n_feat_f;
389 for j in 0..n_features {
390 let complement_count_j = total_feature_counts[j] - self.feature_counts[[ci, j]];
391 self.weights[[ci, j]] = -((complement_count_j + self.alpha) / denom).ln();
393 }
394 }
395
396 if self.norm {
397 apply_norm_inplace(&mut self.weights);
398 }
399
400 Ok(())
401 }
402
403 fn complement_scores(&self, x: &Array2<F>) -> Array2<F> {
409 let n_samples = x.nrows();
410 let n_classes = self.classes.len();
411 let n_features = x.ncols();
412
413 let mut scores = Array2::<F>::zeros((n_samples, n_classes));
414
415 for i in 0..n_samples {
416 for ci in 0..n_classes {
417 let mut score = F::zero();
418 for j in 0..n_features {
419 score = score + x[[i, j]] * self.weights[[ci, j]];
420 }
421 scores[[i, ci]] = score;
422 }
423 }
424
425 scores
426 }
427
428 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
440 let n_features_fitted = self.weights.ncols();
441 if x.ncols() != n_features_fitted {
442 return Err(FerroError::ShapeMismatch {
443 expected: vec![n_features_fitted],
444 actual: vec![x.ncols()],
445 context: "number of features must match fitted ComplementNB".into(),
446 });
447 }
448
449 let scores = self.complement_scores(x);
452 let n_samples = x.nrows();
453 let n_classes = self.classes.len();
454 let mut proba = Array2::<F>::zeros((n_samples, n_classes));
455
456 for i in 0..n_samples {
457 let max_score = scores
458 .row(i)
459 .iter()
460 .fold(F::neg_infinity(), |a, &b| a.max(b));
461
462 let mut row_sum = F::zero();
463 for ci in 0..n_classes {
464 let p = (scores[[i, ci]] - max_score).exp();
465 proba[[i, ci]] = p;
466 row_sum = row_sum + p;
467 }
468 for ci in 0..n_classes {
469 proba[[i, ci]] = proba[[i, ci]] / row_sum;
470 }
471 }
472
473 Ok(proba)
474 }
475
476 pub fn predict_joint_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
489 let n_features_fitted = self.weights.ncols();
490 if x.ncols() != n_features_fitted {
491 return Err(FerroError::ShapeMismatch {
492 expected: vec![n_features_fitted],
493 actual: vec![x.ncols()],
494 context: "number of features must match fitted ComplementNB".into(),
495 });
496 }
497 Ok(self.complement_scores(x))
500 }
501
502 pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
511 let jll = self.predict_joint_log_proba(x)?;
512 Ok(crate::log_softmax_rows(&jll))
513 }
514
515 pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
524 if x.nrows() != y.len() {
525 return Err(FerroError::ShapeMismatch {
526 expected: vec![x.nrows()],
527 actual: vec![y.len()],
528 context: "y length must match number of samples in X".into(),
529 });
530 }
531 let preds = self.predict(x)?;
532 let n = y.len();
533 if n == 0 {
534 return Ok(F::zero());
535 }
536 let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
537 Ok(F::from(correct).unwrap() / F::from(n).unwrap())
538 }
539}
540
541impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedComplementNB<F> {
542 type Output = Array1<usize>;
543 type Error = FerroError;
544
545 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
554 let n_features_fitted = self.weights.ncols();
555 if x.ncols() != n_features_fitted {
556 return Err(FerroError::ShapeMismatch {
557 expected: vec![n_features_fitted],
558 actual: vec![x.ncols()],
559 context: "number of features must match fitted ComplementNB".into(),
560 });
561 }
562
563 let scores = self.complement_scores(x);
564 let n_samples = x.nrows();
565 let n_classes = self.classes.len();
566
567 let mut predictions = Array1::<usize>::zeros(n_samples);
568 for i in 0..n_samples {
569 let mut best_class = 0;
572 let mut best_score = scores[[i, 0]];
573 for ci in 1..n_classes {
574 if scores[[i, ci]] > best_score {
575 best_score = scores[[i, ci]];
576 best_class = ci;
577 }
578 }
579 predictions[i] = self.classes[best_class];
580 }
581
582 Ok(predictions)
583 }
584}
585
586impl<F: Float + Send + Sync + 'static> HasClasses for FittedComplementNB<F> {
587 fn classes(&self) -> &[usize] {
588 &self.classes
589 }
590
591 fn n_classes(&self) -> usize {
592 self.classes.len()
593 }
594}
595
596impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
598 for ComplementNB<F>
599{
600 fn fit_pipeline(
601 &self,
602 x: &Array2<F>,
603 y: &Array1<F>,
604 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
605 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
606 let fitted = self.fit(x, &y_usize)?;
607 Ok(Box::new(FittedComplementNBPipeline(fitted)))
608 }
609}
610
611struct FittedComplementNBPipeline<F: Float + Send + Sync + 'static>(FittedComplementNB<F>);
612
613unsafe impl<F: Float + Send + Sync + 'static> Send for FittedComplementNBPipeline<F> {}
614unsafe impl<F: Float + Send + Sync + 'static> Sync for FittedComplementNBPipeline<F> {}
615
616impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
617 for FittedComplementNBPipeline<F>
618{
619 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
620 let preds = self.0.predict(x)?;
621 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
622 }
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628 use approx::assert_relative_eq;
629 use ndarray::array;
630
631 fn make_count_data() -> (Array2<f64>, Array1<usize>) {
632 let x = Array2::from_shape_vec(
633 (6, 3),
634 vec![
635 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 0.0,
636 2.0, 6.0,
637 ],
638 )
639 .unwrap();
640 let y = array![0usize, 0, 0, 1, 1, 1];
641 (x, y)
642 }
643
644 #[test]
645 fn test_complement_nb_fit_predict() {
646 let (x, y) = make_count_data();
647 let model = ComplementNB::<f64>::new();
648 let fitted = model.fit(&x, &y).unwrap();
649 let preds = fitted.predict(&x).unwrap();
650 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
651 assert_eq!(correct, 6);
652 }
653
654 #[test]
655 fn test_complement_nb_predict_proba_sums_to_one() {
656 let (x, y) = make_count_data();
657 let model = ComplementNB::<f64>::new();
658 let fitted = model.fit(&x, &y).unwrap();
659 let proba = fitted.predict_proba(&x).unwrap();
660 for i in 0..proba.nrows() {
661 assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
662 }
663 }
664
665 #[test]
666 fn test_complement_nb_has_classes() {
667 let (x, y) = make_count_data();
668 let model = ComplementNB::<f64>::new();
669 let fitted = model.fit(&x, &y).unwrap();
670 assert_eq!(fitted.classes(), &[0, 1]);
671 assert_eq!(fitted.n_classes(), 2);
672 }
673
674 #[test]
675 fn test_complement_nb_shape_mismatch_fit() {
676 let x = Array2::from_shape_vec((4, 3), vec![1.0; 12]).unwrap();
677 let y = array![0usize, 1]; let model = ComplementNB::<f64>::new();
679 assert!(model.fit(&x, &y).is_err());
680 }
681
682 #[test]
683 fn test_complement_nb_shape_mismatch_predict() {
684 let (x, y) = make_count_data();
685 let model = ComplementNB::<f64>::new();
686 let fitted = model.fit(&x, &y).unwrap();
687 let x_bad = Array2::from_shape_vec((3, 5), vec![1.0; 15]).unwrap();
688 assert!(fitted.predict(&x_bad).is_err());
689 assert!(fitted.predict_proba(&x_bad).is_err());
690 }
691
692 #[test]
693 fn test_complement_nb_negative_features_error() {
694 let x =
695 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, -0.5, 3.0, 2.0, 1.0, 0.0, 4.0]).unwrap();
696 let y = array![0usize, 0, 1, 1];
697 let model = ComplementNB::<f64>::new();
698 assert!(model.fit(&x, &y).is_err());
699 }
700
701 #[test]
702 fn test_complement_nb_single_class() {
703 let x = Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
704 .unwrap();
705 let y = array![0usize, 0, 0];
706 let model = ComplementNB::<f64>::new();
707 let fitted = model.fit(&x, &y).unwrap();
708 assert_eq!(fitted.classes(), &[0]);
709 let preds = fitted.predict(&x).unwrap();
710 assert!(preds.iter().all(|&p| p == 0));
711 }
712
713 #[test]
714 fn test_complement_nb_empty_data() {
715 let x = Array2::<f64>::zeros((0, 3));
716 let y = Array1::<usize>::zeros(0);
717 let model = ComplementNB::<f64>::new();
718 assert!(model.fit(&x, &y).is_err());
719 }
720
721 #[test]
722 fn test_complement_nb_default() {
723 let model = ComplementNB::<f64>::default();
724 assert_relative_eq!(model.alpha, 1.0, epsilon = 1e-15);
725 }
726
727 #[test]
728 fn test_complement_nb_imbalanced_data() {
729 let x = Array2::from_shape_vec(
732 (12, 3),
733 vec![
734 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0,
735 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 0.0, 1.0,
736 5.0, 0.0, 2.0, 6.0, ],
739 )
740 .unwrap();
741 let y = array![0usize, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1];
742
743 let model = ComplementNB::<f64>::new();
744 let fitted = model.fit(&x, &y).unwrap();
745 let preds = fitted.predict(&x).unwrap();
746
747 assert_eq!(preds[10], 1);
749 assert_eq!(preds[11], 1);
750 }
751
752 #[test]
753 fn test_complement_nb_partial_fit() {
754 let x1 = Array2::from_shape_vec(
755 (4, 3),
756 vec![5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0],
757 )
758 .unwrap();
759 let y1 = array![0usize, 0, 1, 1];
760
761 let model = ComplementNB::<f64>::new();
762 let mut fitted = model.fit(&x1, &y1).unwrap();
763
764 let x2 = Array2::from_shape_vec((2, 3), vec![6.0, 0.0, 1.0, 0.0, 2.0, 6.0]).unwrap();
765 let y2 = array![0usize, 1];
766
767 fitted.partial_fit(&x2, &y2).unwrap();
768
769 let preds = fitted.predict(&x1).unwrap();
770 assert_eq!(preds.len(), 4);
771 }
772
773 #[test]
774 fn test_complement_nb_partial_fit_shape_mismatch() {
775 let (x, y) = make_count_data();
776 let model = ComplementNB::<f64>::new();
777 let mut fitted = model.fit(&x, &y).unwrap();
778
779 let x_bad = Array2::from_shape_vec((2, 5), vec![1.0; 10]).unwrap();
780 let y_bad = array![0usize, 1];
781 assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
782 }
783
784 #[test]
785 fn test_complement_nb_class_prior() {
786 let (x, y) = make_count_data();
787 let model = ComplementNB::<f64>::new().with_class_prior(vec![0.5, 0.5]);
788 let fitted = model.fit(&x, &y).unwrap();
789 let preds = fitted.predict(&x).unwrap();
790 assert_eq!(preds.len(), 6);
791 }
792
793 #[test]
794 fn test_complement_nb_class_prior_wrong_length() {
795 let (x, y) = make_count_data();
796 let model = ComplementNB::<f64>::new().with_class_prior(vec![1.0]);
797 assert!(model.fit(&x, &y).is_err());
798 }
799
800 #[test]
801 fn test_complement_nb_three_classes() {
802 let x = Array2::from_shape_vec(
803 (9, 3),
804 vec![
805 5.0, 0.0, 0.0, 6.0, 0.0, 0.0, 4.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0,
806 4.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0, 4.0,
807 ],
808 )
809 .unwrap();
810 let y = array![0usize, 0, 0, 1, 1, 1, 2, 2, 2];
811
812 let model = ComplementNB::<f64>::new();
813 let fitted = model.fit(&x, &y).unwrap();
814 assert_eq!(fitted.n_classes(), 3);
815 let preds = fitted.predict(&x).unwrap();
816 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
817 assert!(correct >= 7);
818 }
819}