Skip to main content

ferrolearn_bayes/
complement.rs

1//! Complement Naive Bayes classifier.
2//!
3//! This module provides [`ComplementNB`], a variant of Multinomial Naive Bayes
4//! that is particularly well-suited for imbalanced datasets. Instead of estimating
5//! the likelihood of a feature given a class, it estimates the likelihood of the
6//! feature given all *other* (complement) classes and inverts the weights.
7//!
8//! The weight for feature `j` in class `c` is:
9//!
10//! ```text
11//! w_cj = log( (N_~cj + alpha) / (N_~c + alpha * n_features) )
12//! ```
13//!
14//! where `N_~cj` is the total count of feature `j` in all classes except `c`,
15//! and `N_~c` is the total count of all features in all classes except `c`.
16//!
17//! Stores weights with sklearn's sign convention (positive
18//! `-log(complement_prob)`), and prediction uses
19//! `argmax_c sum_j x_j * w_cj` — matching sklearn's
20//! `argmax(X @ feature_log_prob.T)` exactly.
21//!
22//! # Examples
23//!
24//! ```
25//! use ferrolearn_bayes::ComplementNB;
26//! use ferrolearn_core::{Fit, Predict};
27//! use ndarray::{array, Array2};
28//!
29//! let x = Array2::from_shape_vec(
30//!     (6, 3),
31//!     vec![
32//!         5.0, 1.0, 0.0,
33//!         4.0, 2.0, 0.0,
34//!         6.0, 0.0, 1.0,
35//!         0.0, 1.0, 5.0,
36//!         1.0, 0.0, 4.0,
37//!         0.0, 2.0, 6.0,
38//!     ],
39//! ).unwrap();
40//! let y = array![0usize, 0, 0, 1, 1, 1];
41//!
42//! let model = ComplementNB::<f64>::new();
43//! let fitted = model.fit(&x, &y).unwrap();
44//! let preds = fitted.predict(&x).unwrap();
45//! assert_eq!(preds.len(), 6);
46//! ```
47
48use ferrolearn_core::error::FerroError;
49use ferrolearn_core::introspection::HasClasses;
50use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
51use ferrolearn_core::traits::{Fit, Predict};
52use ndarray::{Array1, Array2};
53use num_traits::{Float, FromPrimitive, ToPrimitive};
54
55/// Complement Naive Bayes classifier.
56///
57/// A variant of Multinomial NB that uses complement-class statistics.
58/// More robust for imbalanced datasets.
59///
60/// # Type Parameters
61///
62/// - `F`: The floating-point type (`f32` or `f64`).
63#[derive(Debug, Clone)]
64pub struct ComplementNB<F> {
65    /// Additive (Laplace) smoothing parameter. Default: `1.0`.
66    pub alpha: F,
67    /// Optional user-supplied class priors. Note: ComplementNB does not
68    /// use priors in the standard way (it uses complement weights), but
69    /// this field is provided for API consistency with other NB variants.
70    pub class_prior: Option<Vec<F>>,
71    /// Whether to learn class priors from the data. Stored for API
72    /// consistency; ComplementNB's predict does not consult priors in the
73    /// multi-class case. Default: `true`.
74    pub fit_prior: bool,
75    /// When `false`, `alpha` values below `1e-10` are silently raised to
76    /// `1e-10` (legacy behavior). Default: `true`.
77    pub force_alpha: bool,
78    /// When `true`, performs a second L1 normalization of the weights
79    /// (Rennie et al. 2003 §4.4 "normalized weights" variant). Default:
80    /// `false`.
81    pub norm: bool,
82}
83
84impl<F: Float> ComplementNB<F> {
85    /// Create a new `ComplementNB` with Laplace smoothing (`alpha = 1.0`).
86    #[must_use]
87    pub fn new() -> Self {
88        Self {
89            alpha: F::one(),
90            class_prior: None,
91            fit_prior: true,
92            force_alpha: true,
93            norm: false,
94        }
95    }
96
97    /// Set the Laplace smoothing parameter.
98    #[must_use]
99    pub fn with_alpha(mut self, alpha: F) -> Self {
100        self.alpha = alpha;
101        self
102    }
103
104    /// Set user-supplied class priors.
105    ///
106    /// The priors must have length equal to the number of classes discovered
107    /// during fitting. Note: ComplementNB uses complement-class weights rather
108    /// than direct class priors, but the priors are stored for API consistency.
109    #[must_use]
110    pub fn with_class_prior(mut self, priors: Vec<F>) -> Self {
111        self.class_prior = Some(priors);
112        self
113    }
114
115    /// Toggle `fit_prior`. Stored for API consistency with other discrete NBs.
116    #[must_use]
117    pub fn with_fit_prior(mut self, fit_prior: bool) -> Self {
118        self.fit_prior = fit_prior;
119        self
120    }
121
122    /// Toggle the `force_alpha` policy. See struct field doc.
123    #[must_use]
124    pub fn with_force_alpha(mut self, force_alpha: bool) -> Self {
125        self.force_alpha = force_alpha;
126        self
127    }
128
129    /// Toggle the second L1 normalization on weights (sklearn's `norm`
130    /// parameter; Rennie et al. 2003 §4.4).
131    #[must_use]
132    pub fn with_norm(mut self, norm: bool) -> Self {
133        self.norm = norm;
134        self
135    }
136}
137
138impl<F: Float> Default for ComplementNB<F> {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144/// Fitted Complement Naive Bayes classifier.
145#[derive(Debug, Clone)]
146pub struct FittedComplementNB<F> {
147    /// Sorted unique class labels.
148    classes: Vec<usize>,
149    /// Complement weights per class, shape `(n_classes, n_features)`.
150    /// Each entry is `log( (N_~cj + alpha) / (N_~c + alpha * n_features) )`.
151    weights: Array2<F>,
152    /// Raw per-class feature count sums, shape `(n_classes, n_features)`.
153    feature_counts: Array2<F>,
154    /// Per-class sample counts.
155    class_counts: Vec<usize>,
156    /// Smoothing parameter carried forward for partial_fit (post-clamp
157    /// when `force_alpha=false`).
158    alpha: F,
159    /// Whether to apply the second L1 normalization on weights (carried
160    /// forward for partial_fit).
161    norm: bool,
162}
163
164impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for ComplementNB<F> {
165    type Fitted = FittedComplementNB<F>;
166    type Error = FerroError;
167
168    /// Fit the Complement NB model.
169    ///
170    /// # Errors
171    ///
172    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers of rows.
173    /// - [`FerroError::InsufficientSamples`] if there are no samples.
174    /// - [`FerroError::InvalidParameter`] if any feature value is negative.
175    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedComplementNB<F>, FerroError> {
176        let (n_samples, n_features) = x.dim();
177
178        if n_samples == 0 {
179            return Err(FerroError::InsufficientSamples {
180                required: 1,
181                actual: 0,
182                context: "ComplementNB requires at least one sample".into(),
183            });
184        }
185
186        if n_samples != y.len() {
187            return Err(FerroError::ShapeMismatch {
188                expected: vec![n_samples],
189                actual: vec![y.len()],
190                context: "y length must match number of samples in X".into(),
191            });
192        }
193
194        // Validate non-negative features.
195        if x.iter().any(|&v| v < F::zero()) {
196            return Err(FerroError::InvalidParameter {
197                name: "X".into(),
198                reason: "ComplementNB requires non-negative feature values".into(),
199            });
200        }
201
202        // Collect sorted unique classes.
203        let mut classes: Vec<usize> = y.to_vec();
204        classes.sort_unstable();
205        classes.dedup();
206        let n_classes = classes.len();
207
208        let n_feat_f = F::from(n_features).unwrap();
209        let alpha = crate::clamp_alpha(self.alpha, self.force_alpha);
210
211        // Compute per-class feature count sums, shape (n_classes, n_features).
212        let mut class_feature_counts = Array2::<F>::zeros((n_classes, n_features));
213        let mut class_counts = vec![0usize; n_classes];
214
215        for (sample_idx, &label) in y.iter().enumerate() {
216            let ci = classes.iter().position(|&c| c == label).unwrap();
217            class_counts[ci] += 1;
218            for j in 0..n_features {
219                class_feature_counts[[ci, j]] = class_feature_counts[[ci, j]] + x[[sample_idx, j]];
220            }
221        }
222
223        // Total feature counts across all classes.
224        let total_feature_counts: Array1<F> = class_feature_counts.rows().into_iter().fold(
225            Array1::<F>::zeros(n_features),
226            |acc, row| {
227                let mut result = acc;
228                for j in 0..n_features {
229                    result[j] = result[j] + row[j];
230                }
231                result
232            },
233        );
234
235        let total_all: F = total_feature_counts.sum();
236
237        // Compute complement-log weights for each class. sklearn stores
238        // `feature_log_prob_ = -log((complement_count + alpha) / (total + alpha*n_features))`
239        // (positive values — see #346). ferrolearn previously stored the
240        // pre-negation value; we now match sklearn's convention so
241        // introspection is parity-correct and predict uses argmax.
242        let mut weights = Array2::<F>::zeros((n_classes, n_features));
243
244        for ci in 0..n_classes {
245            // Complement counts: sum over all other classes.
246            let complement_total = total_all - class_feature_counts.row(ci).sum();
247
248            let denom = complement_total + alpha * n_feat_f;
249
250            for j in 0..n_features {
251                let complement_count_j = total_feature_counts[j] - class_feature_counts[[ci, j]];
252                // Negate so the stored value matches sklearn's
253                // `feature_log_prob_` exactly: positive values whose
254                // *smaller* indicates higher complement probability.
255                weights[[ci, j]] = -((complement_count_j + alpha) / denom).ln();
256            }
257        }
258
259        if self.norm {
260            apply_norm_inplace(&mut weights);
261        }
262
263        // Validate class_prior length if provided.
264        if let Some(ref priors) = self.class_prior {
265            if priors.len() != n_classes {
266                return Err(FerroError::InvalidParameter {
267                    name: "class_prior".into(),
268                    reason: format!(
269                        "length {} does not match number of classes {}",
270                        priors.len(),
271                        n_classes
272                    ),
273                });
274            }
275        }
276
277        Ok(FittedComplementNB {
278            classes,
279            weights,
280            feature_counts: class_feature_counts,
281            class_counts,
282            alpha,
283            norm: self.norm,
284        })
285    }
286}
287
288/// Apply sklearn's `norm=True` second L1 normalization to complement weights.
289///
290/// `weights` is already stored as sklearn's positive `-log(complement_prob)`.
291/// sklearn divides each row by its sum so rows sum to 1 (still positive,
292/// since the unnormalised values are positive).
293fn apply_norm_inplace<F: Float>(weights: &mut Array2<F>) {
294    let n_classes = weights.nrows();
295    let n_features = weights.ncols();
296    for ci in 0..n_classes {
297        let row_sum = (0..n_features).fold(F::zero(), |acc, j| acc + weights[[ci, j]]);
298        if row_sum == F::zero() {
299            continue;
300        }
301        for j in 0..n_features {
302            weights[[ci, j]] = weights[[ci, j]] / row_sum;
303        }
304    }
305}
306
307impl<F: Float + Send + Sync + 'static> FittedComplementNB<F> {
308    /// Incrementally update the model with new data.
309    ///
310    /// Accumulates feature counts and class counts, then recomputes
311    /// the complement weights.
312    ///
313    /// # Errors
314    ///
315    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different row counts
316    ///   or the number of features does not match the fitted model.
317    /// - [`FerroError::InvalidParameter`] if any feature value is negative.
318    pub fn partial_fit(&mut self, x: &Array2<F>, y: &Array1<usize>) -> Result<(), FerroError> {
319        let (n_samples, n_features) = x.dim();
320
321        if n_samples == 0 {
322            return Ok(());
323        }
324
325        if n_samples != y.len() {
326            return Err(FerroError::ShapeMismatch {
327                expected: vec![n_samples],
328                actual: vec![y.len()],
329                context: "y length must match number of samples in X".into(),
330            });
331        }
332
333        if n_features != self.weights.ncols() {
334            return Err(FerroError::ShapeMismatch {
335                expected: vec![self.weights.ncols()],
336                actual: vec![n_features],
337                context: "number of features must match fitted ComplementNB".into(),
338            });
339        }
340
341        if x.iter().any(|&v| v < F::zero()) {
342            return Err(FerroError::InvalidParameter {
343                name: "X".into(),
344                reason: "ComplementNB requires non-negative feature values".into(),
345            });
346        }
347
348        // Accumulate counts for each existing class.
349        for (ci, &class_label) in self.classes.clone().iter().enumerate() {
350            let new_indices: Vec<usize> = y
351                .iter()
352                .enumerate()
353                .filter_map(|(i, &label)| if label == class_label { Some(i) } else { None })
354                .collect();
355
356            if new_indices.is_empty() {
357                continue;
358            }
359
360            self.class_counts[ci] += new_indices.len();
361
362            for &i in &new_indices {
363                for j in 0..n_features {
364                    self.feature_counts[[ci, j]] = self.feature_counts[[ci, j]] + x[[i, j]];
365                }
366            }
367        }
368
369        // Recompute complement weights from accumulated feature_counts.
370        let n_classes = self.classes.len();
371        let n_feat_f = F::from(n_features).unwrap();
372
373        let total_feature_counts: Array1<F> = self.feature_counts.rows().into_iter().fold(
374            Array1::<F>::zeros(n_features),
375            |acc, row| {
376                let mut result = acc;
377                for j in 0..n_features {
378                    result[j] = result[j] + row[j];
379                }
380                result
381            },
382        );
383
384        let total_all: F = total_feature_counts.sum();
385
386        for ci in 0..n_classes {
387            let complement_total = total_all - self.feature_counts.row(ci).sum();
388            let denom = complement_total + self.alpha * n_feat_f;
389            for j in 0..n_features {
390                let complement_count_j = total_feature_counts[j] - self.feature_counts[[ci, j]];
391                // sklearn-parity sign: positive -log(complement_prob).
392                self.weights[[ci, j]] = -((complement_count_j + self.alpha) / denom).ln();
393            }
394        }
395
396        if self.norm {
397            apply_norm_inplace(&mut self.weights);
398        }
399
400        Ok(())
401    }
402
403    /// Compute the joint log-likelihood scores for each class.
404    ///
405    /// Returns `X @ feature_log_prob_.T` (shape `(n_samples, n_classes)`).
406    /// With ferrolearn's sklearn-parity sign for `feature_log_prob_`,
407    /// **higher is better** and `argmax(scores, axis=1)` predicts the class.
408    fn complement_scores(&self, x: &Array2<F>) -> Array2<F> {
409        let n_samples = x.nrows();
410        let n_classes = self.classes.len();
411        let n_features = x.ncols();
412
413        let mut scores = Array2::<F>::zeros((n_samples, n_classes));
414
415        for i in 0..n_samples {
416            for ci in 0..n_classes {
417                let mut score = F::zero();
418                for j in 0..n_features {
419                    score = score + x[[i, j]] * self.weights[[ci, j]];
420                }
421                scores[[i, ci]] = score;
422            }
423        }
424
425        scores
426    }
427
428    /// Predict class probabilities for the given feature matrix.
429    ///
430    /// Converts complement scores (lower=better) to probabilities by negating
431    /// and applying softmax.
432    ///
433    /// Returns shape `(n_samples, n_classes)` where each row sums to 1.
434    ///
435    /// # Errors
436    ///
437    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
438    /// not match the fitted model.
439    pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
440        let n_features_fitted = self.weights.ncols();
441        if x.ncols() != n_features_fitted {
442            return Err(FerroError::ShapeMismatch {
443                expected: vec![n_features_fitted],
444                actual: vec![x.ncols()],
445                context: "number of features must match fitted ComplementNB".into(),
446            });
447        }
448
449        // scores are joint log-likelihoods (sklearn convention, higher=better).
450        // softmax directly without negation.
451        let scores = self.complement_scores(x);
452        let n_samples = x.nrows();
453        let n_classes = self.classes.len();
454        let mut proba = Array2::<F>::zeros((n_samples, n_classes));
455
456        for i in 0..n_samples {
457            let max_score = scores
458                .row(i)
459                .iter()
460                .fold(F::neg_infinity(), |a, &b| a.max(b));
461
462            let mut row_sum = F::zero();
463            for ci in 0..n_classes {
464                let p = (scores[[i, ci]] - max_score).exp();
465                proba[[i, ci]] = p;
466                row_sum = row_sum + p;
467            }
468            for ci in 0..n_classes {
469                proba[[i, ci]] = proba[[i, ci]] / row_sum;
470            }
471        }
472
473        Ok(proba)
474    }
475
476    /// Compute the joint log-likelihood scores using sklearn's sign
477    /// convention: argmax(jll) gives the predicted class.
478    ///
479    /// Returns shape `(n_samples, n_classes)`. ComplementNB's complement
480    /// scoring is "lower=better", so this method returns
481    /// `-complement_scores` to match sklearn's convention where higher is
482    /// better. Matches sklearn `ComplementNB._joint_log_likelihood`.
483    ///
484    /// # Errors
485    ///
486    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
487    /// not match the fitted model.
488    pub fn predict_joint_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
489        let n_features_fitted = self.weights.ncols();
490        if x.ncols() != n_features_fitted {
491            return Err(FerroError::ShapeMismatch {
492                expected: vec![n_features_fitted],
493                actual: vec![x.ncols()],
494                context: "number of features must match fitted ComplementNB".into(),
495            });
496        }
497        // With the sklearn-parity sign, complement_scores IS the joint
498        // log-likelihood directly (no negation needed).
499        Ok(self.complement_scores(x))
500    }
501
502    /// Compute log of class probabilities (numerically stable).
503    ///
504    /// Returns shape `(n_samples, n_classes)`.
505    ///
506    /// # Errors
507    ///
508    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
509    /// not match the fitted model.
510    pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
511        let jll = self.predict_joint_log_proba(x)?;
512        Ok(crate::log_softmax_rows(&jll))
513    }
514
515    /// Mean accuracy on the given test data and labels.
516    ///
517    /// Equivalent to sklearn's `ClassifierMixin.score`.
518    ///
519    /// # Errors
520    ///
521    /// Returns [`FerroError::ShapeMismatch`] if `x.nrows() != y.len()` or
522    /// the feature count does not match the fitted model.
523    pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
524        if x.nrows() != y.len() {
525            return Err(FerroError::ShapeMismatch {
526                expected: vec![x.nrows()],
527                actual: vec![y.len()],
528                context: "y length must match number of samples in X".into(),
529            });
530        }
531        let preds = self.predict(x)?;
532        let n = y.len();
533        if n == 0 {
534            return Ok(F::zero());
535        }
536        let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
537        Ok(F::from(correct).unwrap() / F::from(n).unwrap())
538    }
539}
540
541impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedComplementNB<F> {
542    type Output = Array1<usize>;
543    type Error = FerroError;
544
545    /// Predict class labels for the given feature matrix.
546    ///
547    /// Predicts the class with the *lowest* complement score.
548    ///
549    /// # Errors
550    ///
551    /// Returns [`FerroError::ShapeMismatch`] if the number of features does
552    /// not match the fitted model.
553    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
554        let n_features_fitted = self.weights.ncols();
555        if x.ncols() != n_features_fitted {
556            return Err(FerroError::ShapeMismatch {
557                expected: vec![n_features_fitted],
558                actual: vec![x.ncols()],
559                context: "number of features must match fitted ComplementNB".into(),
560            });
561        }
562
563        let scores = self.complement_scores(x);
564        let n_samples = x.nrows();
565        let n_classes = self.classes.len();
566
567        let mut predictions = Array1::<usize>::zeros(n_samples);
568        for i in 0..n_samples {
569            // Argmax: with sklearn-parity sign, higher joint-log-likelihood
570            // wins.
571            let mut best_class = 0;
572            let mut best_score = scores[[i, 0]];
573            for ci in 1..n_classes {
574                if scores[[i, ci]] > best_score {
575                    best_score = scores[[i, ci]];
576                    best_class = ci;
577                }
578            }
579            predictions[i] = self.classes[best_class];
580        }
581
582        Ok(predictions)
583    }
584}
585
586impl<F: Float + Send + Sync + 'static> HasClasses for FittedComplementNB<F> {
587    fn classes(&self) -> &[usize] {
588        &self.classes
589    }
590
591    fn n_classes(&self) -> usize {
592        self.classes.len()
593    }
594}
595
596// Pipeline integration.
597impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
598    for ComplementNB<F>
599{
600    fn fit_pipeline(
601        &self,
602        x: &Array2<F>,
603        y: &Array1<F>,
604    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
605        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
606        let fitted = self.fit(x, &y_usize)?;
607        Ok(Box::new(FittedComplementNBPipeline(fitted)))
608    }
609}
610
611struct FittedComplementNBPipeline<F: Float + Send + Sync + 'static>(FittedComplementNB<F>);
612
613unsafe impl<F: Float + Send + Sync + 'static> Send for FittedComplementNBPipeline<F> {}
614unsafe impl<F: Float + Send + Sync + 'static> Sync for FittedComplementNBPipeline<F> {}
615
616impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
617    for FittedComplementNBPipeline<F>
618{
619    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
620        let preds = self.0.predict(x)?;
621        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
622    }
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628    use approx::assert_relative_eq;
629    use ndarray::array;
630
631    fn make_count_data() -> (Array2<f64>, Array1<usize>) {
632        let x = Array2::from_shape_vec(
633            (6, 3),
634            vec![
635                5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 0.0,
636                2.0, 6.0,
637            ],
638        )
639        .unwrap();
640        let y = array![0usize, 0, 0, 1, 1, 1];
641        (x, y)
642    }
643
644    #[test]
645    fn test_complement_nb_fit_predict() {
646        let (x, y) = make_count_data();
647        let model = ComplementNB::<f64>::new();
648        let fitted = model.fit(&x, &y).unwrap();
649        let preds = fitted.predict(&x).unwrap();
650        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
651        assert_eq!(correct, 6);
652    }
653
654    #[test]
655    fn test_complement_nb_predict_proba_sums_to_one() {
656        let (x, y) = make_count_data();
657        let model = ComplementNB::<f64>::new();
658        let fitted = model.fit(&x, &y).unwrap();
659        let proba = fitted.predict_proba(&x).unwrap();
660        for i in 0..proba.nrows() {
661            assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
662        }
663    }
664
665    #[test]
666    fn test_complement_nb_has_classes() {
667        let (x, y) = make_count_data();
668        let model = ComplementNB::<f64>::new();
669        let fitted = model.fit(&x, &y).unwrap();
670        assert_eq!(fitted.classes(), &[0, 1]);
671        assert_eq!(fitted.n_classes(), 2);
672    }
673
674    #[test]
675    fn test_complement_nb_shape_mismatch_fit() {
676        let x = Array2::from_shape_vec((4, 3), vec![1.0; 12]).unwrap();
677        let y = array![0usize, 1]; // Wrong length
678        let model = ComplementNB::<f64>::new();
679        assert!(model.fit(&x, &y).is_err());
680    }
681
682    #[test]
683    fn test_complement_nb_shape_mismatch_predict() {
684        let (x, y) = make_count_data();
685        let model = ComplementNB::<f64>::new();
686        let fitted = model.fit(&x, &y).unwrap();
687        let x_bad = Array2::from_shape_vec((3, 5), vec![1.0; 15]).unwrap();
688        assert!(fitted.predict(&x_bad).is_err());
689        assert!(fitted.predict_proba(&x_bad).is_err());
690    }
691
692    #[test]
693    fn test_complement_nb_negative_features_error() {
694        let x =
695            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, -0.5, 3.0, 2.0, 1.0, 0.0, 4.0]).unwrap();
696        let y = array![0usize, 0, 1, 1];
697        let model = ComplementNB::<f64>::new();
698        assert!(model.fit(&x, &y).is_err());
699    }
700
701    #[test]
702    fn test_complement_nb_single_class() {
703        let x = Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
704            .unwrap();
705        let y = array![0usize, 0, 0];
706        let model = ComplementNB::<f64>::new();
707        let fitted = model.fit(&x, &y).unwrap();
708        assert_eq!(fitted.classes(), &[0]);
709        let preds = fitted.predict(&x).unwrap();
710        assert!(preds.iter().all(|&p| p == 0));
711    }
712
713    #[test]
714    fn test_complement_nb_empty_data() {
715        let x = Array2::<f64>::zeros((0, 3));
716        let y = Array1::<usize>::zeros(0);
717        let model = ComplementNB::<f64>::new();
718        assert!(model.fit(&x, &y).is_err());
719    }
720
721    #[test]
722    fn test_complement_nb_default() {
723        let model = ComplementNB::<f64>::default();
724        assert_relative_eq!(model.alpha, 1.0, epsilon = 1e-15);
725    }
726
727    #[test]
728    fn test_complement_nb_imbalanced_data() {
729        // ComplementNB is designed for imbalanced data.
730        // 10 samples of class 0, 2 samples of class 1.
731        let x = Array2::from_shape_vec(
732            (12, 3),
733            vec![
734                5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0,
735                0.0, 1.0, 5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 6.0, 0.0, 1.0, 5.0, 1.0, 0.0, 0.0, 1.0,
736                5.0, // class 1
737                0.0, 2.0, 6.0, // class 1
738            ],
739        )
740        .unwrap();
741        let y = array![0usize, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1];
742
743        let model = ComplementNB::<f64>::new();
744        let fitted = model.fit(&x, &y).unwrap();
745        let preds = fitted.predict(&x).unwrap();
746
747        // Class 1 samples should be predicted as class 1.
748        assert_eq!(preds[10], 1);
749        assert_eq!(preds[11], 1);
750    }
751
752    #[test]
753    fn test_complement_nb_partial_fit() {
754        let x1 = Array2::from_shape_vec(
755            (4, 3),
756            vec![5.0, 1.0, 0.0, 4.0, 2.0, 0.0, 0.0, 1.0, 5.0, 1.0, 0.0, 4.0],
757        )
758        .unwrap();
759        let y1 = array![0usize, 0, 1, 1];
760
761        let model = ComplementNB::<f64>::new();
762        let mut fitted = model.fit(&x1, &y1).unwrap();
763
764        let x2 = Array2::from_shape_vec((2, 3), vec![6.0, 0.0, 1.0, 0.0, 2.0, 6.0]).unwrap();
765        let y2 = array![0usize, 1];
766
767        fitted.partial_fit(&x2, &y2).unwrap();
768
769        let preds = fitted.predict(&x1).unwrap();
770        assert_eq!(preds.len(), 4);
771    }
772
773    #[test]
774    fn test_complement_nb_partial_fit_shape_mismatch() {
775        let (x, y) = make_count_data();
776        let model = ComplementNB::<f64>::new();
777        let mut fitted = model.fit(&x, &y).unwrap();
778
779        let x_bad = Array2::from_shape_vec((2, 5), vec![1.0; 10]).unwrap();
780        let y_bad = array![0usize, 1];
781        assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
782    }
783
784    #[test]
785    fn test_complement_nb_class_prior() {
786        let (x, y) = make_count_data();
787        let model = ComplementNB::<f64>::new().with_class_prior(vec![0.5, 0.5]);
788        let fitted = model.fit(&x, &y).unwrap();
789        let preds = fitted.predict(&x).unwrap();
790        assert_eq!(preds.len(), 6);
791    }
792
793    #[test]
794    fn test_complement_nb_class_prior_wrong_length() {
795        let (x, y) = make_count_data();
796        let model = ComplementNB::<f64>::new().with_class_prior(vec![1.0]);
797        assert!(model.fit(&x, &y).is_err());
798    }
799
800    #[test]
801    fn test_complement_nb_three_classes() {
802        let x = Array2::from_shape_vec(
803            (9, 3),
804            vec![
805                5.0, 0.0, 0.0, 6.0, 0.0, 0.0, 4.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0,
806                4.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0, 0.0, 1.0, 4.0,
807            ],
808        )
809        .unwrap();
810        let y = array![0usize, 0, 0, 1, 1, 1, 2, 2, 2];
811
812        let model = ComplementNB::<f64>::new();
813        let fitted = model.fit(&x, &y).unwrap();
814        assert_eq!(fitted.n_classes(), 3);
815        let preds = fitted.predict(&x).unwrap();
816        let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
817        assert!(correct >= 7);
818    }
819}