1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use crate::appx_dbscan::cells_grid::CellsGrid;
use crate::appx_dbscan::hyperparameters::AppxDbscanHyperParams;
use linfa::Float;

use ndarray::{Array1, ArrayView2};

/// Struct that labels a set of points according to
/// the Approximated DBSCAN algorithm
pub struct AppxDbscanLabeler {
    labels: Array1<Option<usize>>,
}

impl AppxDbscanLabeler {
    /// Runs the Approximated DBSCAN algorithm on the provided `observations` using the params specified in input.
    /// The `Labeler` struct returned contains the label of every point in `observations`.
    ///
    /// ## Parameters:
    /// * `observations`: the points that you want to cluster according to the approximated DBSCAN rule;
    /// * `params`: the parameters for the approximated DBSCAN algorithm
    ///
    /// ## Return
    ///
    /// Struct of type `Labeler` which contains the label associated with each point in `observations`
    ///
    pub fn new<F: Float>(
        observations: &ArrayView2<F>,
        params: &AppxDbscanHyperParams<F>,
    ) -> AppxDbscanLabeler {
        let mut grid = CellsGrid::new(observations, params);
        AppxDbscanLabeler {
            labels: Self::label(&mut grid, observations, params),
        }
    }

    /// Gives the labels of every point provided in input to the constructor.
    ///
    /// ## Example:
    ///
    /// ```rust
    ///
    /// use ndarray::{array, Axis};
    /// use linfa_clustering::{AppxDbscanLabeler, AppxDbscanHyperParams};
    ///
    /// // Let's define some observations and set the desired params
    /// let observations = array![[0.,0.], [1., 0.], [0., 1.]];
    /// let params = AppxDbscanHyperParams::new(2).build();
    /// // Now we build the labels for each observation using the Labeler struct
    /// let labeler = AppxDbscanLabeler::new(&observations.view(),&params);
    /// // Here we can access the labels for each point `observations`
    /// for (i, point) in observations.axis_iter(Axis(0)).enumerate() {
    ///     let label_for_point = labeler.labels()[i];
    /// }  
    /// ```
    ///
    pub fn labels(&self) -> &Array1<Option<usize>> {
        &self.labels
    }

    fn label<F: Float>(
        grid: &mut CellsGrid<F>,
        points: &ArrayView2<F>,
        params: &AppxDbscanHyperParams<F>,
    ) -> Array1<Option<usize>> {
        let mut labels = Self::label_connected_components(grid, points, params);
        Self::label_border_noise_points(grid, points, &mut labels, params);
        labels
    }

    /// Explores the graph of cells contained in `grid` and labels all the core points of core cells in the same connected component
    /// with in the same cluster label, and core points from core cells in different connected components with different cluster labels.
    /// If the points in the input grid were not labeled then they will be inside this method.
    fn label_connected_components<F: Float>(
        grid: &mut CellsGrid<F>,
        observations: &ArrayView2<F>,
        params: &AppxDbscanHyperParams<F>,
    ) -> Array1<Option<usize>> {
        if !grid.labeled() {
            grid.label_points(observations, params);
        }
        let mut labels = Array1::from_elem(observations.dim().0, None);
        let mut current_cluster_i: usize = 0;
        for set in grid.cells_mut().all_sets_mut() {
            let mut core_cells_count = 0;
            for cell in set.filter(|(_, c)| c.is_core()).map(|(_, c)| c) {
                cell.assign_to_cluster(current_cluster_i, &mut labels.view_mut());
                core_cells_count += 1;
            }
            if core_cells_count > 0 {
                current_cluster_i += 1;
            }
        }
        labels
    }

    /// Loops through all non core points of the dataset and labels them with one of the possible cluster labels that they belong to.
    /// If no such cluster is found, the point is given label of `None`.
    fn label_border_noise_points<F: Float>(
        grid: &CellsGrid<F>,
        observations: &ArrayView2<F>,
        clusters: &mut Array1<Option<usize>>,
        params: &AppxDbscanHyperParams<F>,
    ) {
        for cell in grid.cells() {
            for cp_index in cell
                .points()
                .iter()
                .filter(|x| x.is_core())
                .map(|x| x.index())
            {
                let curr_point = &observations.row(cp_index);
                'nbrs: for neighbour_i in cell.neighbours_indexes() {
                    // indexes are added to neighbours only if tthey are in the table
                    let neighbour = grid.cells().get(*neighbour_i).unwrap();
                    if neighbour.approximate_range_counting(curr_point, &params) > 0 {
                        clusters[cp_index] = Some(neighbour.cluster_i().unwrap_or_else(|| {
                            panic!("Attempted to get cluster index of a non core cell")
                        }));
                        // assign only to first matching cluster for compatibility
                        break 'nbrs;
                    }
                }
            }
        }
    }
}

#[cfg(test)]
mod tests;