use crate::NaiveBayesError;
use linfa::{Float, ParamGuard};
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GaussianNbValidParams<F, L> {
var_smoothing: F,
label: PhantomData<L>,
}
impl<F: Float, L> GaussianNbValidParams<F, L> {
pub fn var_smoothing(&self) -> F {
self.var_smoothing
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GaussianNbParams<F, L>(GaussianNbValidParams<F, L>);
impl<F: Float, L> Default for GaussianNbParams<F, L> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float, L> GaussianNbParams<F, L> {
pub fn new() -> Self {
Self(GaussianNbValidParams {
var_smoothing: F::cast(1e-9),
label: PhantomData,
})
}
pub fn var_smoothing(mut self, var_smoothing: F) -> Self {
self.0.var_smoothing = var_smoothing;
self
}
}
impl<F: Float, L> ParamGuard for GaussianNbParams<F, L> {
type Checked = GaussianNbValidParams<F, L>;
type Error = NaiveBayesError;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
if self.0.var_smoothing.is_negative() {
Err(NaiveBayesError::InvalidSmoothing(
self.0.var_smoothing.to_f64().unwrap(),
))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MultinomialNbValidParams<F, L> {
alpha: F,
label: PhantomData<L>,
}
impl<F: Float, L> MultinomialNbValidParams<F, L> {
pub fn alpha(&self) -> F {
self.alpha
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MultinomialNbParams<F, L>(MultinomialNbValidParams<F, L>);
impl<F: Float, L> Default for MultinomialNbParams<F, L> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float, L> MultinomialNbParams<F, L> {
pub fn new() -> Self {
Self(MultinomialNbValidParams {
alpha: F::cast(1),
label: PhantomData,
})
}
pub fn alpha(mut self, alpha: F) -> Self {
self.0.alpha = alpha;
self
}
}
impl<F: Float, L> ParamGuard for MultinomialNbParams<F, L> {
type Checked = MultinomialNbValidParams<F, L>;
type Error = NaiveBayesError;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
if self.0.alpha.is_negative() {
Err(NaiveBayesError::InvalidSmoothing(
self.0.alpha.to_f64().unwrap(),
))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BernoulliNbValidParams<F, L> {
alpha: F,
binarize: Option<F>,
label: PhantomData<L>,
}
impl<F: Float, L> BernoulliNbValidParams<F, L> {
pub fn alpha(&self) -> F {
self.alpha
}
pub fn binarize(&self) -> Option<F> {
self.binarize
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BernoulliNbParams<F, L>(BernoulliNbValidParams<F, L>);
impl<F: Float, L> Default for BernoulliNbParams<F, L> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float, L> BernoulliNbParams<F, L> {
pub fn new() -> Self {
Self(BernoulliNbValidParams {
alpha: F::one(),
binarize: Some(F::zero()),
label: PhantomData,
})
}
pub fn alpha(mut self, alpha: F) -> Self {
self.0.alpha = alpha;
self
}
pub fn binarize(mut self, threshold: Option<F>) -> Self {
self.0.binarize = threshold;
self
}
}
impl<F: Float, L> ParamGuard for BernoulliNbParams<F, L> {
type Checked = BernoulliNbValidParams<F, L>;
type Error = NaiveBayesError;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
if self.0.alpha.is_negative() {
Err(NaiveBayesError::InvalidSmoothing(
self.0.alpha.to_f64().unwrap(),
))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}