1use ferrolearn_core::error::FerroError;
47use ferrolearn_core::introspection::HasClasses;
48use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
49use ferrolearn_core::traits::{Fit, Predict};
50use ndarray::{Array1, Array2};
51use num_traits::{Float, FromPrimitive, ToPrimitive};
52
53#[derive(Debug, Clone)]
62pub struct ComplementNB<F> {
63 pub alpha: F,
65 pub class_prior: Option<Vec<F>>,
69 pub fit_prior: bool,
73 pub force_alpha: bool,
76 pub norm: bool,
80}
81
82impl<F: Float> ComplementNB<F> {
83 #[must_use]
85 pub fn new() -> Self {
86 Self {
87 alpha: F::one(),
88 class_prior: None,
89 fit_prior: true,
90 force_alpha: true,
91 norm: false,
92 }
93 }
94
95 #[must_use]
97 pub fn with_alpha(mut self, alpha: F) -> Self {
98 self.alpha = alpha;
99 self
100 }
101
102 #[must_use]
108 pub fn with_class_prior(mut self, priors: Vec<F>) -> Self {
109 self.class_prior = Some(priors);
110 self
111 }
112
113 #[must_use]
115 pub fn with_fit_prior(mut self, fit_prior: bool) -> Self {
116 self.fit_prior = fit_prior;
117 self
118 }
119
120 #[must_use]
122 pub fn with_force_alpha(mut self, force_alpha: bool) -> Self {
123 self.force_alpha = force_alpha;
124 self
125 }
126
127 #[must_use]
130 pub fn with_norm(mut self, norm: bool) -> Self {
131 self.norm = norm;
132 self
133 }
134}
135
136impl<F: Float> Default for ComplementNB<F> {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct FittedComplementNB<F> {
145 classes: Vec<usize>,
147 weights: Array2<F>,
150 feature_counts: Array2<F>,
152 class_counts: Vec<usize>,
154 alpha: F,
157 norm: bool,
160}
161
162impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ComplementNB<F> {
163 type Fitted = FittedComplementNB<F>;
164 type Error = FerroError;
165
166 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedComplementNB<F>, FerroError> {
174 let (n_samples, n_features) = x.dim();
175
176 if n_samples == 0 {
177 return Err(FerroError::InsufficientSamples {
178 required: 1,
179 actual: 0,
180 context: "ComplementNB requires at least one sample".into(),
181 });
182 }
183
184 if n_samples != y.len() {
185 return Err(FerroError::ShapeMismatch {
186 expected: vec![n_samples],
187 actual: vec![y.len()],
188 context: "y length must match number of samples in X".into(),
189 });
190 }
191
192 if x.iter().any(|&v| v < F::zero()) {
194 return Err(FerroError::InvalidParameter {
195 name: "X".into(),
196 reason: "ComplementNB requires non-negative feature values".into(),
197 });
198 }
199
200 let mut classes: Vec<usize> = y.to_vec();
202 classes.sort_unstable();
203 classes.dedup();
204 let n_classes = classes.len();
205
206 let n_feat_f = F::from(n_features).unwrap();
207 let alpha = crate::clamp_alpha(self.alpha, self.force_alpha);
208
209 let mut class_feature_counts = Array2::<F>::zeros((n_classes, n_features));
211 let mut class_counts = vec![0usize; n_classes];
212
213 for (sample_idx, &label) in y.iter().enumerate() {
214 let ci = classes.iter().position(|&c| c == label).unwrap();
215 class_counts[ci] += 1;
216 for j in 0..n_features {
217 class_feature_counts[[ci, j]] = class_feature_counts[[ci, j]] + x[[sample_idx, j]];
218 }
219 }
220
221 let total_feature_counts: Array1<F> = class_feature_counts.rows().into_iter().fold(
223 Array1::<F>::zeros(n_features),
224 |acc, row| {
225 let mut result = acc;
226 for j in 0..n_features {
227 result[j] = result[j] + row[j];
228 }
229 result
230 },
231 );
232
233 let total_all: F = total_feature_counts.sum();
234
235 let mut weights = Array2::<F>::zeros((n_classes, n_features));
237
238 for ci in 0..n_classes {
239 let complement_total = total_all - class_feature_counts.row(ci).sum();
241
242 let denom = complement_total + alpha * n_feat_f;
243
244 for j in 0..n_features {
245 let complement_count_j = total_feature_counts[j] - class_feature_counts[[ci, j]];
246 weights[[ci, j]] = ((complement_count_j + alpha) / denom).ln();
247 }
248 }
249
250 if self.norm {
251 apply_norm_inplace(&mut weights);
252 }
253
254 if let Some(ref priors) = self.class_prior {
256 if priors.len() != n_classes {
257 return Err(FerroError::InvalidParameter {
258 name: "class_prior".into(),
259 reason: format!(
260 "length {} does not match number of classes {}",
261 priors.len(),
262 n_classes
263 ),
264 });
265 }
266 }
267
268 Ok(FittedComplementNB {
269 classes,
270 weights,
271 feature_counts: class_feature_counts,
272 class_counts,
273 alpha,
274 norm: self.norm,
275 })
276 }
277}
278
279fn apply_norm_inplace<F: Float>(weights: &mut Array2<F>) {
288 let n_classes = weights.nrows();
289 let n_features = weights.ncols();
290 for ci in 0..n_classes {
291 let row_sum = (0..n_features).fold(F::zero(), |acc, j| acc + weights[[ci, j]]);
292 if row_sum == F::zero() {
293 continue;
294 }
295 for j in 0..n_features {
296 weights[[ci, j]] = -(weights[[ci, j]] / row_sum);
297 }
298 }
299}
300
301impl<F: Float + Send + Sync + 'static> FittedComplementNB<F> {
302 pub fn partial_fit(&mut self, x: &Array2<F>, y: &Array1<usize>) -> Result<(), FerroError> {
313 let (n_samples, n_features) = x.dim();
314
315 if n_samples == 0 {
316 return Ok(());
317 }
318
319 if n_samples != y.len() {
320 return Err(FerroError::ShapeMismatch {
321 expected: vec![n_samples],
322 actual: vec![y.len()],
323 context: "y length must match number of samples in X".into(),
324 });
325 }
326
327 if n_features != self.weights.ncols() {
328 return Err(FerroError::ShapeMismatch {
329 expected: vec![self.weights.ncols()],
330 actual: vec![n_features],
331 context: "number of features must match fitted ComplementNB".into(),
332 });
333 }
334
335 if x.iter().any(|&v| v < F::zero()) {
336 return Err(FerroError::InvalidParameter {
337 name: "X".into(),
338 reason: "ComplementNB requires non-negative feature values".into(),
339 });
340 }
341
342 for (ci, &class_label) in self.classes.clone().iter().enumerate() {
344 let new_indices: Vec<usize> = y
345 .iter()
346 .enumerate()
347 .filter_map(|(i, &label)| if label == class_label { Some(i) } else { None })
348 .collect();
349
350 if new_indices.is_empty() {
351 continue;
352 }
353
354 self.class_counts[ci] += new_indices.len();
355
356 for &i in &new_indices {
357 for j in 0..n_features {
358 self.feature_counts[[ci, j]] = self.feature_counts[[ci, j]] + x[[i, j]];
359 }
360 }
361 }
362
363 let n_classes = self.classes.len();
365 let n_feat_f = F::from(n_features).unwrap();
366
367 let total_feature_counts: Array1<F> = self.feature_counts.rows().into_iter().fold(
368 Array1::<F>::zeros(n_features),
369 |acc, row| {
370 let mut result = acc;
371 for j in 0..n_features {
372 result[j] = result[j] + row[j];
373 }
374 result
375 },
376 );
377
378 let total_all: F = total_feature_counts.sum();
379
380 for ci in 0..n_classes {
381 let complement_total = total_all - self.feature_counts.row(ci).sum();
382 let denom = complement_total + self.alpha * n_feat_f;
383 for j in 0..n_features {
384 let complement_count_j = total_feature_counts[j] - self.feature_counts[[ci, j]];
385 self.weights[[ci, j]] = ((complement_count_j + self.alpha) / denom).ln();
386 }
387 }
388
389 if self.norm {
390 apply_norm_inplace(&mut self.weights);
391 }
392
393 Ok(())
394 }
395
396 fn complement_scores(&self, x: &Array2<F>) -> Array2<F> {
400 let n_samples = x.nrows();
401 let n_classes = self.classes.len();
402 let n_features = x.ncols();
403
404 let mut scores = Array2::<F>::zeros((n_samples, n_classes));
405
406 for i in 0..n_samples {
407 for ci in 0..n_classes {
408 let mut score = F::zero();
409 for j in 0..n_features {
410 score = score + x[[i, j]] * self.weights[[ci, j]];
411 }
412 scores[[i, ci]] = score;
413 }
414 }
415
416 scores
417 }
418
419 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
431 let n_features_fitted = self.weights.ncols();
432 if x.ncols() != n_features_fitted {
433 return Err(FerroError::ShapeMismatch {
434 expected: vec![n_features_fitted],
435 actual: vec![x.ncols()],
436 context: "number of features must match fitted ComplementNB".into(),
437 });
438 }
439
440 let neg_scores = self.complement_scores(x).mapv(|v| -v);
442 let n_samples = x.nrows();
443 let n_classes = self.classes.len();
444 let mut proba = Array2::<F>::zeros((n_samples, n_classes));
445
446 for i in 0..n_samples {
447 let max_score = neg_scores
448 .row(i)
449 .iter()
450 .fold(F::neg_infinity(), |a, &b| a.max(b));
451
452 let mut row_sum = F::zero();
453 for ci in 0..n_classes {
454 let p = (neg_scores[[i, ci]] - max_score).exp();
455 proba[[i, ci]] = p;
456 row_sum = row_sum + p;
457 }
458 for ci in 0..n_classes {
459 proba[[i, ci]] = proba[[i, ci]] / row_sum;
460 }
461 }
462
463 Ok(proba)
464 }
465
466 pub fn predict_joint_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
479 let n_features_fitted = self.weights.ncols();
480 if x.ncols() != n_features_fitted {
481 return Err(FerroError::ShapeMismatch {
482 expected: vec![n_features_fitted],
483 actual: vec![x.ncols()],
484 context: "number of features must match fitted ComplementNB".into(),
485 });
486 }
487 Ok(self.complement_scores(x).mapv(|v| -v))
488 }
489
490 pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
499 let jll = self.predict_joint_log_proba(x)?;
500 Ok(crate::log_softmax_rows(&jll))
501 }
502
503 pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
512 if x.nrows() != y.len() {
513 return Err(FerroError::ShapeMismatch {
514 expected: vec![x.nrows()],
515 actual: vec![y.len()],
516 context: "y length must match number of samples in X".into(),
517 });
518 }
519 let preds = self.predict(x)?;
520 let n = y.len();
521 if n == 0 {
522 return Ok(F::zero());
523 }
524 let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
525 Ok(F::from(correct).unwrap() / F::from(n).unwrap())
526 }
527}
528
529impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedComplementNB<F> {
530 type Output = Array1<usize>;
531 type Error = FerroError;
532
533 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
542 let n_features_fitted = self.weights.ncols();
543 if x.ncols() != n_features_fitted {
544 return Err(FerroError::ShapeMismatch {
545 expected: vec![n_features_fitted],
546 actual: vec![x.ncols()],
547 context: "number of features must match fitted ComplementNB".into(),
548 });
549 }
550
551 let scores = self.complement_scores(x);
552 let n_samples = x.nrows();
553 let n_classes = self.classes.len();
554
555 let mut predictions = Array1::<usize>::zeros(n_samples);
556 for i in 0..n_samples {
557 let mut best_class = 0;
559 let mut best_score = scores[[i, 0]];
560 for ci in 1..n_classes {
561 if scores[[i, ci]] < best_score {
562 best_score = scores[[i, ci]];
563 best_class = ci;
564 }
565 }
566 predictions[i] = self.classes[best_class];
567 }
568
569 Ok(predictions)
570 }
571}
572
573impl<F: Float + Send + Sync + 'static> HasClasses for FittedComplementNB<F> {
574 fn classes(&self) -> &[usize] {
575 &self.classes
576 }
577
578 fn n_classes(&self) -> usize {
579 self.classes.len()
580 }
581}
582
583impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
585 for ComplementNB<F>
586{
587 fn fit_pipeline(
588 &self,
589 x: &Array2<F>,
590 y: &Array1<F>,
591 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
592 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
593 let fitted = self.fit(x, &y_usize)?;
594 Ok(Box::new(FittedComplementNBPipeline(fitted)))
595 }
596}
597
598struct FittedComplementNBPipeline<F: Float + Send + Sync + 'static>(FittedComplementNB<F>);
599
600unsafe impl<F: Float + Send + Sync + 'static> Send for FittedComplementNBPipeline<F> {}
601unsafe impl<F: Float + Send + Sync + 'static> Sync for FittedComplementNBPipeline<F> {}
602
603impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
604 for FittedComplementNBPipeline<F>
605{
606 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
607 let preds = self.0.predict(x)?;
608 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615 use approx::assert_relative_eq;
616 use ndarray::array;
617
618 fn make_count_data() -> (Array2<f64>, Array1<usize>) {
619 let x = Array2::from_shape_vec(
620 (6, 3),
621 vec![
622 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 0.0,
623 2.0, 6.0,
624 ],
625 )
626 .unwrap();
627 let y = array![0usize, 0, 0, 1, 1, 1];
628 (x, y)
629 }
630
631 #[test]
632 fn test_complement_nb_fit_predict() {
633 let (x, y) = make_count_data();
634 let model = ComplementNB::<f64>::new();
635 let fitted = model.fit(&x, &y).unwrap();
636 let preds = fitted.predict(&x).unwrap();
637 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
638 assert_eq!(correct, 6);
639 }
640
641 #[test]
642 fn test_complement_nb_predict_proba_sums_to_one() {
643 let (x, y) = make_count_data();
644 let model = ComplementNB::<f64>::new();
645 let fitted = model.fit(&x, &y).unwrap();
646 let proba = fitted.predict_proba(&x).unwrap();
647 for i in 0..proba.nrows() {
648 assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
649 }
650 }
651
652 #[test]
653 fn test_complement_nb_has_classes() {
654 let (x, y) = make_count_data();
655 let model = ComplementNB::<f64>::new();
656 let fitted = model.fit(&x, &y).unwrap();
657 assert_eq!(fitted.classes(), &[0, 1]);
658 assert_eq!(fitted.n_classes(), 2);
659 }
660
661 #[test]
662 fn test_complement_nb_shape_mismatch_fit() {
663 let x = Array2::from_shape_vec((4, 3), vec![1.0; 12]).unwrap();
664 let y = array![0usize, 1]; let model = ComplementNB::<f64>::new();
666 assert!(model.fit(&x, &y).is_err());
667 }
668
669 #[test]
670 fn test_complement_nb_shape_mismatch_predict() {
671 let (x, y) = make_count_data();
672 let model = ComplementNB::<f64>::new();
673 let fitted = model.fit(&x, &y).unwrap();
674 let x_bad = Array2::from_shape_vec((3, 5), vec![1.0; 15]).unwrap();
675 assert!(fitted.predict(&x_bad).is_err());
676 assert!(fitted.predict_proba(&x_bad).is_err());
677 }
678
679 #[test]
680 fn test_complement_nb_negative_features_error() {
681 let x =
682 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, -0.5, 3.0, 2.0, 1.0, 0.0, 4.0]).unwrap();
683 let y = array![0usize, 0, 1, 1];
684 let model = ComplementNB::<f64>::new();
685 assert!(model.fit(&x, &y).is_err());
686 }
687
688 #[test]
689 fn test_complement_nb_single_class() {
690 let x = Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
691 .unwrap();
692 let y = array![0usize, 0, 0];
693 let model = ComplementNB::<f64>::new();
694 let fitted = model.fit(&x, &y).unwrap();
695 assert_eq!(fitted.classes(), &[0]);
696 let preds = fitted.predict(&x).unwrap();
697 assert!(preds.iter().all(|&p| p == 0));
698 }
699
700 #[test]
701 fn test_complement_nb_empty_data() {
702 let x = Array2::<f64>::zeros((0, 3));
703 let y = Array1::<usize>::zeros(0);
704 let model = ComplementNB::<f64>::new();
705 assert!(model.fit(&x, &y).is_err());
706 }
707
708 #[test]
709 fn test_complement_nb_default() {
710 let model = ComplementNB::<f64>::default();
711 assert_relative_eq!(model.alpha, 1.0, epsilon = 1e-15);
712 }
713
714 #[test]
715 fn test_complement_nb_imbalanced_data() {
716 let x = Array2::from_shape_vec(
719 (12, 3),
720 vec![
721 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0,
722 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 0.0, 1.0,
723 5.0, 0.0, 2.0, 6.0, ],
726 )
727 .unwrap();
728 let y = array![0usize, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1];
729
730 let model = ComplementNB::<f64>::new();
731 let fitted = model.fit(&x, &y).unwrap();
732 let preds = fitted.predict(&x).unwrap();
733
734 assert_eq!(preds[10], 1);
736 assert_eq!(preds[11], 1);
737 }
738
739 #[test]
740 fn test_complement_nb_partial_fit() {
741 let x1 = Array2::from_shape_vec(
742 (4, 3),
743 vec![5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0],
744 )
745 .unwrap();
746 let y1 = array![0usize, 0, 1, 1];
747
748 let model = ComplementNB::<f64>::new();
749 let mut fitted = model.fit(&x1, &y1).unwrap();
750
751 let x2 = Array2::from_shape_vec((2, 3), vec![6.0, 0.0, 1.0, 0.0, 2.0, 6.0]).unwrap();
752 let y2 = array![0usize, 1];
753
754 fitted.partial_fit(&x2, &y2).unwrap();
755
756 let preds = fitted.predict(&x1).unwrap();
757 assert_eq!(preds.len(), 4);
758 }
759
760 #[test]
761 fn test_complement_nb_partial_fit_shape_mismatch() {
762 let (x, y) = make_count_data();
763 let model = ComplementNB::<f64>::new();
764 let mut fitted = model.fit(&x, &y).unwrap();
765
766 let x_bad = Array2::from_shape_vec((2, 5), vec![1.0; 10]).unwrap();
767 let y_bad = array![0usize, 1];
768 assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
769 }
770
771 #[test]
772 fn test_complement_nb_class_prior() {
773 let (x, y) = make_count_data();
774 let model = ComplementNB::<f64>::new().with_class_prior(vec![0.5, 0.5]);
775 let fitted = model.fit(&x, &y).unwrap();
776 let preds = fitted.predict(&x).unwrap();
777 assert_eq!(preds.len(), 6);
778 }
779
780 #[test]
781 fn test_complement_nb_class_prior_wrong_length() {
782 let (x, y) = make_count_data();
783 let model = ComplementNB::<f64>::new().with_class_prior(vec![1.0]);
784 assert!(model.fit(&x, &y).is_err());
785 }
786
787 #[test]
788 fn test_complement_nb_three_classes() {
789 let x = Array2::from_shape_vec(
790 (9, 3),
791 vec![
792 5.0, 0.0, 0.0, 6.0, 0.0, 0.0, 4.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0,
793 4.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0, 4.0,
794 ],
795 )
796 .unwrap();
797 let y = array![0usize, 0, 0, 1, 1, 1, 2, 2, 2];
798
799 let model = ComplementNB::<f64>::new();
800 let fitted = model.fit(&x, &y).unwrap();
801 assert_eq!(fitted.n_classes(), 3);
802 let preds = fitted.predict(&x).unwrap();
803 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
804 assert!(correct >= 7);
805 }
806}