1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use crate::appx_dbscan::cells_grid::CellsGrid;
use crate::appx_dbscan::hyperparameters::AppxDbscanHyperParams;
use linfa::Float;
use ndarray::{Array1, ArrayView2};
pub struct AppxDbscanLabeler {
labels: Array1<Option<usize>>,
}
impl AppxDbscanLabeler {
pub fn new<F: Float>(
observations: &ArrayView2<F>,
params: &AppxDbscanHyperParams<F>,
) -> AppxDbscanLabeler {
let mut grid = CellsGrid::new(observations, params);
AppxDbscanLabeler {
labels: Self::label(&mut grid, observations, params),
}
}
pub fn labels(&self) -> &Array1<Option<usize>> {
&self.labels
}
fn label<F: Float>(
grid: &mut CellsGrid<F>,
points: &ArrayView2<F>,
params: &AppxDbscanHyperParams<F>,
) -> Array1<Option<usize>> {
let mut labels = Self::label_connected_components(grid, points, params);
Self::label_border_noise_points(grid, points, &mut labels, params);
labels
}
fn label_connected_components<F: Float>(
grid: &mut CellsGrid<F>,
observations: &ArrayView2<F>,
params: &AppxDbscanHyperParams<F>,
) -> Array1<Option<usize>> {
if !grid.labeled() {
grid.label_points(observations, params);
}
let mut labels = Array1::from_elem(observations.dim().0, None);
let mut current_cluster_i: usize = 0;
for set in grid.cells_mut().all_sets_mut() {
let mut core_cells_count = 0;
for cell in set.filter(|(_, c)| c.is_core()).map(|(_, c)| c) {
cell.assign_to_cluster(current_cluster_i, &mut labels.view_mut());
core_cells_count += 1;
}
if core_cells_count > 0 {
current_cluster_i += 1;
}
}
labels
}
fn label_border_noise_points<F: Float>(
grid: &CellsGrid<F>,
observations: &ArrayView2<F>,
clusters: &mut Array1<Option<usize>>,
params: &AppxDbscanHyperParams<F>,
) {
for cell in grid.cells() {
for cp_index in cell
.points()
.iter()
.filter(|x| x.is_core())
.map(|x| x.index())
{
let curr_point = &observations.row(cp_index);
'nbrs: for neighbour_i in cell.neighbours_indexes() {
let neighbour = grid.cells().get(*neighbour_i).unwrap();
if neighbour.approximate_range_counting(curr_point, ¶ms) > 0 {
clusters[cp_index] = Some(neighbour.cluster_i().unwrap_or_else(|| {
panic!("Attempted to get cluster index of a non core cell")
}));
break 'nbrs;
}
}
}
}
}
}
#[cfg(test)]
mod tests;