1use crate::optics::hyperparams::{OpticsParams, OpticsValidParams};
2use linfa::traits::Transformer;
3use linfa::Float;
4use linfa_nn::distance::{Distance, L2Dist};
5use linfa_nn::{CommonNearestNeighbour, NearestNeighbour, NearestNeighbourIndex};
6use ndarray::{ArrayView, Ix1, Ix2};
7use noisy_float::{checkers::NumChecker, NoisyFloat};
8#[cfg(feature = "serde")]
9use serde_crate::{Deserialize, Serialize};
10use std::cmp::Ordering;
11use std::collections::BTreeSet;
12use std::ops::Index;
13use std::slice::SliceIndex;
14
15#[derive(Clone, Debug, PartialEq, Eq)]
16#[cfg_attr(
17 feature = "serde",
18 derive(Serialize, Deserialize),
19 serde(crate = "serde_crate")
20)]
21pub struct Optics;
34
35#[derive(Debug, Clone)]
38#[cfg_attr(
39 feature = "serde",
40 derive(Serialize, Deserialize),
41 serde(crate = "serde_crate")
42)]
43pub struct Sample<F> {
44 index: usize,
46 core_distance: Option<F>,
48 reachability_distance: Option<F>,
50}
51
52impl<F: Float> Sample<F> {
53 fn new(index: usize) -> Self {
55 Self {
56 index,
57 core_distance: None,
58 reachability_distance: None,
59 }
60 }
61
62 pub fn index(&self) -> usize {
64 self.index
65 }
66
67 pub fn reachability_distance(&self) -> &Option<F> {
70 &self.reachability_distance
71 }
72
73 pub fn core_distance(&self) -> &Option<F> {
75 &self.core_distance
76 }
77}
78
79impl<F: Float> Eq for Sample<F> {}
80
81impl<F: Float> PartialEq for Sample<F> {
82 fn eq(&self, other: &Self) -> bool {
83 self.reachability_distance == other.reachability_distance
84 }
85}
86
87#[allow(clippy::non_canonical_partial_ord_impl)]
88impl<F: Float> PartialOrd for Sample<F> {
89 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
90 self.reachability_distance
91 .partial_cmp(&other.reachability_distance)
92 }
93}
94
95impl<F: Float> Ord for Sample<F> {
96 fn cmp(&self, other: &Self) -> Ordering {
97 self.reachability_distance
98 .map(NoisyFloat::<_, NumChecker>::new)
99 .cmp(
100 &other
101 .reachability_distance
102 .map(NoisyFloat::<_, NumChecker>::new),
103 )
104 }
105}
106
107#[derive(Clone, Debug, PartialEq)]
112#[cfg_attr(
113 feature = "serde",
114 derive(Serialize, Deserialize),
115 serde(crate = "serde_crate")
116)]
117pub struct OpticsAnalysis<F: Float> {
118 orderings: Vec<Sample<F>>,
121}
122
123impl<F: Float> OpticsAnalysis<F> {
124 pub fn as_slice(&self) -> &[Sample<F>] {
126 self.orderings.as_slice()
127 }
128
129 pub fn iter(&self) -> std::slice::Iter<'_, Sample<F>> {
131 self.orderings.iter()
132 }
133}
134
135impl<I, F: Float> Index<I> for OpticsAnalysis<F>
136where
137 I: SliceIndex<[Sample<F>]>,
138{
139 type Output = I::Output;
140
141 fn index(&self, index: I) -> &Self::Output {
142 self.orderings.index(index)
143 }
144}
145
146impl Optics {
147 pub fn params<F: Float>(min_points: usize) -> OpticsParams<F, L2Dist, CommonNearestNeighbour> {
154 OpticsParams::new(min_points, L2Dist, CommonNearestNeighbour::KdTree)
155 }
156
157 pub fn params_with<F: Float, D: Distance<F>, N: NearestNeighbour>(
160 min_points: usize,
161 dist_fn: D,
162 nn_algo: N,
163 ) -> OpticsParams<F, D, N> {
164 OpticsParams::new(min_points, dist_fn, nn_algo)
165 }
166}
167
168impl<F: Float, D: Distance<F>, N: NearestNeighbour>
169 Transformer<ArrayView<'_, F, Ix2>, OpticsAnalysis<F>> for OpticsValidParams<F, D, N>
170{
171 fn transform(&self, observations: ArrayView<F, Ix2>) -> OpticsAnalysis<F> {
172 let mut result = OpticsAnalysis { orderings: vec![] };
173
174 let mut points = (0..observations.nrows())
175 .map(Sample::new)
176 .collect::<Vec<_>>();
177
178 let nn = match self
179 .nn_algo()
180 .from_batch(&observations, self.dist_fn().clone())
181 {
182 Ok(nn) => nn,
183 Err(linfa_nn::BuildError::ZeroDimension) => {
184 return OpticsAnalysis { orderings: points }
185 }
186 Err(e) => panic!("Unexpected nearest neighbour error: {}", e),
187 };
188
189 let mut processed = BTreeSet::new();
192 let mut index = 0;
193 let mut seeds = Vec::new();
194 loop {
195 if index == points.len() {
196 break;
197 } else if processed.contains(&index) {
198 index += 1;
199 continue;
200 }
201 let mut expected = if processed.is_empty() { 0 } else { index };
202 let mut points_index = index;
203 for index in processed.range(index..) {
205 if expected != *index {
206 points_index = expected;
207 break;
208 }
209 expected += 1;
210 }
211 index += 1;
212 let neighbors = self.find_neighbors(&*nn, observations.row(points_index));
213 let n = &mut points[points_index];
214 self.set_core_distance(n, &neighbors, observations);
215 if n.core_distance.is_some() {
216 seeds.clear();
217 self.get_seeds(
220 observations,
221 n.clone(),
222 &neighbors,
223 &mut points,
224 &processed,
225 &mut seeds,
226 );
227 while !seeds.is_empty() {
228 seeds.sort_unstable_by(|a, b| b.cmp(a));
229 let (i, min_point) = seeds
230 .iter()
231 .enumerate()
232 .min_by(|(_, a), (_, b)| points[**a].cmp(&points[**b]))
233 .unwrap();
234 let n = &mut points[*min_point];
235 seeds.remove(i);
236 processed.insert(n.index);
237 let neighbors = self.find_neighbors(&*nn, observations.row(n.index));
238
239 self.set_core_distance(n, &neighbors, observations);
240 result.orderings.push(n.clone());
241 if n.core_distance.is_some() {
242 self.get_seeds(
243 observations,
244 n.clone(),
245 &neighbors,
246 &mut points,
247 &processed,
248 &mut seeds,
249 );
250 }
251 }
252 } else {
253 result.orderings.push(n.clone());
256 processed.insert(n.index);
257 }
258 }
259 result
260 }
261}
262
263impl<F: Float, D: Distance<F>, N: NearestNeighbour> OpticsValidParams<F, D, N> {
264 fn find_neighbors(
268 &self,
269 nn: &dyn NearestNeighbourIndex<F>,
270 candidate: ArrayView<F, Ix1>,
271 ) -> Vec<Sample<F>> {
272 nn.within_range(candidate, self.tolerance())
275 .unwrap()
276 .into_iter()
277 .map(|(pt, index)| Sample {
278 index,
279 reachability_distance: Some(self.dist_fn().distance(pt, candidate)),
280 core_distance: None,
281 })
282 .collect()
283 }
284
285 fn set_core_distance(
287 &self,
288 point: &mut Sample<F>,
289 neighbors: &[Sample<F>],
290 dataset: ArrayView<F, Ix2>,
291 ) {
292 let observation = dataset.row(point.index);
293 point.core_distance = neighbors
294 .get(self.minimum_points() - 1)
295 .map(|x| dataset.row(x.index))
296 .map(|x| self.dist_fn().distance(observation, x));
297 }
298
299 fn get_seeds(
302 &self,
303 observations: ArrayView<F, Ix2>,
304 sample: Sample<F>,
305 neighbors: &[Sample<F>],
306 points: &mut [Sample<F>],
307 processed: &BTreeSet<usize>,
308 seeds: &mut Vec<usize>,
309 ) {
310 for n in neighbors.iter().filter(|x| !processed.contains(&x.index)) {
311 let dist = self
312 .dist_fn()
313 .distance(observations.row(n.index), observations.row(sample.index));
314 let r_dist = F::max(sample.core_distance.unwrap(), dist);
315 match points[n.index].reachability_distance {
316 None => {
317 points[n.index].reachability_distance = Some(r_dist);
318 seeds.push(n.index);
319 }
320 Some(s) if r_dist < s => points[n.index].reachability_distance = Some(r_dist),
321 _ => {}
322 }
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::OpticsError;
331 use linfa::ParamGuard;
332 use linfa_nn::KdTree;
333 use ndarray::Array2;
334 use std::collections::BTreeSet;
335
336 #[test]
337 fn autotraits() {
338 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
339 has_autotraits::<OpticsAnalysis<f64>>();
340 has_autotraits::<Optics>();
341 has_autotraits::<Sample<f64>>();
342 has_autotraits::<OpticsError>();
343 has_autotraits::<OpticsParams<f64, L2Dist, KdTree>>();
344 has_autotraits::<OpticsValidParams<f64, L2Dist, KdTree>>();
345 }
346
347 #[test]
348 fn optics_consistency() {
349 let params = Optics::params(3);
350 let data = vec![1.0, 2.0, 3.0, 8.0, 8.0, 7.0, 2.0, 5.0, 6.0, 7.0, 8.0, 3.0];
351 let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
352
353 let samples = params.transform(data.view()).unwrap();
354
355 let indexes = samples
357 .orderings
358 .iter()
359 .map(|x| x.index)
360 .collect::<BTreeSet<_>>();
361 assert!((0..data.len()).all(|x| indexes.contains(&x)));
362
363 assert!(samples.orderings.iter().all(|x| x.core_distance.is_some()));
365 }
366
367 #[test]
368 fn simple_dataset() {
369 let params = Optics::params(3).tolerance(4.0);
370 let data = vec![
372 1.0, 2.0, 3.0, 10.0, 18.0, 18.0, 15.0, 2.0, 15.0, 18.0, 3.0, 100.0, 101.0,
373 ];
374 let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
375
376 let first_grouping = [0, 1, 2, 7, 10].iter().collect::<BTreeSet<_>>();
379 let second_grouping = [4, 5, 6, 8, 9].iter().collect::<BTreeSet<_>>();
380
381 let samples = params.transform(data.view()).unwrap();
382
383 let indexes = samples
384 .orderings
385 .iter()
386 .map(|x| x.index)
387 .collect::<BTreeSet<_>>();
388 assert!((0..data.len()).all(|x| indexes.contains(&x)));
389
390 assert!(samples
391 .orderings
392 .iter()
393 .take(first_grouping.len())
394 .all(|x| first_grouping.contains(&x.index)));
395 let skip_len = first_grouping.len() + 1;
396 assert!(samples
397 .orderings
398 .iter()
399 .skip(skip_len)
400 .take(first_grouping.len())
401 .all(|x| second_grouping.contains(&x.index)));
402
403 let anomaly = samples.orderings.iter().find(|x| x.index == 3).unwrap();
404 assert!(anomaly.core_distance.is_none());
405 assert!(anomaly.reachability_distance.is_none());
406
407 let anomaly = samples.orderings.iter().find(|x| x.index == 11).unwrap();
408 assert!(anomaly.core_distance.is_none());
409 assert!(anomaly.reachability_distance.is_none());
410
411 let anomaly = samples.orderings.iter().find(|x| x.index == 12).unwrap();
412 assert!(anomaly.core_distance.is_none());
413 assert!(anomaly.reachability_distance.is_none());
414 }
415
416 #[test]
417 fn dataset_too_small() {
418 let params = Optics::params(4);
419 let data = vec![1.0, 2.0, 3.0];
420 let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
421
422 let samples = params.transform(data.view()).unwrap();
423
424 assert!(samples
425 .orderings
426 .iter()
427 .all(|x| x.core_distance.is_none() && x.reachability_distance.is_none()));
428 }
429
430 #[test]
431 fn invalid_params() {
432 let params = Optics::params(1);
433 let data = vec![1.0, 2.0, 3.0];
434 let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
435 assert!(params.transform(data.view()).is_err());
436
437 let params = Optics::params(2);
438 assert!(params.transform(data.view()).is_ok());
439
440 let params = params.tolerance(0.0);
441 assert!(params.transform(data.view()).is_err());
442 }
443
444 #[test]
445 fn find_neighbors_test() {
446 let data = vec![1.0, 2.0, 10.0, 15.0, 13.0];
447 let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
448
449 let param = Optics::params(3).tolerance(6.0).check_unwrap();
450 let nn = CommonNearestNeighbour::KdTree
451 .from_batch(&data, L2Dist)
452 .unwrap();
453
454 let neighbors = param.find_neighbors(&*nn, data.row(0));
455 assert_eq!(neighbors.len(), 2);
456 assert_eq!(
457 vec![0, 1],
458 neighbors
459 .iter()
460 .map(|x| x.reachability_distance.unwrap() as u32)
461 .collect::<Vec<u32>>()
462 );
463 assert!(neighbors.iter().all(|x| x.core_distance.is_none()));
464
465 let neighbors = param.find_neighbors(&*nn, data.row(4));
466 assert_eq!(neighbors.len(), 3);
467 assert!(neighbors.iter().all(|x| x.core_distance.is_none()));
468 assert_eq!(
469 vec![0, 2, 3],
470 neighbors
471 .iter()
472 .map(|x| x.reachability_distance.unwrap() as u32)
473 .collect::<Vec<u32>>()
474 );
475 }
476
477 #[test]
478 fn get_seeds_test() {
479 let data = vec![1.0, 2.0, 10.0, 15.0, 13.0];
480 let data: Array2<f64> = Array2::from_shape_vec((data.len(), 1), data).unwrap();
481
482 let param = Optics::params(3).tolerance(6.0).check_unwrap();
483 let nn = CommonNearestNeighbour::KdTree
484 .from_batch(&data, L2Dist)
485 .unwrap();
486
487 let mut points = (0..data.nrows()).map(Sample::new).collect::<Vec<_>>();
488
489 let neighbors = param.find_neighbors(&*nn, data.row(0));
490 param.set_core_distance(&mut points[0], &neighbors, data.view());
493 assert!(points[0].core_distance.is_none());
494
495 let neighbors = param.find_neighbors(&*nn, data.row(4));
496 param.set_core_distance(&mut points[4], &neighbors, data.view());
497 dbg!(&points);
498 assert!(points[4].core_distance.is_some());
499
500 let mut seeds = vec![];
501 let mut processed = BTreeSet::new();
502 param.get_seeds(
506 data.view(),
507 points[4].clone(),
508 &neighbors,
509 &mut points,
510 &processed,
511 &mut seeds,
512 );
513
514 assert_eq!(seeds, vec![4, 3, 2]);
515
516 let mut points = (0..data.nrows()).map(Sample::new).collect::<Vec<_>>();
517
518 param.set_core_distance(&mut points[4], &neighbors, data.view());
521 processed.insert(3);
522 seeds.clear();
523
524 param.get_seeds(
525 data.view(),
526 points[4].clone(),
527 &neighbors,
528 &mut points,
529 &processed,
530 &mut seeds,
531 );
532
533 assert_eq!(seeds, vec![4, 2]);
534
535 let mut points = (0..data.nrows()).map(Sample::new).collect::<Vec<_>>();
536
537 processed.clear();
541 param.set_core_distance(&mut points[4], &neighbors, data.view());
542 points[2].reachability_distance = Some(0.001);
543 seeds.clear();
544
545 param.get_seeds(
546 data.view(),
547 points[4].clone(),
548 &neighbors,
549 &mut points,
550 &processed,
551 &mut seeds,
552 );
553
554 assert_eq!(seeds, vec![4, 3]);
555 }
556}