Skip to main content

ferrolearn_bayes/
complement.rs

1//! Complement Naive Bayes classifier.
2//!
3//! This module provides [`ComplementNB`], a variant of Multinomial Naive Bayes
4//! that is particularly well-suited for imbalanced datasets. Instead of estimating
5//! the likelihood of a feature given a class, it estimates the likelihood of the
6//! feature given all *other* (complement) classes and inverts the weights.
7//!
8//! The weight for feature `j` in class `c` is:
9//!
10//! ```text
11//! w_cj = log( (N_~cj + alpha) / (N_~c + alpha * n_features) )
12//! ```
13//!
14//! where `N_~cj` is the total count of feature `j` in all classes except `c`,
15//! and `N_~c` is the total count of all features in all classes except `c`.
16//!
17//! Prediction uses `argmin_c sum_j x_j * w_cj` (i.e., the class with the
18//! *smallest* complement score is chosen).
19//!
20//! # Examples
21//!
22//! ```
23//! use ferrolearn_bayes::ComplementNB;
24//! use ferrolearn_core::{Fit, Predict};
25//! use ndarray::{array, Array2};
26//!
27//! let x = Array2::from_shape_vec(
28//!     (6, 3),
29//!     vec![
30//!         5.0, 1.0, 0.0,
31//!         4.0, 2.0, 0.0,
32//!         6.0, 0.0, 1.0,
33//!         0.0, 1.0, 5.0,
34//!         1.0, 0.0, 4.0,
35//!         0.0, 2.0, 6.0,
36//!     ],
37//! ).unwrap();
38//! let y = array![0usize, 0, 0, 1, 1, 1];
39//!
40//! let model = ComplementNB::<f64>::new();
41//! let fitted = model.fit(&x, &y).unwrap();
42//! let preds = fitted.predict(&x).unwrap();
43//! assert_eq!(preds.len(), 6);
44//! ```
45
46use ferrolearn_core::error::FerroError;
47use ferrolearn_core::introspection::HasClasses;
48use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
49use ferrolearn_core::traits::{Fit, Predict};
50use ndarray::{Array1, Array2};
51use num_traits::{Float, FromPrimitive, ToPrimitive};
52
53/// Complement Naive Bayes classifier.
54///
55/// A variant of Multinomial NB that uses complement-class statistics.
56/// More robust for imbalanced datasets.
57///
58/// # Type Parameters
59///
60/// - `F`: The floating-point type (`f32` or `f64`).
61#[derive(Debug, Clone)]
62pub struct ComplementNB<F> {
63    /// Additive (Laplace) smoothing parameter. Default: `1.0`.
64    pub alpha: F,
65    /// Optional user-supplied class priors. Note: ComplementNB does not
66    /// use priors in the standard way (it uses complement weights), but
67    /// this field is provided for API consistency with other NB variants.
68    pub class_prior: Option<Vec<F>>,
69}
70
71impl<F: Float> ComplementNB<F> {
72    /// Create a new `ComplementNB` with Laplace smoothing (`alpha = 1.0`).
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            alpha: F::one(),
77            class_prior: None,
78        }
79    }
80
81    /// Set the Laplace smoothing parameter.
82    #[must_use]
83    pub fn with_alpha(mut self, alpha: F) -> Self {
84        self.alpha = alpha;
85        self
86    }
87
88    /// Set user-supplied class priors.
89    ///
90    /// The priors must have length equal to the number of classes discovered
91    /// during fitting. Note: ComplementNB uses complement-class weights rather
92    /// than direct class priors, but the priors are stored for API consistency.
93    #[must_use]
94    pub fn with_class_prior(mut self, priors: Vec<F>) -> Self {
95        self.class_prior = Some(priors);
96        self
97    }
98}
99
100impl<F: Float> Default for ComplementNB<F> {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106/// Fitted Complement Naive Bayes classifier.
107#[derive(Debug, Clone)]
108pub struct FittedComplementNB<F> {
109    /// Sorted unique class labels.
110    classes: Vec<usize>,
111    /// Complement weights per class, shape `(n_classes, n_features)`.
112    /// Each entry is `log( (N_~cj + alpha) / (N_~c + alpha * n_features) )`.
113    weights: Array2<F>,
114    /// Raw per-class feature count sums, shape `(n_classes, n_features)`.
115    feature_counts: Array2<F>,
116    /// Per-class sample counts.
117    class_counts: Vec<usize>,
118    /// Smoothing parameter carried forward for partial_fit.
119    alpha: F,
120}
121
122impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ComplementNB<F> {
123    type Fitted = FittedComplementNB<F>;
124    type Error = FerroError;
125
126    /// Fit the Complement NB model.
127    ///
128    /// # Errors
129    ///
130    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers of rows.
131    /// - [`FerroError::InsufficientSamples`] if there are no samples.
132    /// - [`FerroError::InvalidParameter`] if any feature value is negative.
133    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedComplementNB<F>, FerroError> {
134        let (n_samples, n_features) = x.dim();
135
136        if n_samples == 0 {
137            return Err(FerroError::InsufficientSamples {
138                required: 1,
139                actual: 0,
140                context: "ComplementNB requires at least one sample".into(),
141            });
142        }
143
144        if n_samples != y.len() {
145            return Err(FerroError::ShapeMismatch {
146                expected: vec![n_samples],
147                actual: vec![y.len()],
148                context: "y length must match number of samples in X".into(),
149            });
150        }
151
152        // Validate non-negative features.
153        if x.iter().any(|&v| v < F::zero()) {
154            return Err(FerroError::InvalidParameter {
155                name: "X".into(),
156                reason: "ComplementNB requires non-negative feature values".into(),
157            });
158        }
159
160        // Collect sorted unique classes.
161        let mut classes: Vec<usize> = y.to_vec();
162        classes.sort_unstable();
163        classes.dedup();
164        let n_classes = classes.len();
165
166        let n_feat_f = F::from(n_features).unwrap();
167
168        // Compute per-class feature count sums, shape (n_classes, n_features).
169        let mut class_feature_counts = Array2::<F>::zeros((n_classes, n_features));
170        let mut class_counts = vec![0usize; n_classes];
171
172        for (sample_idx, &label) in y.iter().enumerate() {
173            let ci = classes.iter().position(|&c| c == label).unwrap();
174            class_counts[ci] += 1;
175            for j in 0..n_features {
176                class_feature_counts[[ci, j]] = class_feature_counts[[ci, j]] + x[[sample_idx, j]];
177            }
178        }
179
180        // Total feature counts across all classes.
181        let total_feature_counts: Array1<F> = class_feature_counts.rows().into_iter().fold(
182            Array1::<F>::zeros(n_features),
183            |acc, row| {
184                let mut result = acc;
185                for j in 0..n_features {
186                    result[j] = result[j] + row[j];
187                }
188                result
189            },
190        );
191
192        let total_all: F = total_feature_counts.sum();
193
194        // Compute complement weights for each class.
195        let mut weights = Array2::<F>::zeros((n_classes, n_features));
196
197        for ci in 0..n_classes {
198            // Complement counts: sum over all other classes.
199            let complement_total = total_all - class_feature_counts.row(ci).sum();
200
201            let denom = complement_total + self.alpha * n_feat_f;
202
203            for j in 0..n_features {
204                let complement_count_j = total_feature_counts[j] - class_feature_counts[[ci, j]];
205                weights[[ci, j]] = ((complement_count_j + self.alpha) / denom).ln();
206            }
207        }
208
209        // Validate class_prior length if provided.
210        if let Some(ref priors) = self.class_prior {
211            if priors.len() != n_classes {
212                return Err(FerroError::InvalidParameter {
213                    name: "class_prior".into(),
214                    reason: format!(
215                        "length {} does not match number of classes {}",
216                        priors.len(),
217                        n_classes
218                    ),
219                });
220            }
221        }
222
223        Ok(FittedComplementNB {
224            classes,
225            weights,
226            feature_counts: class_feature_counts,
227            class_counts,
228            alpha: self.alpha,
229        })
230    }
231}
232
233impl<F: Float + Send + Sync + 'static> FittedComplementNB<F> {
234    /// Incrementally update the model with new data.
235    ///
236    /// Accumulates feature counts and class counts, then recomputes
237    /// the complement weights.
238    ///
239    /// # Errors
240    ///
241    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different row counts
242    ///   or the number of features does not match the fitted model.
243    /// - [`FerroError::InvalidParameter`] if any feature value is negative.
244    pub fn partial_fit(
245        &mut self,
246        x: &Array2<F>,
247        y: &Array1<usize>,
248    ) -> Result<(), FerroError> {
249        let (n_samples, n_features) = x.dim();
250
251        if n_samples == 0 {
252            return Ok(());
253        }
254
255        if n_samples != y.len() {
256            return Err(FerroError::ShapeMismatch {
257                expected: vec![n_samples],
258                actual: vec![y.len()],
259                context: "y length must match number of samples in X".into(),
260            });
261        }
262
263        if n_features != self.weights.ncols() {
264            return Err(FerroError::ShapeMismatch {
265                expected: vec![self.weights.ncols()],
266                actual: vec![n_features],
267                context: "number of features must match fitted ComplementNB".into(),
268            });
269        }
270
271        if x.iter().any(|&v| v < F::zero()) {
272            return Err(FerroError::InvalidParameter {
273                name: "X".into(),
274                reason: "ComplementNB requires non-negative feature values".into(),
275            });
276        }
277
278        // Accumulate counts for each existing class.
279        for (ci, &class_label) in self.classes.clone().iter().enumerate() {
280            let new_indices: Vec<usize> = y
281                .iter()
282                .enumerate()
283                .filter_map(|(i, &label)| if label == class_label { Some(i) } else { None })
284                .collect();
285
286            if new_indices.is_empty() {
287                continue;
288            }
289
290            self.class_counts[ci] += new_indices.len();
291
292            for &i in &new_indices {
293                for j in 0..n_features {
294                    self.feature_counts[[ci, j]] = self.feature_counts[[ci, j]] + x[[i, j]];
295                }
296            }
297        }
298
299        // Recompute complement weights from accumulated feature_counts.
300        let n_classes = self.classes.len();
301        let n_feat_f = F::from(n_features).unwrap();
302
303        let total_feature_counts: Array1<F> = self.feature_counts.rows().into_iter().fold(
304            Array1::<F>::zeros(n_features),
305            |acc, row| {
306                let mut result = acc;
307                for j in 0..n_features {
308                    result[j] = result[j] + row[j];
309                }
310                result
311            },
312        );
313
314        let total_all: F = total_feature_counts.sum();
315
316        for ci in 0..n_classes {
317            let complement_total = total_all - self.feature_counts.row(ci).sum();
318            let denom = complement_total + self.alpha * n_feat_f;
319            for j in 0..n_features {
320                let complement_count_j =
321                    total_feature_counts[j] - self.feature_counts[[ci, j]];
322                self.weights[[ci, j]] = ((complement_count_j + self.alpha) / denom).ln();
323            }
324        }
325
326        Ok(())
327    }
328
329    /// Compute complement scores for each class.
330    ///
331    /// Returns shape `(n_samples, n_classes)`. Lower is better.
332    fn complement_scores(&self, x: &Array2<F>) -> Array2<F> {
333        let n_samples = x.nrows();
334        let n_classes = self.classes.len();
335        let n_features = x.ncols();
336
337        let mut scores = Array2::<F>::zeros((n_samples, n_classes));
338
339        for i in 0..n_samples {
340            for ci in 0..n_classes {
341                let mut score = F::zero();
342                for j in 0..n_features {
343                    score = score + x[[i, j]] * self.weights[[ci, j]];
344                }
345                scores[[i, ci]] = score;
346            }
347        }
348
349        scores
350    }
351
352    /// Predict class probabilities for the given feature matrix.
353    ///
354    /// Converts complement scores (lower=better) to probabilities by negating
355    /// and applying softmax.
356    ///
357    /// Returns shape `(n_samples, n_classes)` where each row sums to 1.
358    ///
359    /// # Errors
360    ///
361    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
362    /// not match the fitted model.
363    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
364        let n_features_fitted = self.weights.ncols();
365        if x.ncols() != n_features_fitted {
366            return Err(FerroError::ShapeMismatch {
367                expected: vec![n_features_fitted],
368                actual: vec![x.ncols()],
369                context: "number of features must match fitted ComplementNB".into(),
370            });
371        }
372
373        // Negate complement scores so that lower complement score → higher probability.
374        let neg_scores = self.complement_scores(x).mapv(|v| -v);
375        let n_samples = x.nrows();
376        let n_classes = self.classes.len();
377        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
378
379        for i in 0..n_samples {
380            let max_score = neg_scores
381                .row(i)
382                .iter()
383                .fold(F::neg_infinity(), |a, &b| a.max(b));
384
385            let mut row_sum = F::zero();
386            for ci in 0..n_classes {
387                let p = (neg_scores[[i, ci]] - max_score).exp();
388                proba[[i, ci]] = p;
389                row_sum = row_sum + p;
390            }
391            for ci in 0..n_classes {
392                proba[[i, ci]] = proba[[i, ci]] / row_sum;
393            }
394        }
395
396        Ok(proba)
397    }
398}
399
400impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedComplementNB<F> {
401    type Output = Array1<usize>;
402    type Error = FerroError;
403
404    /// Predict class labels for the given feature matrix.
405    ///
406    /// Predicts the class with the *lowest* complement score.
407    ///
408    /// # Errors
409    ///
410    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
411    /// not match the fitted model.
412    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
413        let n_features_fitted = self.weights.ncols();
414        if x.ncols() != n_features_fitted {
415            return Err(FerroError::ShapeMismatch {
416                expected: vec![n_features_fitted],
417                actual: vec![x.ncols()],
418                context: "number of features must match fitted ComplementNB".into(),
419            });
420        }
421
422        let scores = self.complement_scores(x);
423        let n_samples = x.nrows();
424        let n_classes = self.classes.len();
425
426        let mut predictions = Array1::<usize>::zeros(n_samples);
427        for i in 0..n_samples {
428            // Argmin: class with the smallest complement score.
429            let mut best_class = 0;
430            let mut best_score = scores[[i, 0]];
431            for ci in 1..n_classes {
432                if scores[[i, ci]] < best_score {
433                    best_score = scores[[i, ci]];
434                    best_class = ci;
435                }
436            }
437            predictions[i] = self.classes[best_class];
438        }
439
440        Ok(predictions)
441    }
442}
443
444impl<F: Float + Send + Sync + 'static> HasClasses for FittedComplementNB<F> {
445    fn classes(&self) -> &[usize] {
446        &self.classes
447    }
448
449    fn n_classes(&self) -> usize {
450        self.classes.len()
451    }
452}
453
454// Pipeline integration.
455impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
456    for ComplementNB<F>
457{
458    fn fit_pipeline(
459        &self,
460        x: &Array2<F>,
461        y: &Array1<F>,
462    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
463        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
464        let fitted = self.fit(x, &y_usize)?;
465        Ok(Box::new(FittedComplementNBPipeline(fitted)))
466    }
467}
468
469struct FittedComplementNBPipeline<F: Float + Send + Sync + 'static>(FittedComplementNB<F>);
470
471unsafe impl<F: Float + Send + Sync + 'static> Send for FittedComplementNBPipeline<F> {}
472unsafe impl<F: Float + Send + Sync + 'static> Sync for FittedComplementNBPipeline<F> {}
473
474impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
475    for FittedComplementNBPipeline<F>
476{
477    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
478        let preds = self.0.predict(x)?;
479        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use approx::assert_relative_eq;
487    use ndarray::array;
488
489    fn make_count_data() -> (Array2<f64>, Array1<usize>) {
490        let x = Array2::from_shape_vec(
491            (6, 3),
492            vec![
493                5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 0.0,
494                2.0, 6.0,
495            ],
496        )
497        .unwrap();
498        let y = array![0usize, 0, 0, 1, 1, 1];
499        (x, y)
500    }
501
502    #[test]
503    fn test_complement_nb_fit_predict() {
504        let (x, y) = make_count_data();
505        let model = ComplementNB::<f64>::new();
506        let fitted = model.fit(&x, &y).unwrap();
507        let preds = fitted.predict(&x).unwrap();
508        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
509        assert_eq!(correct, 6);
510    }
511
512    #[test]
513    fn test_complement_nb_predict_proba_sums_to_one() {
514        let (x, y) = make_count_data();
515        let model = ComplementNB::<f64>::new();
516        let fitted = model.fit(&x, &y).unwrap();
517        let proba = fitted.predict_proba(&x).unwrap();
518        for i in 0..proba.nrows() {
519            assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
520        }
521    }
522
523    #[test]
524    fn test_complement_nb_has_classes() {
525        let (x, y) = make_count_data();
526        let model = ComplementNB::<f64>::new();
527        let fitted = model.fit(&x, &y).unwrap();
528        assert_eq!(fitted.classes(), &[0, 1]);
529        assert_eq!(fitted.n_classes(), 2);
530    }
531
532    #[test]
533    fn test_complement_nb_shape_mismatch_fit() {
534        let x = Array2::from_shape_vec((4, 3), vec![1.0; 12]).unwrap();
535        let y = array![0usize, 1]; // Wrong length
536        let model = ComplementNB::<f64>::new();
537        assert!(model.fit(&x, &y).is_err());
538    }
539
540    #[test]
541    fn test_complement_nb_shape_mismatch_predict() {
542        let (x, y) = make_count_data();
543        let model = ComplementNB::<f64>::new();
544        let fitted = model.fit(&x, &y).unwrap();
545        let x_bad = Array2::from_shape_vec((3, 5), vec![1.0; 15]).unwrap();
546        assert!(fitted.predict(&x_bad).is_err());
547        assert!(fitted.predict_proba(&x_bad).is_err());
548    }
549
550    #[test]
551    fn test_complement_nb_negative_features_error() {
552        let x =
553            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, -0.5, 3.0, 2.0, 1.0, 0.0, 4.0]).unwrap();
554        let y = array![0usize, 0, 1, 1];
555        let model = ComplementNB::<f64>::new();
556        assert!(model.fit(&x, &y).is_err());
557    }
558
559    #[test]
560    fn test_complement_nb_single_class() {
561        let x = Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
562            .unwrap();
563        let y = array![0usize, 0, 0];
564        let model = ComplementNB::<f64>::new();
565        let fitted = model.fit(&x, &y).unwrap();
566        assert_eq!(fitted.classes(), &[0]);
567        let preds = fitted.predict(&x).unwrap();
568        assert!(preds.iter().all(|&p| p == 0));
569    }
570
571    #[test]
572    fn test_complement_nb_empty_data() {
573        let x = Array2::<f64>::zeros((0, 3));
574        let y = Array1::<usize>::zeros(0);
575        let model = ComplementNB::<f64>::new();
576        assert!(model.fit(&x, &y).is_err());
577    }
578
579    #[test]
580    fn test_complement_nb_default() {
581        let model = ComplementNB::<f64>::default();
582        assert_relative_eq!(model.alpha, 1.0, epsilon = 1e-15);
583    }
584
585    #[test]
586    fn test_complement_nb_imbalanced_data() {
587        // ComplementNB is designed for imbalanced data.
588        // 10 samples of class 0, 2 samples of class 1.
589        let x = Array2::from_shape_vec(
590            (12, 3),
591            vec![
592                5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0,
593                0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 0.0, 1.0,
594                5.0, // class 1
595                0.0, 2.0, 6.0, // class 1
596            ],
597        )
598        .unwrap();
599        let y = array![0usize, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1];
600
601        let model = ComplementNB::<f64>::new();
602        let fitted = model.fit(&x, &y).unwrap();
603        let preds = fitted.predict(&x).unwrap();
604
605        // Class 1 samples should be predicted as class 1.
606        assert_eq!(preds[10], 1);
607        assert_eq!(preds[11], 1);
608    }
609
610    #[test]
611    fn test_complement_nb_partial_fit() {
612        let x1 = Array2::from_shape_vec(
613            (4, 3),
614            vec![5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0],
615        )
616        .unwrap();
617        let y1 = array![0usize, 0, 1, 1];
618
619        let model = ComplementNB::<f64>::new();
620        let mut fitted = model.fit(&x1, &y1).unwrap();
621
622        let x2 = Array2::from_shape_vec(
623            (2, 3),
624            vec![6.0, 0.0, 1.0, 0.0, 2.0, 6.0],
625        )
626        .unwrap();
627        let y2 = array![0usize, 1];
628
629        fitted.partial_fit(&x2, &y2).unwrap();
630
631        let preds = fitted.predict(&x1).unwrap();
632        assert_eq!(preds.len(), 4);
633    }
634
635    #[test]
636    fn test_complement_nb_partial_fit_shape_mismatch() {
637        let (x, y) = make_count_data();
638        let model = ComplementNB::<f64>::new();
639        let mut fitted = model.fit(&x, &y).unwrap();
640
641        let x_bad = Array2::from_shape_vec((2, 5), vec![1.0; 10]).unwrap();
642        let y_bad = array![0usize, 1];
643        assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
644    }
645
646    #[test]
647    fn test_complement_nb_class_prior() {
648        let (x, y) = make_count_data();
649        let model = ComplementNB::<f64>::new().with_class_prior(vec![0.5, 0.5]);
650        let fitted = model.fit(&x, &y).unwrap();
651        let preds = fitted.predict(&x).unwrap();
652        assert_eq!(preds.len(), 6);
653    }
654
655    #[test]
656    fn test_complement_nb_class_prior_wrong_length() {
657        let (x, y) = make_count_data();
658        let model = ComplementNB::<f64>::new().with_class_prior(vec![1.0]);
659        assert!(model.fit(&x, &y).is_err());
660    }
661
662    #[test]
663    fn test_complement_nb_three_classes() {
664        let x = Array2::from_shape_vec(
665            (9, 3),
666            vec![
667                5.0, 0.0, 0.0, 6.0, 0.0, 0.0, 4.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0,
668                4.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0, 4.0,
669            ],
670        )
671        .unwrap();
672        let y = array![0usize, 0, 0, 1, 1, 1, 2, 2, 2];
673
674        let model = ComplementNB::<f64>::new();
675        let fitted = model.fit(&x, &y).unwrap();
676        assert_eq!(fitted.n_classes(), 3);
677        let preds = fitted.predict(&x).unwrap();
678        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
679        assert!(correct >= 7);
680    }
681}