Skip to main content

ferrolearn_bayes/
complement.rs

1//! Complement Naive Bayes classifier.
2//!
3//! This module provides [`ComplementNB`], a variant of Multinomial Naive Bayes
4//! that is particularly well-suited for imbalanced datasets. Instead of estimating
5//! the likelihood of a feature given a class, it estimates the likelihood of the
6//! feature given all *other* (complement) classes and inverts the weights.
7//!
8//! The weight for feature `j` in class `c` is:
9//!
10//! ```text
11//! w_cj = log( (N_~cj + alpha) / (N_~c + alpha * n_features) )
12//! ```
13//!
14//! where `N_~cj` is the total count of feature `j` in all classes except `c`,
15//! and `N_~c` is the total count of all features in all classes except `c`.
16//!
17//! Prediction uses `argmin_c sum_j x_j * w_cj` (i.e., the class with the
18//! *smallest* complement score is chosen).
19//!
20//! # Examples
21//!
22//! ```
23//! use ferrolearn_bayes::ComplementNB;
24//! use ferrolearn_core::{Fit, Predict};
25//! use ndarray::{array, Array2};
26//!
27//! let x = Array2::from_shape_vec(
28//!     (6, 3),
29//!     vec![
30//!         5.0, 1.0, 0.0,
31//!         4.0, 2.0, 0.0,
32//!         6.0, 0.0, 1.0,
33//!         0.0, 1.0, 5.0,
34//!         1.0, 0.0, 4.0,
35//!         0.0, 2.0, 6.0,
36//!     ],
37//! ).unwrap();
38//! let y = array![0usize, 0, 0, 1, 1, 1];
39//!
40//! let model = ComplementNB::<f64>::new();
41//! let fitted = model.fit(&x, &y).unwrap();
42//! let preds = fitted.predict(&x).unwrap();
43//! assert_eq!(preds.len(), 6);
44//! ```
45
46use ferrolearn_core::error::FerroError;
47use ferrolearn_core::introspection::HasClasses;
48use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
49use ferrolearn_core::traits::{Fit, Predict};
50use ndarray::{Array1, Array2};
51use num_traits::{Float, FromPrimitive, ToPrimitive};
52
53/// Complement Naive Bayes classifier.
54///
55/// A variant of Multinomial NB that uses complement-class statistics.
56/// More robust for imbalanced datasets.
57///
58/// # Type Parameters
59///
60/// - `F`: The floating-point type (`f32` or `f64`).
61#[derive(Debug, Clone)]
62pub struct ComplementNB<F> {
63    /// Additive (Laplace) smoothing parameter. Default: `1.0`.
64    pub alpha: F,
65    /// Optional user-supplied class priors. Note: ComplementNB does not
66    /// use priors in the standard way (it uses complement weights), but
67    /// this field is provided for API consistency with other NB variants.
68    pub class_prior: Option<Vec<F>>,
69    /// Whether to learn class priors from the data. Stored for API
70    /// consistency; ComplementNB's predict does not consult priors in the
71    /// multi-class case. Default: `true`.
72    pub fit_prior: bool,
73    /// When `false`, `alpha` values below `1e-10` are silently raised to
74    /// `1e-10` (legacy behavior). Default: `true`.
75    pub force_alpha: bool,
76    /// When `true`, performs a second L1 normalization of the weights
77    /// (Rennie et al. 2003 §4.4 "normalized weights" variant). Default:
78    /// `false`.
79    pub norm: bool,
80}
81
82impl<F: Float> ComplementNB<F> {
83    /// Create a new `ComplementNB` with Laplace smoothing (`alpha = 1.0`).
84    #[must_use]
85    pub fn new() -> Self {
86        Self {
87            alpha: F::one(),
88            class_prior: None,
89            fit_prior: true,
90            force_alpha: true,
91            norm: false,
92        }
93    }
94
95    /// Set the Laplace smoothing parameter.
96    #[must_use]
97    pub fn with_alpha(mut self, alpha: F) -> Self {
98        self.alpha = alpha;
99        self
100    }
101
102    /// Set user-supplied class priors.
103    ///
104    /// The priors must have length equal to the number of classes discovered
105    /// during fitting. Note: ComplementNB uses complement-class weights rather
106    /// than direct class priors, but the priors are stored for API consistency.
107    #[must_use]
108    pub fn with_class_prior(mut self, priors: Vec<F>) -> Self {
109        self.class_prior = Some(priors);
110        self
111    }
112
113    /// Toggle `fit_prior`. Stored for API consistency with other discrete NBs.
114    #[must_use]
115    pub fn with_fit_prior(mut self, fit_prior: bool) -> Self {
116        self.fit_prior = fit_prior;
117        self
118    }
119
120    /// Toggle the `force_alpha` policy. See struct field doc.
121    #[must_use]
122    pub fn with_force_alpha(mut self, force_alpha: bool) -> Self {
123        self.force_alpha = force_alpha;
124        self
125    }
126
127    /// Toggle the second L1 normalization on weights (sklearn's `norm`
128    /// parameter; Rennie et al. 2003 §4.4).
129    #[must_use]
130    pub fn with_norm(mut self, norm: bool) -> Self {
131        self.norm = norm;
132        self
133    }
134}
135
136impl<F: Float> Default for ComplementNB<F> {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142/// Fitted Complement Naive Bayes classifier.
143#[derive(Debug, Clone)]
144pub struct FittedComplementNB<F> {
145    /// Sorted unique class labels.
146    classes: Vec<usize>,
147    /// Complement weights per class, shape `(n_classes, n_features)`.
148    /// Each entry is `log( (N_~cj + alpha) / (N_~c + alpha * n_features) )`.
149    weights: Array2<F>,
150    /// Raw per-class feature count sums, shape `(n_classes, n_features)`.
151    feature_counts: Array2<F>,
152    /// Per-class sample counts.
153    class_counts: Vec<usize>,
154    /// Smoothing parameter carried forward for partial_fit (post-clamp
155    /// when `force_alpha=false`).
156    alpha: F,
157    /// Whether to apply the second L1 normalization on weights (carried
158    /// forward for partial_fit).
159    norm: bool,
160}
161
162impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ComplementNB<F> {
163    type Fitted = FittedComplementNB<F>;
164    type Error = FerroError;
165
166    /// Fit the Complement NB model.
167    ///
168    /// # Errors
169    ///
170    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers of rows.
171    /// - [`FerroError::InsufficientSamples`] if there are no samples.
172    /// - [`FerroError::InvalidParameter`] if any feature value is negative.
173    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedComplementNB<F>, FerroError> {
174        let (n_samples, n_features) = x.dim();
175
176        if n_samples == 0 {
177            return Err(FerroError::InsufficientSamples {
178                required: 1,
179                actual: 0,
180                context: "ComplementNB requires at least one sample".into(),
181            });
182        }
183
184        if n_samples != y.len() {
185            return Err(FerroError::ShapeMismatch {
186                expected: vec![n_samples],
187                actual: vec![y.len()],
188                context: "y length must match number of samples in X".into(),
189            });
190        }
191
192        // Validate non-negative features.
193        if x.iter().any(|&v| v < F::zero()) {
194            return Err(FerroError::InvalidParameter {
195                name: "X".into(),
196                reason: "ComplementNB requires non-negative feature values".into(),
197            });
198        }
199
200        // Collect sorted unique classes.
201        let mut classes: Vec<usize> = y.to_vec();
202        classes.sort_unstable();
203        classes.dedup();
204        let n_classes = classes.len();
205
206        let n_feat_f = F::from(n_features).unwrap();
207        let alpha = crate::clamp_alpha(self.alpha, self.force_alpha);
208
209        // Compute per-class feature count sums, shape (n_classes, n_features).
210        let mut class_feature_counts = Array2::<F>::zeros((n_classes, n_features));
211        let mut class_counts = vec![0usize; n_classes];
212
213        for (sample_idx, &label) in y.iter().enumerate() {
214            let ci = classes.iter().position(|&c| c == label).unwrap();
215            class_counts[ci] += 1;
216            for j in 0..n_features {
217                class_feature_counts[[ci, j]] = class_feature_counts[[ci, j]] + x[[sample_idx, j]];
218            }
219        }
220
221        // Total feature counts across all classes.
222        let total_feature_counts: Array1<F> = class_feature_counts.rows().into_iter().fold(
223            Array1::<F>::zeros(n_features),
224            |acc, row| {
225                let mut result = acc;
226                for j in 0..n_features {
227                    result[j] = result[j] + row[j];
228                }
229                result
230            },
231        );
232
233        let total_all: F = total_feature_counts.sum();
234
235        // Compute complement weights for each class.
236        let mut weights = Array2::<F>::zeros((n_classes, n_features));
237
238        for ci in 0..n_classes {
239            // Complement counts: sum over all other classes.
240            let complement_total = total_all - class_feature_counts.row(ci).sum();
241
242            let denom = complement_total + alpha * n_feat_f;
243
244            for j in 0..n_features {
245                let complement_count_j = total_feature_counts[j] - class_feature_counts[[ci, j]];
246                weights[[ci, j]] = ((complement_count_j + alpha) / denom).ln();
247            }
248        }
249
250        if self.norm {
251            apply_norm_inplace(&mut weights);
252        }
253
254        // Validate class_prior length if provided.
255        if let Some(ref priors) = self.class_prior {
256            if priors.len() != n_classes {
257                return Err(FerroError::InvalidParameter {
258                    name: "class_prior".into(),
259                    reason: format!(
260                        "length {} does not match number of classes {}",
261                        priors.len(),
262                        n_classes
263                    ),
264                });
265            }
266        }
267
268        Ok(FittedComplementNB {
269            classes,
270            weights,
271            feature_counts: class_feature_counts,
272            class_counts,
273            alpha,
274            norm: self.norm,
275        })
276    }
277}
278
279/// Apply sklearn's `norm=True` second L1 normalization to complement
280/// weights, then negate so ferrolearn's `argmin` predict semantics keep
281/// matching sklearn's `argmax(X @ feature_log_prob.T)`.
282///
283/// Walks each row of `weights` (= sklearn's pre-negation `logged`):
284/// - row_sum = Σ_j logged[c, j]    (≤ 0 since logs of probabilities)
285/// - normalized = logged / row_sum  (≥ 0, rows sum to 1)
286/// - weights[c, j] = -normalized    (re-negated to keep argmin convention)
287fn apply_norm_inplace<F: Float>(weights: &mut Array2<F>) {
288    let n_classes = weights.nrows();
289    let n_features = weights.ncols();
290    for ci in 0..n_classes {
291        let row_sum = (0..n_features).fold(F::zero(), |acc, j| acc + weights[[ci, j]]);
292        if row_sum == F::zero() {
293            continue;
294        }
295        for j in 0..n_features {
296            weights[[ci, j]] = -(weights[[ci, j]] / row_sum);
297        }
298    }
299}
300
301impl<F: Float + Send + Sync + 'static> FittedComplementNB<F> {
302    /// Incrementally update the model with new data.
303    ///
304    /// Accumulates feature counts and class counts, then recomputes
305    /// the complement weights.
306    ///
307    /// # Errors
308    ///
309    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different row counts
310    ///   or the number of features does not match the fitted model.
311    /// - [`FerroError::InvalidParameter`] if any feature value is negative.
312    pub fn partial_fit(&mut self, x: &Array2<F>, y: &Array1<usize>) -> Result<(), FerroError> {
313        let (n_samples, n_features) = x.dim();
314
315        if n_samples == 0 {
316            return Ok(());
317        }
318
319        if n_samples != y.len() {
320            return Err(FerroError::ShapeMismatch {
321                expected: vec![n_samples],
322                actual: vec![y.len()],
323                context: "y length must match number of samples in X".into(),
324            });
325        }
326
327        if n_features != self.weights.ncols() {
328            return Err(FerroError::ShapeMismatch {
329                expected: vec![self.weights.ncols()],
330                actual: vec![n_features],
331                context: "number of features must match fitted ComplementNB".into(),
332            });
333        }
334
335        if x.iter().any(|&v| v < F::zero()) {
336            return Err(FerroError::InvalidParameter {
337                name: "X".into(),
338                reason: "ComplementNB requires non-negative feature values".into(),
339            });
340        }
341
342        // Accumulate counts for each existing class.
343        for (ci, &class_label) in self.classes.clone().iter().enumerate() {
344            let new_indices: Vec<usize> = y
345                .iter()
346                .enumerate()
347                .filter_map(|(i, &label)| if label == class_label { Some(i) } else { None })
348                .collect();
349
350            if new_indices.is_empty() {
351                continue;
352            }
353
354            self.class_counts[ci] += new_indices.len();
355
356            for &i in &new_indices {
357                for j in 0..n_features {
358                    self.feature_counts[[ci, j]] = self.feature_counts[[ci, j]] + x[[i, j]];
359                }
360            }
361        }
362
363        // Recompute complement weights from accumulated feature_counts.
364        let n_classes = self.classes.len();
365        let n_feat_f = F::from(n_features).unwrap();
366
367        let total_feature_counts: Array1<F> = self.feature_counts.rows().into_iter().fold(
368            Array1::<F>::zeros(n_features),
369            |acc, row| {
370                let mut result = acc;
371                for j in 0..n_features {
372                    result[j] = result[j] + row[j];
373                }
374                result
375            },
376        );
377
378        let total_all: F = total_feature_counts.sum();
379
380        for ci in 0..n_classes {
381            let complement_total = total_all - self.feature_counts.row(ci).sum();
382            let denom = complement_total + self.alpha * n_feat_f;
383            for j in 0..n_features {
384                let complement_count_j = total_feature_counts[j] - self.feature_counts[[ci, j]];
385                self.weights[[ci, j]] = ((complement_count_j + self.alpha) / denom).ln();
386            }
387        }
388
389        if self.norm {
390            apply_norm_inplace(&mut self.weights);
391        }
392
393        Ok(())
394    }
395
396    /// Compute complement scores for each class.
397    ///
398    /// Returns shape `(n_samples, n_classes)`. Lower is better.
399    fn complement_scores(&self, x: &Array2<F>) -> Array2<F> {
400        let n_samples = x.nrows();
401        let n_classes = self.classes.len();
402        let n_features = x.ncols();
403
404        let mut scores = Array2::<F>::zeros((n_samples, n_classes));
405
406        for i in 0..n_samples {
407            for ci in 0..n_classes {
408                let mut score = F::zero();
409                for j in 0..n_features {
410                    score = score + x[[i, j]] * self.weights[[ci, j]];
411                }
412                scores[[i, ci]] = score;
413            }
414        }
415
416        scores
417    }
418
419    /// Predict class probabilities for the given feature matrix.
420    ///
421    /// Converts complement scores (lower=better) to probabilities by negating
422    /// and applying softmax.
423    ///
424    /// Returns shape `(n_samples, n_classes)` where each row sums to 1.
425    ///
426    /// # Errors
427    ///
428    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
429    /// not match the fitted model.
430    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
431        let n_features_fitted = self.weights.ncols();
432        if x.ncols() != n_features_fitted {
433            return Err(FerroError::ShapeMismatch {
434                expected: vec![n_features_fitted],
435                actual: vec![x.ncols()],
436                context: "number of features must match fitted ComplementNB".into(),
437            });
438        }
439
440        // Negate complement scores so that lower complement score → higher probability.
441        let neg_scores = self.complement_scores(x).mapv(|v| -v);
442        let n_samples = x.nrows();
443        let n_classes = self.classes.len();
444        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
445
446        for i in 0..n_samples {
447            let max_score = neg_scores
448                .row(i)
449                .iter()
450                .fold(F::neg_infinity(), |a, &b| a.max(b));
451
452            let mut row_sum = F::zero();
453            for ci in 0..n_classes {
454                let p = (neg_scores[[i, ci]] - max_score).exp();
455                proba[[i, ci]] = p;
456                row_sum = row_sum + p;
457            }
458            for ci in 0..n_classes {
459                proba[[i, ci]] = proba[[i, ci]] / row_sum;
460            }
461        }
462
463        Ok(proba)
464    }
465
466    /// Compute the joint log-likelihood scores using sklearn's sign
467    /// convention: argmax(jll) gives the predicted class.
468    ///
469    /// Returns shape `(n_samples, n_classes)`. ComplementNB's complement
470    /// scoring is "lower=better", so this method returns
471    /// `-complement_scores` to match sklearn's convention where higher is
472    /// better. Matches sklearn `ComplementNB._joint_log_likelihood`.
473    ///
474    /// # Errors
475    ///
476    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
477    /// not match the fitted model.
478    pub fn predict_joint_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
479        let n_features_fitted = self.weights.ncols();
480        if x.ncols() != n_features_fitted {
481            return Err(FerroError::ShapeMismatch {
482                expected: vec![n_features_fitted],
483                actual: vec![x.ncols()],
484                context: "number of features must match fitted ComplementNB".into(),
485            });
486        }
487        Ok(self.complement_scores(x).mapv(|v| -v))
488    }
489
490    /// Compute log of class probabilities (numerically stable).
491    ///
492    /// Returns shape `(n_samples, n_classes)`.
493    ///
494    /// # Errors
495    ///
496    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
497    /// not match the fitted model.
498    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
499        let jll = self.predict_joint_log_proba(x)?;
500        Ok(crate::log_softmax_rows(&jll))
501    }
502
503    /// Mean accuracy on the given test data and labels.
504    ///
505    /// Equivalent to sklearn's `ClassifierMixin.score`.
506    ///
507    /// # Errors
508    ///
509    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
510    /// the feature count does not match the fitted model.
511    pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
512        if x.nrows() != y.len() {
513            return Err(FerroError::ShapeMismatch {
514                expected: vec![x.nrows()],
515                actual: vec![y.len()],
516                context: "y length must match number of samples in X".into(),
517            });
518        }
519        let preds = self.predict(x)?;
520        let n = y.len();
521        if n == 0 {
522            return Ok(F::zero());
523        }
524        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
525        Ok(F::from(correct).unwrap() / F::from(n).unwrap())
526    }
527}
528
529impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedComplementNB<F> {
530    type Output = Array1<usize>;
531    type Error = FerroError;
532
533    /// Predict class labels for the given feature matrix.
534    ///
535    /// Predicts the class with the *lowest* complement score.
536    ///
537    /// # Errors
538    ///
539    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
540    /// not match the fitted model.
541    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
542        let n_features_fitted = self.weights.ncols();
543        if x.ncols() != n_features_fitted {
544            return Err(FerroError::ShapeMismatch {
545                expected: vec![n_features_fitted],
546                actual: vec![x.ncols()],
547                context: "number of features must match fitted ComplementNB".into(),
548            });
549        }
550
551        let scores = self.complement_scores(x);
552        let n_samples = x.nrows();
553        let n_classes = self.classes.len();
554
555        let mut predictions = Array1::<usize>::zeros(n_samples);
556        for i in 0..n_samples {
557            // Argmin: class with the smallest complement score.
558            let mut best_class = 0;
559            let mut best_score = scores[[i, 0]];
560            for ci in 1..n_classes {
561                if scores[[i, ci]] < best_score {
562                    best_score = scores[[i, ci]];
563                    best_class = ci;
564                }
565            }
566            predictions[i] = self.classes[best_class];
567        }
568
569        Ok(predictions)
570    }
571}
572
573impl<F: Float + Send + Sync + 'static> HasClasses for FittedComplementNB<F> {
574    fn classes(&self) -> &[usize] {
575        &self.classes
576    }
577
578    fn n_classes(&self) -> usize {
579        self.classes.len()
580    }
581}
582
583// Pipeline integration.
584impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
585    for ComplementNB<F>
586{
587    fn fit_pipeline(
588        &self,
589        x: &Array2<F>,
590        y: &Array1<F>,
591    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
592        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
593        let fitted = self.fit(x, &y_usize)?;
594        Ok(Box::new(FittedComplementNBPipeline(fitted)))
595    }
596}
597
598struct FittedComplementNBPipeline<F: Float + Send + Sync + 'static>(FittedComplementNB<F>);
599
600unsafe impl<F: Float + Send + Sync + 'static> Send for FittedComplementNBPipeline<F> {}
601unsafe impl<F: Float + Send + Sync + 'static> Sync for FittedComplementNBPipeline<F> {}
602
603impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
604    for FittedComplementNBPipeline<F>
605{
606    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
607        let preds = self.0.predict(x)?;
608        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615    use approx::assert_relative_eq;
616    use ndarray::array;
617
618    fn make_count_data() -> (Array2<f64>, Array1<usize>) {
619        let x = Array2::from_shape_vec(
620            (6, 3),
621            vec![
622                5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 0.0,
623                2.0, 6.0,
624            ],
625        )
626        .unwrap();
627        let y = array![0usize, 0, 0, 1, 1, 1];
628        (x, y)
629    }
630
631    #[test]
632    fn test_complement_nb_fit_predict() {
633        let (x, y) = make_count_data();
634        let model = ComplementNB::<f64>::new();
635        let fitted = model.fit(&x, &y).unwrap();
636        let preds = fitted.predict(&x).unwrap();
637        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
638        assert_eq!(correct, 6);
639    }
640
641    #[test]
642    fn test_complement_nb_predict_proba_sums_to_one() {
643        let (x, y) = make_count_data();
644        let model = ComplementNB::<f64>::new();
645        let fitted = model.fit(&x, &y).unwrap();
646        let proba = fitted.predict_proba(&x).unwrap();
647        for i in 0..proba.nrows() {
648            assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
649        }
650    }
651
652    #[test]
653    fn test_complement_nb_has_classes() {
654        let (x, y) = make_count_data();
655        let model = ComplementNB::<f64>::new();
656        let fitted = model.fit(&x, &y).unwrap();
657        assert_eq!(fitted.classes(), &[0, 1]);
658        assert_eq!(fitted.n_classes(), 2);
659    }
660
661    #[test]
662    fn test_complement_nb_shape_mismatch_fit() {
663        let x = Array2::from_shape_vec((4, 3), vec![1.0; 12]).unwrap();
664        let y = array![0usize, 1]; // Wrong length
665        let model = ComplementNB::<f64>::new();
666        assert!(model.fit(&x, &y).is_err());
667    }
668
669    #[test]
670    fn test_complement_nb_shape_mismatch_predict() {
671        let (x, y) = make_count_data();
672        let model = ComplementNB::<f64>::new();
673        let fitted = model.fit(&x, &y).unwrap();
674        let x_bad = Array2::from_shape_vec((3, 5), vec![1.0; 15]).unwrap();
675        assert!(fitted.predict(&x_bad).is_err());
676        assert!(fitted.predict_proba(&x_bad).is_err());
677    }
678
679    #[test]
680    fn test_complement_nb_negative_features_error() {
681        let x =
682            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, -0.5, 3.0, 2.0, 1.0, 0.0, 4.0]).unwrap();
683        let y = array![0usize, 0, 1, 1];
684        let model = ComplementNB::<f64>::new();
685        assert!(model.fit(&x, &y).is_err());
686    }
687
688    #[test]
689    fn test_complement_nb_single_class() {
690        let x = Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
691            .unwrap();
692        let y = array![0usize, 0, 0];
693        let model = ComplementNB::<f64>::new();
694        let fitted = model.fit(&x, &y).unwrap();
695        assert_eq!(fitted.classes(), &[0]);
696        let preds = fitted.predict(&x).unwrap();
697        assert!(preds.iter().all(|&p| p == 0));
698    }
699
700    #[test]
701    fn test_complement_nb_empty_data() {
702        let x = Array2::<f64>::zeros((0, 3));
703        let y = Array1::<usize>::zeros(0);
704        let model = ComplementNB::<f64>::new();
705        assert!(model.fit(&x, &y).is_err());
706    }
707
708    #[test]
709    fn test_complement_nb_default() {
710        let model = ComplementNB::<f64>::default();
711        assert_relative_eq!(model.alpha, 1.0, epsilon = 1e-15);
712    }
713
714    #[test]
715    fn test_complement_nb_imbalanced_data() {
716        // ComplementNB is designed for imbalanced data.
717        // 10 samples of class 0, 2 samples of class 1.
718        let x = Array2::from_shape_vec(
719            (12, 3),
720            vec![
721                5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0,
722                0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 0.0, 1.0,
723                5.0, // class 1
724                0.0, 2.0, 6.0, // class 1
725            ],
726        )
727        .unwrap();
728        let y = array![0usize, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1];
729
730        let model = ComplementNB::<f64>::new();
731        let fitted = model.fit(&x, &y).unwrap();
732        let preds = fitted.predict(&x).unwrap();
733
734        // Class 1 samples should be predicted as class 1.
735        assert_eq!(preds[10], 1);
736        assert_eq!(preds[11], 1);
737    }
738
739    #[test]
740    fn test_complement_nb_partial_fit() {
741        let x1 = Array2::from_shape_vec(
742            (4, 3),
743            vec![5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0],
744        )
745        .unwrap();
746        let y1 = array![0usize, 0, 1, 1];
747
748        let model = ComplementNB::<f64>::new();
749        let mut fitted = model.fit(&x1, &y1).unwrap();
750
751        let x2 = Array2::from_shape_vec((2, 3), vec![6.0, 0.0, 1.0, 0.0, 2.0, 6.0]).unwrap();
752        let y2 = array![0usize, 1];
753
754        fitted.partial_fit(&x2, &y2).unwrap();
755
756        let preds = fitted.predict(&x1).unwrap();
757        assert_eq!(preds.len(), 4);
758    }
759
760    #[test]
761    fn test_complement_nb_partial_fit_shape_mismatch() {
762        let (x, y) = make_count_data();
763        let model = ComplementNB::<f64>::new();
764        let mut fitted = model.fit(&x, &y).unwrap();
765
766        let x_bad = Array2::from_shape_vec((2, 5), vec![1.0; 10]).unwrap();
767        let y_bad = array![0usize, 1];
768        assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
769    }
770
771    #[test]
772    fn test_complement_nb_class_prior() {
773        let (x, y) = make_count_data();
774        let model = ComplementNB::<f64>::new().with_class_prior(vec![0.5, 0.5]);
775        let fitted = model.fit(&x, &y).unwrap();
776        let preds = fitted.predict(&x).unwrap();
777        assert_eq!(preds.len(), 6);
778    }
779
780    #[test]
781    fn test_complement_nb_class_prior_wrong_length() {
782        let (x, y) = make_count_data();
783        let model = ComplementNB::<f64>::new().with_class_prior(vec![1.0]);
784        assert!(model.fit(&x, &y).is_err());
785    }
786
787    #[test]
788    fn test_complement_nb_three_classes() {
789        let x = Array2::from_shape_vec(
790            (9, 3),
791            vec![
792                5.0, 0.0, 0.0, 6.0, 0.0, 0.0, 4.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0,
793                4.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0, 4.0,
794            ],
795        )
796        .unwrap();
797        let y = array![0usize, 0, 0, 1, 1, 1, 2, 2, 2];
798
799        let model = ComplementNB::<f64>::new();
800        let fitted = model.fit(&x, &y).unwrap();
801        assert_eq!(fitted.n_classes(), 3);
802        let preds = fitted.predict(&x).unwrap();
803        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
804        assert!(correct >= 7);
805    }
806}