1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
use crate::NaiveBayesError;
use linfa::{Float, ParamGuard};
use std::marker::PhantomData;

/// A verified hyper-parameter set ready for the estimation of a Gaussian Naive Bayes model
///
/// See [`GaussianNbParams`](crate::hyperparams::GaussianNbParams) for more informations.
#[derive(Debug)]
pub struct GaussianNbValidParams<F, L> {
    // Required for calculation stability
    var_smoothing: F,
    // Phantom data for label type
    label: PhantomData<L>,
}

impl<F: Float, L> GaussianNbValidParams<F, L> {
    /// Get the variance smoothing
    pub fn var_smoothing(&self) -> F {
        self.var_smoothing
    }
}

/// A hyper-parameter set during construction
///
/// The parameter set can be verified into a
/// [`GaussianNbValidParams`](crate::hyperparams::GaussianNbValidParams) by calling
/// [ParamGuard::check](Self::check). It is also possible to directly fit a model with
/// [Fit::fit](linfa::traits::Fit::fit) or
/// [FitWith::fit_with](linfa::traits::FitWith::fit_with) which implicitely verifies the parameter set
/// prior to the model estimation and forwards any error.
///
/// # Parameters
/// | Name | Default | Purpose | Range |
/// | :--- | :--- | :---| :--- |
/// | [var_smoothing](Self::var_smoothing) | `1e-9` | Stabilize variance calculation if ratios are small in update step | `[0, inf)` |
///
/// # Errors
///
/// The following errors can come from invalid hyper-parameters:
///
/// Returns [`InvalidSmoothing`](NaiveBayesError::InvalidSmoothing) if the smoothing
/// parameter is negative.
///
/// # Example
///
/// ```rust
/// use linfa_bayes::{GaussianNbParams, GaussianNbValidParams, Result};
/// use linfa::prelude::*;
/// use ndarray::array;
///
/// let x = array![
///     [-2., -1.],
///     [-1., -1.],
///     [-1., -2.],
///     [1., 1.],
///     [1., 2.],
///     [2., 1.]
/// ];
/// let y = array![1, 1, 1, 2, 2, 2];
/// let ds = DatasetView::new(x.view(), y.view());
///
/// // create a new parameter set with variance smoothing equals `1e-5`
/// let unchecked_params = GaussianNbParams::new()
///     .var_smoothing(1e-5);
///
/// // fit model with unchecked parameter set
/// let model = unchecked_params.fit(&ds)?;
///
/// // transform into a verified parameter set
/// let checked_params = unchecked_params.check()?;
///
/// // update model with the verified parameters, this only returns
/// // errors originating from the fitting process
/// let model = checked_params.fit_with(Some(model), &ds)?;
/// # Result::Ok(())
/// ```
pub struct GaussianNbParams<F, L>(GaussianNbValidParams<F, L>);

impl<F: Float, L> Default for GaussianNbParams<F, L> {
    fn default() -> Self {
        Self::new()
    }
}

impl<F: Float, L> GaussianNbParams<F, L> {
    /// Create new [GaussianNbParams] set with default values for its parameters
    pub fn new() -> Self {
        Self(GaussianNbValidParams {
            var_smoothing: F::cast(1e-9),
            label: PhantomData,
        })
    }

    /// Specifies the portion of the largest variance of all the features that
    /// is added to the variance for calculation stability
    pub fn var_smoothing(mut self, var_smoothing: F) -> Self {
        self.0.var_smoothing = var_smoothing;
        self
    }
}

impl<F: Float, L> ParamGuard for GaussianNbParams<F, L> {
    type Checked = GaussianNbValidParams<F, L>;
    type Error = NaiveBayesError;

    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
        if self.0.var_smoothing.is_negative() {
            Err(NaiveBayesError::InvalidSmoothing(
                self.0.var_smoothing.to_f64().unwrap(),
            ))
        } else {
            Ok(&self.0)
        }
    }

    fn check(self) -> Result<Self::Checked, Self::Error> {
        self.check_ref()?;
        Ok(self.0)
    }
}