1use std::collections::{HashMap, HashSet, BinaryHeap};
2use std::f64;
3use std::hash::Hash;
4use std::slice;
5
6use order_stat;
7
8use {Point, RegionQuery, ListPoints, Points};
9
10pub struct Optics<P: Points> where P::Point: Hash + Eq + Clone {
71 computed_eps: f64,
72 min_pts: usize,
73 #[allow(dead_code)] points: P,
74 order: Vec<P::Point>,
75 core_dist: HashMap<P::Point, f64>,
76 reachability: HashMap<P::Point, f64>,
77}
78
79impl<P: RegionQuery + ListPoints> Optics<P>
80 where P::Point: Hash + Eq + Clone
81{
82 pub fn new(points: P, eps: f64, min_pts: usize) -> Optics<P> {
93 let mut processed = HashSet::new();
94 let mut order = vec![];
95 let mut reachability = HashMap::new();
96 let mut core_dist = HashMap::new();
97 let mut seeds = BinaryHeap::new();
98 for p in points.all_points() {
99 seeds.clear();
100 seeds.push(Dist { dist: 0.0, point: p });
101 while let Some(q) = seeds.pop() {
102 if !processed.insert(q.point.clone()) {
103 continue
104 }
105
106 let mut neighbours = points.neighbours(&q.point, eps)
107 .map(|t| Dist { dist: t.0, point: t.1 })
108 .collect::<Vec<_>>();
109 order.push(q.point.clone());
110 if let Some(cd) = compute_core_dist(&mut neighbours, min_pts) {
111 core_dist.insert(q.point.clone(), cd);
112 update(&neighbours, cd, &processed, &mut seeds, &mut reachability)
113 }
114 }
115 }
116 Optics {
117 points: points,
118 min_pts: min_pts,
119 computed_eps: eps,
120 order: order,
121 core_dist: core_dist,
122 reachability: reachability,
123 }
124 }
125
126 pub fn dbscan_clustering<'a>(&'a self, eps: f64) -> OpticsDbscanClustering<'a, P> {
136 assert!(eps <= self.computed_eps);
137 OpticsDbscanClustering {
138 noise: vec![],
139 order: self.order.iter(),
140 optics: self,
141 next: None,
142 eps: eps,
143 }
144 }
145}
146
147pub struct OpticsDbscanClustering<'a, P: 'a + Points>
152 where P::Point: 'a + Eq + Hash + Clone
153{
154 noise: Vec<P::Point>,
155 order: slice::Iter<'a, P::Point>,
156 optics: &'a Optics<P>,
157 next: Option<P::Point>,
158 eps: f64,
159}
160
161impl<'a, P: Points> OpticsDbscanClustering<'a, P>
162 where P::Point: 'a + Eq + Hash + Clone
163{
164 pub fn noise_points(&self) -> &[P::Point] {
165 &self.noise
166 }
167}
168impl<'a, P: RegionQuery + ListPoints> Iterator for OpticsDbscanClustering<'a, P>
169 where P::Point: 'a + Eq + Hash + Clone + ::std::fmt::Debug
170{
171 type Item = Vec<P::Point>;
172 #[inline(never)]
173 fn next(&mut self) -> Option<Vec<P::Point>> {
174 let mut current = Vec::with_capacity(self.optics.min_pts);
175 if let Some(x) = self.next.take() {
176 current.push(x)
177 }
178
179 for p in &mut self.order {
180 if *self.optics.reachability.get(p).unwrap_or(&f64::INFINITY) > self.eps {
181 if *self.optics.core_dist.get(p).unwrap_or(&f64::INFINITY) <= self.eps {
182 if current.len() > 0 {
183 self.next = Some(p.clone());
184 return Some(current)
185 }
186 } else {
187 self.noise.push(p.clone());
188 continue
189 }
190 }
191 current.push(p.clone())
192 }
193 if current.len() > 0 {
194 Some(current)
195 } else {
196 None
197 }
198 }
199}
200
201#[inline(never)]
202fn update<P>(neighbours: &[Dist<P>],
203 core_dist: f64,
204 processed: &HashSet<P>,
205 seeds: &mut BinaryHeap<Dist<P>>,
206 reachability: &mut HashMap<P, f64>)
207 where P: Hash + Eq + Clone
208{
209 for n in neighbours {
210 if processed.contains(&n.point) {
211 continue
212 }
213
214 let new_reach_dist = core_dist.max(n.dist);
215 let entry = reachability.entry(n.point.clone()).or_insert(f64::INFINITY);
216 if new_reach_dist < *entry {
217 *entry = new_reach_dist;
218 seeds.push(Dist { dist: -new_reach_dist, point: n.point.clone() })
220 }
221 }
222}
223
224#[derive(Clone)]
225struct Dist<P> {
226 dist: f64,
227 point: P
228}
229impl<P> PartialEq for Dist<P> {
230 fn eq(&self, other: &Dist<P>) -> bool {
231 self.dist == other.dist
232 }
233}
234impl<P> Eq for Dist<P> {}
235use std::cmp::Ordering;
236impl<P> PartialOrd for Dist<P> {
237 fn partial_cmp(&self, other: &Dist<P>) -> Option<Ordering> {
238 self.dist.partial_cmp(&other.dist)
239 }
240}
241impl<P> Ord for Dist<P> {
242 fn cmp(&self, other: &Dist<P>) -> Ordering {
243 self.partial_cmp(other).unwrap()
244 }
245}
246
247fn compute_core_dist<P>(x: &mut [Dist<P>], n: usize) -> Option<f64> {
248 if x.len() >= n {
249 Some(order_stat::kth(x, n - 1).dist)
250 } else {
251 None
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use {Point, BruteScan};
259 #[derive(Copy, Clone)]
260 struct Linear(f64);
261 impl Point for Linear {
262 fn dist(&self, other: &Linear) -> f64 {
263 (self.0 - other.0).abs()
264 }
265 fn dist_lower_bound(&self, other: &Linear) -> f64 {
266 self.dist(other)
267 }
268 }
269
270 #[test]
271 fn smoke() {
272 let points = [Linear(0.0), Linear(10.0), Linear(9.5), Linear(0.5), Linear(0.6),
274 Linear(9.1), Linear(9.9), Linear(5.0)];
275 let points = BruteScan::new(&points);
276 let optics = Optics::new(points, 0.8, 3);
277 let mut clustering = optics.dbscan_clustering(0.8);
278 println!("{:?}", optics.reachability);
279 let mut clusters = clustering.by_ref().collect::<Vec<_>>();
280
281 for x in &mut clusters { x.sort() }
283 clusters.sort();
284
285
286 assert_eq!(clusters, &[&[0usize, 3, 4] as &[_], &[1usize, 2, 5, 6] as &_]);
287 assert_eq!(clustering.noise_points().iter().cloned().collect::<Vec<_>>(),
288 &[7]);
289 }
290
291 #[test]
292 fn reachability_restricted() {
293 use std::f64::INFINITY as INF;
294 macro_rules! l {
295 ($($e: expr),*) => {
296 [$(Linear($e),)*]
297 }
298 }
299 let points = l![0.0, 0.01, 10.0, 9.5, 0.6, 0.5, 9.1, 9.9, 5.0, 5.3];
300 let scanner = BruteScan::new(&points);
301 let optics = Optics::new(scanner, 0.5, 3);
302
303 let expected = [(0.0, INF),
304 (0.01, 0.5),
305 (0.5, 0.49),
306 (0.6, 0.49),
307 (10.0, INF),
308 (9.9, 0.5),
309 (9.5, 0.4),
310 (9.1, 0.4),
311 (5.0, INF),
312 (5.3, INF)];
313 assert_eq!(optics.order.len(), points.len());
314 for (&idx, &(point, reachability)) in optics.order.iter().zip(&expected) {
315 let idx_point = points[idx];
316 assert_eq!(idx_point.0, point);
317
318 let computed_r = optics.reachability.get(&idx).map_or(INF, |&f| f);
319 assert!((reachability == computed_r) || (reachability - computed_r).abs() < 1e-5,
320 "difference in reachability for {} ({}): true {}, computed {}", idx, point,
321 reachability, computed_r);
322 }
323 }
324 #[test]
325 fn reachability_unrestricted() {
326 use std::f64::INFINITY as INF;
327 macro_rules! l {
328 ($($e: expr),*) => {
329 [$(Linear($e),)*]
330 }
331 }
332 let points = l![0.0, 0.01, 10.0, 9.5, 0.6, 0.5, 9.1, 9.9, 5.0, 5.3];
333 let scanner = BruteScan::new(&points);
334 let optics = Optics::new(scanner, 1e10, 3);
335
336 let expected = [(0.0, INF),
337 (0.01, 0.5),
338 (0.5, 0.49),
339 (0.6, 0.49),
340 (5.0, 4.4),
341 (5.3, 4.1),
342 (9.1, 3.8),
343 (9.5, 0.8),
344 (9.9, 0.4),
345 (10.0, 0.4)];
346
347 assert_eq!(optics.order.len(), points.len());
348 for (&idx, &(point, reachability)) in optics.order.iter().zip(&expected) {
349 let idx_point = points[idx];
350 assert_eq!(idx_point.0, point);
351
352 let computed_r = optics.reachability.get(&idx).map_or(INF, |&f| f);
353 assert!((reachability == computed_r) || (reachability - computed_r).abs() < 1e-5,
354 "difference in reachability for {} ({}): true {}, computed {}", idx, point,
355 reachability, computed_r);
356 }
357 }
358}
359
360make_benches!(|p, e, mp| super::Optics::new(p, e, mp).extract_clustering(e).count());