use ferrolearn_core::error::FerroError;
use ferrolearn_core::introspection::HasClasses;
use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
use ferrolearn_core::traits::{Fit, Predict};
use ndarray::{Array1, Array2};
use num_traits::{Float, FromPrimitive, ToPrimitive};
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub enum MinCategories {
Scalar(usize),
PerFeature(Vec<usize>),
}
#[derive(Debug, Clone)]
pub struct CategoricalNB<F: Float + Send + Sync + 'static> {
alpha: F,
class_prior: Option<Vec<F>>,
fit_prior: bool,
force_alpha: bool,
min_categories: Option<MinCategories>,
_marker: PhantomData<F>,
}
impl<F: Float + Send + Sync + 'static> CategoricalNB<F> {
#[must_use]
pub fn new() -> Self {
Self {
alpha: F::one(),
class_prior: None,
fit_prior: true,
force_alpha: true,
min_categories: None,
_marker: PhantomData,
}
}
#[must_use]
pub fn with_alpha(mut self, alpha: F) -> Self {
self.alpha = alpha;
self
}
#[must_use]
pub fn with_class_prior(mut self, priors: Vec<F>) -> Self {
self.class_prior = Some(priors);
self
}
#[must_use]
pub fn with_fit_prior(mut self, fit_prior: bool) -> Self {
self.fit_prior = fit_prior;
self
}
#[must_use]
pub fn with_force_alpha(mut self, force_alpha: bool) -> Self {
self.force_alpha = force_alpha;
self
}
#[must_use]
pub fn with_min_categories(mut self, min: usize) -> Self {
self.min_categories = Some(MinCategories::Scalar(min));
self
}
#[must_use]
pub fn with_min_categories_per_feature(mut self, mins: Vec<usize>) -> Self {
self.min_categories = Some(MinCategories::PerFeature(mins));
self
}
}
impl<F: Float + Send + Sync + 'static> Default for CategoricalNB<F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct FittedCategoricalNB<F: Float + Send + Sync + 'static> {
classes: Vec<usize>,
class_log_prior: Vec<F>,
feature_log_prob: Vec<Vec<Vec<F>>>,
category_counts: Vec<Vec<Vec<usize>>>,
categories: Vec<Vec<usize>>,
class_counts: Vec<usize>,
n_features: usize,
alpha: F,
class_prior: Option<Vec<F>>,
fit_prior: bool,
}
impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, Array1<usize>> for CategoricalNB<F> {
type Fitted = FittedCategoricalNB<F>;
type Error = FerroError;
fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedCategoricalNB<F>, FerroError> {
let (n_samples, n_features) = x.dim();
if n_samples == 0 {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "CategoricalNB requires at least one sample".into(),
});
}
if n_samples != y.len() {
return Err(FerroError::ShapeMismatch {
expected: vec![n_samples],
actual: vec![y.len()],
context: "y length must match number of samples in X".into(),
});
}
if self.alpha <= F::zero() && self.force_alpha {
return Err(FerroError::InvalidParameter {
name: "alpha".into(),
reason: "alpha must be positive for CategoricalNB \
(or set force_alpha=false to clamp to 1e-10)"
.into(),
});
}
let alpha = crate::clamp_alpha(self.alpha, self.force_alpha);
if let Some(MinCategories::PerFeature(ref mins)) = self.min_categories
&& mins.len() != n_features
{
return Err(FerroError::InvalidParameter {
name: "min_categories".into(),
reason: format!(
"PerFeature length {} does not match n_features {}",
mins.len(),
n_features
),
});
}
let mut classes: Vec<usize> = y.to_vec();
classes.sort_unstable();
classes.dedup();
let n_classes = classes.len();
let mut class_counts = vec![0usize; n_classes];
let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); n_classes];
for (sample_idx, &label) in y.iter().enumerate() {
let ci = classes.iter().position(|&c| c == label).unwrap();
class_counts[ci] += 1;
class_indices[ci].push(sample_idx);
}
let mut category_counts: Vec<Vec<Vec<usize>>> = Vec::with_capacity(n_features);
let mut categories_per_feature: Vec<Vec<usize>> = Vec::with_capacity(n_features);
for j in 0..n_features {
let mut cats: Vec<usize> = Vec::new();
for i in 0..n_samples {
let val = x[[i, j]].to_usize().unwrap_or(0);
cats.push(val);
}
cats.sort_unstable();
cats.dedup();
let min_cats_j = match self.min_categories {
Some(MinCategories::Scalar(m)) => m,
Some(MinCategories::PerFeature(ref v)) => v[j],
None => 0,
};
if min_cats_j > 0 {
for cv in 0..min_cats_j {
if cats.binary_search(&cv).is_err() {
let pos = cats.partition_point(|&c| c < cv);
cats.insert(pos, cv);
}
}
}
let mut counts_for_feature: Vec<Vec<usize>> =
vec![vec![0usize; cats.len()]; n_classes];
for (ci, indices) in class_indices.iter().enumerate() {
for &sample_idx in indices {
let val = x[[sample_idx, j]].to_usize().unwrap_or(0);
if let Ok(cat_idx) = cats.binary_search(&val) {
counts_for_feature[ci][cat_idx] += 1;
}
}
}
category_counts.push(counts_for_feature);
categories_per_feature.push(cats);
}
let feature_log_prob =
recompute_feature_log_prob(&category_counts, &class_counts, alpha);
let class_log_prior =
resolve_class_log_prior(&class_counts, n_classes, &self.class_prior, self.fit_prior)?;
Ok(FittedCategoricalNB {
classes,
class_log_prior,
feature_log_prob,
category_counts,
categories: categories_per_feature,
class_counts,
n_features,
alpha,
class_prior: self.class_prior.clone(),
fit_prior: self.fit_prior,
})
}
}
fn recompute_feature_log_prob<F: Float>(
category_counts: &[Vec<Vec<usize>>],
class_counts: &[usize],
alpha: F,
) -> Vec<Vec<Vec<F>>> {
let n_features = category_counts.len();
let n_classes = class_counts.len();
let mut out: Vec<Vec<Vec<F>>> = Vec::with_capacity(n_features);
for j in 0..n_features {
let n_cats = category_counts[j].first().map_or(0, Vec::len);
let n_cats_f = F::from(n_cats).unwrap();
let mut per_class: Vec<Vec<F>> = Vec::with_capacity(n_classes);
for ci in 0..n_classes {
let n_c_f = F::from(class_counts[ci]).unwrap();
let denom = n_c_f + alpha * n_cats_f;
let mut row: Vec<F> = Vec::with_capacity(n_cats);
for k in 0..n_cats {
let count_f = F::from(category_counts[j][ci][k]).unwrap();
row.push(((count_f + alpha) / denom).ln());
}
per_class.push(row);
}
out.push(per_class);
}
out
}
fn resolve_class_log_prior<F: Float>(
class_counts: &[usize],
n_classes: usize,
class_prior: &Option<Vec<F>>,
fit_prior: bool,
) -> Result<Vec<F>, FerroError> {
let mut out = vec![F::zero(); n_classes];
if let Some(priors) = class_prior {
if priors.len() != n_classes {
return Err(FerroError::InvalidParameter {
name: "class_prior".into(),
reason: format!(
"length {} does not match number of classes {}",
priors.len(),
n_classes
),
});
}
for (ci, &p) in priors.iter().enumerate() {
out[ci] = p.ln();
}
} else if fit_prior {
let total: usize = class_counts.iter().sum();
let total_f = F::from(total).unwrap();
for (ci, &c) in class_counts.iter().enumerate() {
out[ci] = (F::from(c).unwrap() / total_f).ln();
}
} else {
let uniform = (F::one() / F::from(n_classes).unwrap()).ln();
for ci in 0..n_classes {
out[ci] = uniform;
}
}
Ok(out)
}
impl<F: Float + Send + Sync + 'static> FittedCategoricalNB<F> {
pub fn partial_fit(&mut self, x: &Array2<F>, y: &Array1<usize>) -> Result<(), FerroError> {
let (n_samples, n_features) = x.dim();
if n_samples == 0 {
return Ok(());
}
if n_samples != y.len() {
return Err(FerroError::ShapeMismatch {
expected: vec![n_samples],
actual: vec![y.len()],
context: "y length must match number of samples in X".into(),
});
}
if n_features != self.n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features],
actual: vec![n_features],
context: "number of features must match fitted CategoricalNB".into(),
});
}
for &label in y {
if self.classes.binary_search(&label).is_err() {
let pos = self.classes.partition_point(|&c| c < label);
self.classes.insert(pos, label);
self.class_counts.insert(pos, 0);
self.class_log_prior.insert(pos, F::neg_infinity());
for j in 0..self.n_features {
let n_cats = self.categories[j].len();
self.category_counts[j].insert(pos, vec![0usize; n_cats]);
self.feature_log_prob[j].insert(pos, vec![F::zero(); n_cats]);
}
}
}
for sample_idx in 0..n_samples {
let label = y[sample_idx];
let ci = self.classes.binary_search(&label).unwrap();
self.class_counts[ci] += 1;
for j in 0..self.n_features {
let val = x[[sample_idx, j]].to_usize().unwrap_or(0);
let cat_idx = match self.categories[j].binary_search(&val) {
Ok(idx) => idx,
Err(insert_pos) => {
self.categories[j].insert(insert_pos, val);
for c in 0..self.classes.len() {
self.category_counts[j][c].insert(insert_pos, 0);
}
insert_pos
}
};
self.category_counts[j][ci][cat_idx] += 1;
}
}
self.feature_log_prob =
recompute_feature_log_prob(&self.category_counts, &self.class_counts, self.alpha);
self.class_log_prior = resolve_class_log_prior(
&self.class_counts,
self.classes.len(),
&self.class_prior,
self.fit_prior,
)?;
Ok(())
}
fn log_prob_for(&self, feature_idx: usize, class_idx: usize, cat_value: usize) -> F {
let cats = &self.categories[feature_idx];
if let Ok(cat_idx) = cats.binary_search(&cat_value) {
self.feature_log_prob[feature_idx][class_idx][cat_idx]
} else {
let n_cats_plus_one = F::from(cats.len() + 1).unwrap();
(F::one() / n_cats_plus_one).ln()
}
}
fn joint_log_likelihood(&self, x: &Array2<F>) -> Array2<F> {
let n_samples = x.nrows();
let n_classes = self.classes.len();
let mut scores = Array2::<F>::zeros((n_samples, n_classes));
for i in 0..n_samples {
for ci in 0..n_classes {
let mut score = self.class_log_prior[ci];
for j in 0..self.n_features {
let cat_value = x[[i, j]].to_usize().unwrap_or(0);
score = score + self.log_prob_for(j, ci, cat_value);
}
scores[[i, ci]] = score;
}
}
scores
}
pub fn predict_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
if x.ncols() != self.n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features],
actual: vec![x.ncols()],
context: "number of features must match fitted CategoricalNB".into(),
});
}
let log_scores = self.joint_log_likelihood(x);
let n_samples = x.nrows();
let n_classes = self.classes.len();
let mut proba = Array2::<F>::zeros((n_samples, n_classes));
for i in 0..n_samples {
let max_score = log_scores
.row(i)
.iter()
.fold(F::neg_infinity(), |a, &b| a.max(b));
let mut row_sum = F::zero();
for ci in 0..n_classes {
let p = (log_scores[[i, ci]] - max_score).exp();
proba[[i, ci]] = p;
row_sum = row_sum + p;
}
for ci in 0..n_classes {
proba[[i, ci]] = proba[[i, ci]] / row_sum;
}
}
Ok(proba)
}
pub fn predict_joint_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
if x.ncols() != self.n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features],
actual: vec![x.ncols()],
context: "number of features must match fitted CategoricalNB".into(),
});
}
Ok(self.joint_log_likelihood(x))
}
pub fn predict_log_proba(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
let jll = self.predict_joint_log_proba(x)?;
Ok(crate::log_softmax_rows(&jll))
}
pub fn score(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<F, FerroError> {
if x.nrows() != y.len() {
return Err(FerroError::ShapeMismatch {
expected: vec![x.nrows()],
actual: vec![y.len()],
context: "y length must match number of samples in X".into(),
});
}
let preds = self.predict(x)?;
let n = y.len();
if n == 0 {
return Ok(F::zero());
}
let correct = preds.iter().zip(y.iter()).filter(|(p, t)| p == t).count();
Ok(F::from(correct).unwrap() / F::from(n).unwrap())
}
}
impl<F: Float + Send + Sync + 'static> Predict<Array2<F>> for FittedCategoricalNB<F> {
type Output = Array1<usize>;
type Error = FerroError;
fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
if x.ncols() != self.n_features {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_features],
actual: vec![x.ncols()],
context: "number of features must match fitted CategoricalNB".into(),
});
}
let scores = self.joint_log_likelihood(x);
let n_samples = x.nrows();
let n_classes = self.classes.len();
let mut predictions = Array1::<usize>::zeros(n_samples);
for i in 0..n_samples {
let mut best_class = 0;
let mut best_score = scores[[i, 0]];
for ci in 1..n_classes {
if scores[[i, ci]] > best_score {
best_score = scores[[i, ci]];
best_class = ci;
}
}
predictions[i] = self.classes[best_class];
}
Ok(predictions)
}
}
impl<F: Float + Send + Sync + 'static> HasClasses for FittedCategoricalNB<F> {
fn classes(&self) -> &[usize] {
&self.classes
}
fn n_classes(&self) -> usize {
self.classes.len()
}
}
impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> PipelineEstimator<F>
for CategoricalNB<F>
{
fn fit_pipeline(
&self,
x: &Array2<F>,
y: &Array1<F>,
) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
let fitted = self.fit(x, &y_usize)?;
Ok(Box::new(FittedCategoricalNBPipeline(fitted)))
}
}
struct FittedCategoricalNBPipeline<F: Float + Send + Sync + 'static>(FittedCategoricalNB<F>);
unsafe impl<F: Float + Send + Sync + 'static> Send for FittedCategoricalNBPipeline<F> {}
unsafe impl<F: Float + Send + Sync + 'static> Sync for FittedCategoricalNBPipeline<F> {}
impl<F: Float + ToPrimitive + FromPrimitive + Send + Sync + 'static> FittedPipelineEstimator<F>
for FittedCategoricalNBPipeline<F>
{
fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
let preds = self.0.predict(x)?;
Ok(preds.mapv(|v| F::from_usize(v).unwrap_or_else(F::nan)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use ndarray::array;
fn make_categorical_data() -> (Array2<f64>, Array1<usize>) {
let x = Array2::from_shape_vec(
(8, 3),
vec![
0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, ],
)
.unwrap();
let y = array![0usize, 0, 0, 0, 1, 1, 1, 1];
(x, y)
}
#[test]
fn test_categorical_nb_fit_predict() {
let (x, y) = make_categorical_data();
let model = CategoricalNB::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
assert!(correct >= 6, "expected at least 6 correct, got {correct}");
}
#[test]
fn test_categorical_nb_predict_proba_sums_to_one() {
let (x, y) = make_categorical_data();
let model = CategoricalNB::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
let proba = fitted.predict_proba(&x).unwrap();
assert_eq!(proba.nrows(), 8);
assert_eq!(proba.ncols(), 2);
for i in 0..proba.nrows() {
assert_relative_eq!(proba.row(i).sum(), 1.0, epsilon = 1e-10);
}
}
#[test]
fn test_categorical_nb_has_classes() {
let (x, y) = make_categorical_data();
let model = CategoricalNB::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
assert_eq!(fitted.classes(), &[0, 1]);
assert_eq!(fitted.n_classes(), 2);
}
#[test]
fn test_categorical_nb_alpha_smoothing_effect() {
let (x, y) = make_categorical_data();
let model_sharp = CategoricalNB::<f64>::new().with_alpha(0.01);
let fitted_sharp = model_sharp.fit(&x, &y).unwrap();
let proba_sharp = fitted_sharp.predict_proba(&x).unwrap();
let model_smooth = CategoricalNB::<f64>::new().with_alpha(100.0);
let fitted_smooth = model_smooth.fit(&x, &y).unwrap();
let proba_smooth = fitted_smooth.predict_proba(&x).unwrap();
let sharp_max = proba_sharp[[0, 0]].max(proba_sharp[[0, 1]]);
let smooth_max = proba_smooth[[0, 0]].max(proba_smooth[[0, 1]]);
assert!(smooth_max < sharp_max);
}
#[test]
fn test_categorical_nb_invalid_alpha_zero() {
let (x, y) = make_categorical_data();
let model = CategoricalNB::<f64>::new().with_alpha(0.0);
let result = model.fit(&x, &y);
assert!(result.is_err());
match result.unwrap_err() {
FerroError::InvalidParameter { name, .. } => assert_eq!(name, "alpha"),
e => panic!("expected InvalidParameter, got {e:?}"),
}
}
#[test]
fn test_categorical_nb_invalid_alpha_negative() {
let (x, y) = make_categorical_data();
let model = CategoricalNB::<f64>::new().with_alpha(-1.0);
let result = model.fit(&x, &y);
assert!(result.is_err());
match result.unwrap_err() {
FerroError::InvalidParameter { name, .. } => assert_eq!(name, "alpha"),
e => panic!("expected InvalidParameter, got {e:?}"),
}
}
#[test]
fn test_categorical_nb_shape_mismatch_fit() {
let x = Array2::from_shape_vec((4, 3), vec![0.0; 12]).unwrap();
let y = array![0usize, 1]; let model = CategoricalNB::<f64>::new();
assert!(model.fit(&x, &y).is_err());
}
#[test]
fn test_categorical_nb_shape_mismatch_predict() {
let (x, y) = make_categorical_data();
let model = CategoricalNB::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
let x_bad = Array2::from_shape_vec((3, 5), vec![0.0; 15]).unwrap();
assert!(fitted.predict(&x_bad).is_err());
assert!(fitted.predict_proba(&x_bad).is_err());
}
#[test]
fn test_categorical_nb_empty_data() {
let x = Array2::<f64>::zeros((0, 3));
let y = Array1::<usize>::zeros(0);
let model = CategoricalNB::<f64>::new();
assert!(model.fit(&x, &y).is_err());
}
#[test]
fn test_categorical_nb_single_class() {
let x = Array2::from_shape_vec((3, 2), vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
let y = array![2usize, 2, 2];
let model = CategoricalNB::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
assert_eq!(fitted.classes(), &[2]);
let preds = fitted.predict(&x).unwrap();
assert!(preds.iter().all(|&p| p == 2));
}
#[test]
fn test_categorical_nb_default() {
let model = CategoricalNB::<f64>::default();
assert_relative_eq!(model.alpha, 1.0, epsilon = 1e-15);
}
#[test]
fn test_categorical_nb_unseen_category() {
let x = Array2::from_shape_vec(
(4, 2),
vec![
0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
)
.unwrap();
let y = array![0usize, 0, 1, 1];
let model = CategoricalNB::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
let x_new = Array2::from_shape_vec((1, 2), vec![5.0, 0.0]).unwrap();
let preds = fitted.predict(&x_new).unwrap();
assert_eq!(preds.len(), 1);
let proba = fitted.predict_proba(&x_new).unwrap();
assert_relative_eq!(proba.row(0).sum(), 1.0, epsilon = 1e-10);
assert!(proba[[0, 0]] > 0.0 && proba[[0, 0]] < 1.0);
assert!(proba[[0, 1]] > 0.0 && proba[[0, 1]] < 1.0);
}
#[test]
fn test_categorical_nb_three_classes() {
let x = Array2::from_shape_vec(
(9, 2),
vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ],
)
.unwrap();
let y = array![0usize, 0, 0, 1, 1, 1, 2, 2, 2];
let model = CategoricalNB::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
assert_eq!(fitted.n_classes(), 3);
assert_eq!(fitted.classes(), &[0, 1, 2]);
let preds = fitted.predict(&x).unwrap();
let correct = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
assert_eq!(correct, 9);
}
#[test]
fn test_categorical_nb_pipeline() {
let x = Array2::from_shape_vec(
(6, 2),
vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 2.0, 1.0, 2.0, 0.0],
)
.unwrap();
let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
let model = CategoricalNB::<f64>::new();
let fitted = model.fit_pipeline(&x, &y).unwrap();
let preds = fitted.predict_pipeline(&x).unwrap();
assert_eq!(preds.len(), 6);
}
#[test]
fn test_categorical_nb_predict_proba_ordering() {
let (x, y) = make_categorical_data();
let model = CategoricalNB::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
let proba = fitted.predict_proba(&x).unwrap();
for i in 0..4 {
assert!(
proba[[i, 0]] > proba[[i, 1]],
"sample {i}: P(c=0)={} should be > P(c=1)={}",
proba[[i, 0]],
proba[[i, 1]]
);
}
for i in 4..8 {
assert!(
proba[[i, 1]] > proba[[i, 0]],
"sample {i}: P(c=1)={} should be > P(c=0)={}",
proba[[i, 1]],
proba[[i, 0]]
);
}
}
#[test]
fn test_categorical_nb_f32() {
let x = Array2::from_shape_vec((4, 2), vec![0.0f32, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.unwrap();
let y = array![0usize, 0, 1, 1];
let model = CategoricalNB::<f32>::new();
let fitted = model.fit(&x, &y).unwrap();
let preds = fitted.predict(&x).unwrap();
assert_eq!(preds.len(), 4);
let proba = fitted.predict_proba(&x).unwrap();
for i in 0..proba.nrows() {
let sum: f32 = proba.row(i).sum();
assert!((sum - 1.0f32).abs() < 1e-5);
}
}
#[test]
fn test_categorical_nb_unordered_classes() {
let x = Array2::from_shape_vec(
(6, 2),
vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 2.0, 1.0, 2.0, 0.0],
)
.unwrap();
let y = array![5usize, 5, 5, 10, 10, 10];
let model = CategoricalNB::<f64>::new();
let fitted = model.fit(&x, &y).unwrap();
assert_eq!(fitted.classes(), &[5, 10]);
let preds = fitted.predict(&x).unwrap();
for i in 0..3 {
assert_eq!(preds[i], 5);
}
for i in 3..6 {
assert_eq!(preds[i], 10);
}
}
}