1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use linfa::{param_guard::TransformGuard, prelude::*, Float};
use linfa_nn::{distance::Distance, NearestNeighbour};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use thiserror::Error;

#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
#[derive(Debug)]
/// The set of hyperparameters that can be specified for the execution of
/// the [DBSCAN algorithm](struct.Dbscan.html).
pub struct DbscanValidParams<F: Float, D: Distance<F>, N: NearestNeighbour> {
    pub(crate) tolerance: F,
    pub(crate) min_points: usize,
    pub(crate) dist_fn: D,
    pub(crate) nn_algo: N,
}

#[derive(Debug)]
/// Helper struct for building a set of [DBSCAN hyperparameters](struct.DbscanParams.html)
pub struct DbscanParams<F: Float, D: Distance<F>, N: NearestNeighbour>(DbscanValidParams<F, D, N>);

#[derive(Error, Debug)]
pub enum DbscanParamsError {
    #[error("min_points must be greater than 1")]
    MinPoints,
    #[error("tolerance must be greater than 0")]
    Tolerance,
}

impl<F: Float, D: Distance<F>, N: NearestNeighbour> DbscanParams<F, D, N> {
    pub(crate) fn new(min_points: usize, dist_fn: D, nn_algo: N) -> Self {
        Self(DbscanValidParams {
            min_points,
            tolerance: F::cast(1e-4),
            dist_fn,
            nn_algo,
        })
    }

    /// Set the tolerance
    pub fn tolerance(mut self, tolerance: F) -> Self {
        self.0.tolerance = tolerance;
        self
    }

    /// Set the nearest neighbour algorithm to be used
    pub fn nn_algo(mut self, nn_algo: N) -> Self {
        self.0.nn_algo = nn_algo;
        self
    }

    /// Set the distance metric
    pub fn dist_fn(mut self, dist_fn: D) -> Self {
        self.0.dist_fn = dist_fn;
        self
    }
}

impl<F: Float, D: Distance<F>, N: NearestNeighbour> ParamGuard for DbscanParams<F, D, N> {
    type Checked = DbscanValidParams<F, D, N>;
    type Error = DbscanParamsError;

    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
        if self.0.min_points <= 1 {
            Err(DbscanParamsError::MinPoints)
        } else if self.0.tolerance <= F::zero() {
            Err(DbscanParamsError::Tolerance)
        } else {
            Ok(&self.0)
        }
    }

    fn check(self) -> Result<Self::Checked, Self::Error> {
        self.check_ref()?;
        Ok(self.0)
    }
}
impl<F: Float, D: Distance<F>, N: NearestNeighbour> TransformGuard for DbscanParams<F, D, N> {}

impl<F: Float, D: Distance<F>, N: NearestNeighbour> DbscanValidParams<F, D, N> {
    /// Nearest neighbour algorithm used for range queries
    pub fn tolerance(&self) -> F {
        self.tolerance
    }

    /// Minimum number of neighboring points a point needs to have to be a core                                                                                                
    /// point and not a noise point.
    pub fn minimum_points(&self) -> usize {
        self.min_points
    }

    /// Distance metric used in the DBSCAN calculation
    pub fn dist_fn(&self) -> &D {
        &self.dist_fn
    }

    /// Nearest neighbour algorithm used for range queries
    pub fn nn_algo(&self) -> &N {
        &self.nn_algo
    }
}

#[cfg(test)]
mod tests {
    use linfa_nn::{distance::L2Dist, CommonNearestNeighbour};

    use super::*;

    #[test]
    fn tolerance_cannot_be_zero() {
        let res = DbscanParams::new(2, L2Dist, CommonNearestNeighbour::KdTree)
            .tolerance(0.0)
            .check();
        assert!(matches!(res, Err(DbscanParamsError::Tolerance)));
    }

    #[test]
    fn min_points_at_least_2() {
        let res = DbscanParams::new(1, L2Dist, CommonNearestNeighbour::KdTree)
            .tolerance(3.3)
            .check();
        assert!(matches!(res, Err(DbscanParamsError::MinPoints)));
    }
}