1use ferrolearn_core::error::FerroError;
47use ferrolearn_core::introspection::HasClasses;
48use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
49use ferrolearn_core::traits::{Fit, Predict};
50use ndarray::{Array1, Array2};
51use num_traits::{Float, FromPrimitive, ToPrimitive};
52
53#[derive(Debug, Clone)]
62pub struct ComplementNB<F> {
63 pub alpha: F,
65 pub class_prior: Option<Vec<F>>,
69}
70
71impl<F: Float> ComplementNB<F> {
72 #[must_use]
74 pub fn new() -> Self {
75 Self {
76 alpha: F::one(),
77 class_prior: None,
78 }
79 }
80
81 #[must_use]
83 pub fn with_alpha(mut self, alpha: F) -> Self {
84 self.alpha = alpha;
85 self
86 }
87
88 #[must_use]
94 pub fn with_class_prior(mut self, priors: Vec<F>) -> Self {
95 self.class_prior = Some(priors);
96 self
97 }
98}
99
100impl<F: Float> Default for ComplementNB<F> {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct FittedComplementNB<F> {
109 classes: Vec<usize>,
111 weights: Array2<F>,
114 feature_counts: Array2<F>,
116 class_counts: Vec<usize>,
118 alpha: F,
120}
121
122impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ComplementNB<F> {
123 type Fitted = FittedComplementNB<F>;
124 type Error = FerroError;
125
126 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedComplementNB<F>, FerroError> {
134 let (n_samples, n_features) = x.dim();
135
136 if n_samples == 0 {
137 return Err(FerroError::InsufficientSamples {
138 required: 1,
139 actual: 0,
140 context: "ComplementNB requires at least one sample".into(),
141 });
142 }
143
144 if n_samples != y.len() {
145 return Err(FerroError::ShapeMismatch {
146 expected: vec![n_samples],
147 actual: vec![y.len()],
148 context: "y length must match number of samples in X".into(),
149 });
150 }
151
152 if x.iter().any(|&v| v < F::zero()) {
154 return Err(FerroError::InvalidParameter {
155 name: "X".into(),
156 reason: "ComplementNB requires non-negative feature values".into(),
157 });
158 }
159
160 let mut classes: Vec<usize> = y.to_vec();
162 classes.sort_unstable();
163 classes.dedup();
164 let n_classes = classes.len();
165
166 let n_feat_f = F::from(n_features).unwrap();
167
168 let mut class_feature_counts = Array2::<F>::zeros((n_classes, n_features));
170 let mut class_counts = vec![0usize; n_classes];
171
172 for (sample_idx, &label) in y.iter().enumerate() {
173 let ci = classes.iter().position(|&c| c == label).unwrap();
174 class_counts[ci] += 1;
175 for j in 0..n_features {
176 class_feature_counts[[ci, j]] = class_feature_counts[[ci, j]] + x[[sample_idx, j]];
177 }
178 }
179
180 let total_feature_counts: Array1<F> = class_feature_counts.rows().into_iter().fold(
182 Array1::<F>::zeros(n_features),
183 |acc, row| {
184 let mut result = acc;
185 for j in 0..n_features {
186 result[j] = result[j] + row[j];
187 }
188 result
189 },
190 );
191
192 let total_all: F = total_feature_counts.sum();
193
194 let mut weights = Array2::<F>::zeros((n_classes, n_features));
196
197 for ci in 0..n_classes {
198 let complement_total = total_all - class_feature_counts.row(ci).sum();
200
201 let denom = complement_total + self.alpha * n_feat_f;
202
203 for j in 0..n_features {
204 let complement_count_j = total_feature_counts[j] - class_feature_counts[[ci, j]];
205 weights[[ci, j]] = ((complement_count_j + self.alpha) / denom).ln();
206 }
207 }
208
209 if let Some(ref priors) = self.class_prior {
211 if priors.len() != n_classes {
212 return Err(FerroError::InvalidParameter {
213 name: "class_prior".into(),
214 reason: format!(
215 "length {} does not match number of classes {}",
216 priors.len(),
217 n_classes
218 ),
219 });
220 }
221 }
222
223 Ok(FittedComplementNB {
224 classes,
225 weights,
226 feature_counts: class_feature_counts,
227 class_counts,
228 alpha: self.alpha,
229 })
230 }
231}
232
233impl<F: Float + Send + Sync + 'static> FittedComplementNB<F> {
234 pub fn partial_fit(
245 &mut self,
246 x: &Array2<F>,
247 y: &Array1<usize>,
248 ) -> Result<(), FerroError> {
249 let (n_samples, n_features) = x.dim();
250
251 if n_samples == 0 {
252 return Ok(());
253 }
254
255 if n_samples != y.len() {
256 return Err(FerroError::ShapeMismatch {
257 expected: vec![n_samples],
258 actual: vec![y.len()],
259 context: "y length must match number of samples in X".into(),
260 });
261 }
262
263 if n_features != self.weights.ncols() {
264 return Err(FerroError::ShapeMismatch {
265 expected: vec![self.weights.ncols()],
266 actual: vec![n_features],
267 context: "number of features must match fitted ComplementNB".into(),
268 });
269 }
270
271 if x.iter().any(|&v| v < F::zero()) {
272 return Err(FerroError::InvalidParameter {
273 name: "X".into(),
274 reason: "ComplementNB requires non-negative feature values".into(),
275 });
276 }
277
278 for (ci, &class_label) in self.classes.clone().iter().enumerate() {
280 let new_indices: Vec<usize> = y
281 .iter()
282 .enumerate()
283 .filter_map(|(i, &label)| if label == class_label { Some(i) } else { None })
284 .collect();
285
286 if new_indices.is_empty() {
287 continue;
288 }
289
290 self.class_counts[ci] += new_indices.len();
291
292 for &i in &new_indices {
293 for j in 0..n_features {
294 self.feature_counts[[ci, j]] = self.feature_counts[[ci, j]] + x[[i, j]];
295 }
296 }
297 }
298
299 let n_classes = self.classes.len();
301 let n_feat_f = F::from(n_features).unwrap();
302
303 let total_feature_counts: Array1<F> = self.feature_counts.rows().into_iter().fold(
304 Array1::<F>::zeros(n_features),
305 |acc, row| {
306 let mut result = acc;
307 for j in 0..n_features {
308 result[j] = result[j] + row[j];
309 }
310 result
311 },
312 );
313
314 let total_all: F = total_feature_counts.sum();
315
316 for ci in 0..n_classes {
317 let complement_total = total_all - self.feature_counts.row(ci).sum();
318 let denom = complement_total + self.alpha * n_feat_f;
319 for j in 0..n_features {
320 let complement_count_j =
321 total_feature_counts[j] - self.feature_counts[[ci, j]];
322 self.weights[[ci, j]] = ((complement_count_j + self.alpha) / denom).ln();
323 }
324 }
325
326 Ok(())
327 }
328
329 fn complement_scores(&self, x: &Array2<F>) -> Array2<F> {
333 let n_samples = x.nrows();
334 let n_classes = self.classes.len();
335 let n_features = x.ncols();
336
337 let mut scores = Array2::<F>::zeros((n_samples, n_classes));
338
339 for i in 0..n_samples {
340 for ci in 0..n_classes {
341 let mut score = F::zero();
342 for j in 0..n_features {
343 score = score + x[[i, j]] * self.weights[[ci, j]];
344 }
345 scores[[i, ci]] = score;
346 }
347 }
348
349 scores
350 }
351
352 pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
364 let n_features_fitted = self.weights.ncols();
365 if x.ncols() != n_features_fitted {
366 return Err(FerroError::ShapeMismatch {
367 expected: vec![n_features_fitted],
368 actual: vec![x.ncols()],
369 context: "number of features must match fitted ComplementNB".into(),
370 });
371 }
372
373 let neg_scores = self.complement_scores(x).mapv(|v| -v);
375 let n_samples = x.nrows();
376 let n_classes = self.classes.len();
377 let mut proba = Array2::<F>::zeros((n_samples, n_classes));
378
379 for i in 0..n_samples {
380 let max_score = neg_scores
381 .row(i)
382 .iter()
383 .fold(F::neg_infinity(), |a, &b| a.max(b));
384
385 let mut row_sum = F::zero();
386 for ci in 0..n_classes {
387 let p = (neg_scores[[i, ci]] - max_score).exp();
388 proba[[i, ci]] = p;
389 row_sum = row_sum + p;
390 }
391 for ci in 0..n_classes {
392 proba[[i, ci]] = proba[[i, ci]] / row_sum;
393 }
394 }
395
396 Ok(proba)
397 }
398}
399
400impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedComplementNB<F> {
401 type Output = Array1<usize>;
402 type Error = FerroError;
403
404 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
413 let n_features_fitted = self.weights.ncols();
414 if x.ncols() != n_features_fitted {
415 return Err(FerroError::ShapeMismatch {
416 expected: vec![n_features_fitted],
417 actual: vec![x.ncols()],
418 context: "number of features must match fitted ComplementNB".into(),
419 });
420 }
421
422 let scores = self.complement_scores(x);
423 let n_samples = x.nrows();
424 let n_classes = self.classes.len();
425
426 let mut predictions = Array1::<usize>::zeros(n_samples);
427 for i in 0..n_samples {
428 let mut best_class = 0;
430 let mut best_score = scores[[i, 0]];
431 for ci in 1..n_classes {
432 if scores[[i, ci]] < best_score {
433 best_score = scores[[i, ci]];
434 best_class = ci;
435 }
436 }
437 predictions[i] = self.classes[best_class];
438 }
439
440 Ok(predictions)
441 }
442}
443
444impl<F: Float + Send + Sync + 'static> HasClasses for FittedComplementNB<F> {
445 fn classes(&self) -> &[usize] {
446 &self.classes
447 }
448
449 fn n_classes(&self) -> usize {
450 self.classes.len()
451 }
452}
453
454impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
456 for ComplementNB<F>
457{
458 fn fit_pipeline(
459 &self,
460 x: &Array2<F>,
461 y: &Array1<F>,
462 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
463 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
464 let fitted = self.fit(x, &y_usize)?;
465 Ok(Box::new(FittedComplementNBPipeline(fitted)))
466 }
467}
468
469struct FittedComplementNBPipeline<F: Float + Send + Sync + 'static>(FittedComplementNB<F>);
470
471unsafe impl<F: Float + Send + Sync + 'static> Send for FittedComplementNBPipeline<F> {}
472unsafe impl<F: Float + Send + Sync + 'static> Sync for FittedComplementNBPipeline<F> {}
473
474impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
475 for FittedComplementNBPipeline<F>
476{
477 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
478 let preds = self.0.predict(x)?;
479 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use approx::assert_relative_eq;
487 use ndarray::array;
488
489 fn make_count_data() -> (Array2<f64>, Array1<usize>) {
490 let x = Array2::from_shape_vec(
491 (6, 3),
492 vec![
493 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 0.0,
494 2.0, 6.0,
495 ],
496 )
497 .unwrap();
498 let y = array![0usize, 0, 0, 1, 1, 1];
499 (x, y)
500 }
501
502 #[test]
503 fn test_complement_nb_fit_predict() {
504 let (x, y) = make_count_data();
505 let model = ComplementNB::<f64>::new();
506 let fitted = model.fit(&x, &y).unwrap();
507 let preds = fitted.predict(&x).unwrap();
508 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
509 assert_eq!(correct, 6);
510 }
511
512 #[test]
513 fn test_complement_nb_predict_proba_sums_to_one() {
514 let (x, y) = make_count_data();
515 let model = ComplementNB::<f64>::new();
516 let fitted = model.fit(&x, &y).unwrap();
517 let proba = fitted.predict_proba(&x).unwrap();
518 for i in 0..proba.nrows() {
519 assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
520 }
521 }
522
523 #[test]
524 fn test_complement_nb_has_classes() {
525 let (x, y) = make_count_data();
526 let model = ComplementNB::<f64>::new();
527 let fitted = model.fit(&x, &y).unwrap();
528 assert_eq!(fitted.classes(), &[0, 1]);
529 assert_eq!(fitted.n_classes(), 2);
530 }
531
532 #[test]
533 fn test_complement_nb_shape_mismatch_fit() {
534 let x = Array2::from_shape_vec((4, 3), vec![1.0; 12]).unwrap();
535 let y = array![0usize, 1]; let model = ComplementNB::<f64>::new();
537 assert!(model.fit(&x, &y).is_err());
538 }
539
540 #[test]
541 fn test_complement_nb_shape_mismatch_predict() {
542 let (x, y) = make_count_data();
543 let model = ComplementNB::<f64>::new();
544 let fitted = model.fit(&x, &y).unwrap();
545 let x_bad = Array2::from_shape_vec((3, 5), vec![1.0; 15]).unwrap();
546 assert!(fitted.predict(&x_bad).is_err());
547 assert!(fitted.predict_proba(&x_bad).is_err());
548 }
549
550 #[test]
551 fn test_complement_nb_negative_features_error() {
552 let x =
553 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, -0.5, 3.0, 2.0, 1.0, 0.0, 4.0]).unwrap();
554 let y = array![0usize, 0, 1, 1];
555 let model = ComplementNB::<f64>::new();
556 assert!(model.fit(&x, &y).is_err());
557 }
558
559 #[test]
560 fn test_complement_nb_single_class() {
561 let x = Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
562 .unwrap();
563 let y = array![0usize, 0, 0];
564 let model = ComplementNB::<f64>::new();
565 let fitted = model.fit(&x, &y).unwrap();
566 assert_eq!(fitted.classes(), &[0]);
567 let preds = fitted.predict(&x).unwrap();
568 assert!(preds.iter().all(|&p| p == 0));
569 }
570
571 #[test]
572 fn test_complement_nb_empty_data() {
573 let x = Array2::<f64>::zeros((0, 3));
574 let y = Array1::<usize>::zeros(0);
575 let model = ComplementNB::<f64>::new();
576 assert!(model.fit(&x, &y).is_err());
577 }
578
579 #[test]
580 fn test_complement_nb_default() {
581 let model = ComplementNB::<f64>::default();
582 assert_relative_eq!(model.alpha, 1.0, epsilon = 1e-15);
583 }
584
585 #[test]
586 fn test_complement_nb_imbalanced_data() {
587 let x = Array2::from_shape_vec(
590 (12, 3),
591 vec![
592 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0,
593 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 0.0, 1.0,
594 5.0, 0.0, 2.0, 6.0, ],
597 )
598 .unwrap();
599 let y = array![0usize, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1];
600
601 let model = ComplementNB::<f64>::new();
602 let fitted = model.fit(&x, &y).unwrap();
603 let preds = fitted.predict(&x).unwrap();
604
605 assert_eq!(preds[10], 1);
607 assert_eq!(preds[11], 1);
608 }
609
610 #[test]
611 fn test_complement_nb_partial_fit() {
612 let x1 = Array2::from_shape_vec(
613 (4, 3),
614 vec![5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0],
615 )
616 .unwrap();
617 let y1 = array![0usize, 0, 1, 1];
618
619 let model = ComplementNB::<f64>::new();
620 let mut fitted = model.fit(&x1, &y1).unwrap();
621
622 let x2 = Array2::from_shape_vec(
623 (2, 3),
624 vec![6.0, 0.0, 1.0, 0.0, 2.0, 6.0],
625 )
626 .unwrap();
627 let y2 = array![0usize, 1];
628
629 fitted.partial_fit(&x2, &y2).unwrap();
630
631 let preds = fitted.predict(&x1).unwrap();
632 assert_eq!(preds.len(), 4);
633 }
634
635 #[test]
636 fn test_complement_nb_partial_fit_shape_mismatch() {
637 let (x, y) = make_count_data();
638 let model = ComplementNB::<f64>::new();
639 let mut fitted = model.fit(&x, &y).unwrap();
640
641 let x_bad = Array2::from_shape_vec((2, 5), vec![1.0; 10]).unwrap();
642 let y_bad = array![0usize, 1];
643 assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
644 }
645
646 #[test]
647 fn test_complement_nb_class_prior() {
648 let (x, y) = make_count_data();
649 let model = ComplementNB::<f64>::new().with_class_prior(vec![0.5, 0.5]);
650 let fitted = model.fit(&x, &y).unwrap();
651 let preds = fitted.predict(&x).unwrap();
652 assert_eq!(preds.len(), 6);
653 }
654
655 #[test]
656 fn test_complement_nb_class_prior_wrong_length() {
657 let (x, y) = make_count_data();
658 let model = ComplementNB::<f64>::new().with_class_prior(vec![1.0]);
659 assert!(model.fit(&x, &y).is_err());
660 }
661
662 #[test]
663 fn test_complement_nb_three_classes() {
664 let x = Array2::from_shape_vec(
665 (9, 3),
666 vec![
667 5.0, 0.0, 0.0, 6.0, 0.0, 0.0, 4.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0,
668 4.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0, 4.0,
669 ],
670 )
671 .unwrap();
672 let y = array![0usize, 0, 0, 1, 1, 1, 2, 2, 2];
673
674 let model = ComplementNB::<f64>::new();
675 let fitted = model.fit(&x, &y).unwrap();
676 assert_eq!(fitted.n_classes(), 3);
677 let preds = fitted.predict(&x).unwrap();
678 let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
679 assert!(correct >= 7);
680 }
681}