1use std::collections::BinaryHeap;
7use std::marker::PhantomData;
8
9use scirs2_core::ndarray::{Array1, Array2};
10use sklears_core::{
11 error::{Result, SklearnContext, SklearsError},
12 traits::{Estimator, Fit, Predict, Trained, Untrained},
13 types::{Float, FloatBounds},
14 validation::{ConfigValidation, Validate, ValidationRule, ValidationRules},
15};
16
17#[derive(Debug, Clone)]
19pub struct OpticsConfig {
20 pub max_eps: f64,
22 pub min_samples: usize,
24 pub metric: DistanceMetric,
26 pub algorithm: Algorithm,
28 pub leaf_size: usize,
30 pub max_clusters: Option<usize>,
32 pub cluster_method: ClusterMethod,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq)]
38pub enum DistanceMetric {
39 Euclidean,
40 Manhattan,
41 Chebyshev,
42 Minkowski(f64),
43}
44
45#[derive(Debug, Clone, Copy, PartialEq)]
47pub enum Algorithm {
48 BallTree,
50 KDTree,
52 Brute,
54 Auto,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum ClusterMethod {
61 Threshold(f64),
63 Hierarchical,
65 SteepestDescent,
67}
68
69impl Default for OpticsConfig {
70 fn default() -> Self {
71 Self {
72 max_eps: f64::INFINITY,
73 min_samples: 5,
74 metric: DistanceMetric::Euclidean,
75 algorithm: Algorithm::Auto,
76 leaf_size: 30,
77 max_clusters: None,
78 cluster_method: ClusterMethod::Threshold(0.5),
79 }
80 }
81}
82
83impl Validate for OpticsConfig {
84 fn validate(&self) -> Result<()> {
85 ValidationRules::new("max_eps")
87 .add_rule(ValidationRule::Positive)
88 .validate_numeric(&self.max_eps)?;
89
90 ValidationRules::new("min_samples")
92 .add_rule(ValidationRule::Positive)
93 .validate_usize(&self.min_samples)?;
94
95 ValidationRules::new("leaf_size")
97 .add_rule(ValidationRule::Positive)
98 .validate_usize(&self.leaf_size)?;
99
100 if let DistanceMetric::Minkowski(p) = self.metric {
102 if p <= 0.0 {
103 return Err(SklearsError::InvalidInput(
104 "Minkowski p parameter must be positive".to_string(),
105 ));
106 }
107 }
108
109 Ok(())
110 }
111}
112
113impl ConfigValidation for OpticsConfig {
114 fn validate_config(&self) -> Result<()> {
115 self.validate()?;
116
117 if self.min_samples > 100 {
118 log::warn!(
119 "Large min_samples {} may be slow for dense datasets",
120 self.min_samples
121 );
122 }
123
124 if self.max_eps == f64::INFINITY
125 && matches!(self.cluster_method, ClusterMethod::Threshold(_))
126 {
127 log::warn!("Using infinite max_eps with threshold clustering may not produce meaningful results");
128 }
129
130 Ok(())
131 }
132
133 fn get_warnings(&self) -> Vec<String> {
134 let mut warnings = Vec::new();
135
136 if self.min_samples < 3 {
137 warnings.push("Very small min_samples may lead to noisy clusters".to_string());
138 }
139
140 if self.leaf_size > 100 {
141 warnings.push("Large leaf_size may reduce performance".to_string());
142 }
143
144 warnings
145 }
146}
147
148#[derive(Debug, Clone)]
150struct OpticsPoint {
151 index: usize,
153 core_distance: Option<f64>,
155 reachability_distance: Option<f64>,
157 processed: bool,
159}
160
161impl OpticsPoint {
162 fn new(index: usize) -> Self {
163 Self {
164 index,
165 core_distance: None,
166 reachability_distance: None,
167 processed: false,
168 }
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct OpticsOrdering {
175 pub index: usize,
177 pub core_distance: Option<f64>,
179 pub reachability_distance: Option<f64>,
181}
182
183#[derive(Debug, Clone)]
185pub struct Optics<State = Untrained> {
186 config: OpticsConfig,
187 state: PhantomData<State>,
188 ordering_: Option<Vec<OpticsOrdering>>,
190 labels_: Option<Array1<i32>>, n_features_: Option<usize>,
192 core_sample_indices_: Option<Vec<usize>>,
193}
194
195impl Optics<Untrained> {
196 pub fn new() -> Self {
198 Self {
199 config: OpticsConfig::default(),
200 state: PhantomData,
201 ordering_: None,
202 labels_: None,
203 n_features_: None,
204 core_sample_indices_: None,
205 }
206 }
207
208 pub fn max_eps(mut self, max_eps: f64) -> Self {
210 self.config.max_eps = max_eps;
211 self
212 }
213
214 pub fn min_samples(mut self, min_samples: usize) -> Self {
216 self.config.min_samples = min_samples;
217 self
218 }
219
220 pub fn metric(mut self, metric: DistanceMetric) -> Self {
222 self.config.metric = metric;
223 self
224 }
225
226 pub fn algorithm(mut self, algorithm: Algorithm) -> Self {
228 self.config.algorithm = algorithm;
229 self
230 }
231
232 pub fn leaf_size(mut self, leaf_size: usize) -> Self {
234 self.config.leaf_size = leaf_size;
235 self
236 }
237
238 pub fn cluster_method(mut self, method: ClusterMethod) -> Self {
240 self.config.cluster_method = method;
241 self
242 }
243}
244
245impl Default for Optics<Untrained> {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251impl Estimator for Optics<Untrained> {
252 type Config = OpticsConfig;
253 type Error = SklearsError;
254 type Float = Float;
255
256 fn config(&self) -> &Self::Config {
257 &self.config
258 }
259}
260
261impl Fit<Array2<Float>, ()> for Optics<Untrained> {
262 type Fitted = Optics<Trained>;
263
264 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
265 let n_samples = x.nrows();
266 let n_features = x.ncols();
267
268 self.config
270 .validate_config()
271 .fit_context("OPTICS", n_samples, n_features)?;
272
273 use sklears_core::validation::ml;
275 ml::validate_unsupervised_data(x).fit_context("OPTICS", n_samples, n_features)?;
276
277 if n_samples < self.config.min_samples {
278 return Err(SklearsError::InvalidInput(format!(
279 "min_samples ({}) cannot exceed n_samples ({})",
280 self.config.min_samples, n_samples
281 )));
282 }
283
284 let (ordering, core_indices) = self.run_optics(x)?;
286
287 let labels = self.extract_clusters(&ordering)?;
289
290 Ok(Optics {
291 config: self.config,
292 state: PhantomData,
293 ordering_: Some(ordering),
294 labels_: Some(labels),
295 n_features_: Some(n_features),
296 core_sample_indices_: Some(core_indices),
297 })
298 }
299}
300
301impl Optics<Untrained> {
302 fn run_optics(&self, x: &Array2<Float>) -> Result<(Vec<OpticsOrdering>, Vec<usize>)> {
304 let n_samples = x.nrows();
305 let mut points: Vec<OpticsPoint> = (0..n_samples).map(OpticsPoint::new).collect();
306 let mut ordering = Vec::new();
307 let mut core_indices = Vec::new();
308
309 let mut seeds = BinaryHeap::new();
311
312 for i in 0..n_samples {
313 if points[i].processed {
314 continue;
315 }
316
317 let neighbors = self.get_neighbors(x, i)?;
319
320 points[i].processed = true;
322
323 if neighbors.len() >= self.config.min_samples {
325 let core_distance = self.calculate_core_distance(x, i, &neighbors)?;
326 points[i].core_distance = Some(core_distance);
327 core_indices.push(i);
328
329 self.update_seeds(&mut seeds, &mut points, x, i, &neighbors)?;
331
332 ordering.push(OpticsOrdering {
334 index: i,
335 core_distance: points[i].core_distance,
336 reachability_distance: points[i].reachability_distance,
337 });
338
339 while let Some(seed_item) = seeds.pop() {
341 let seed_idx = seed_item.index;
342
343 if points[seed_idx].processed {
344 continue;
345 }
346
347 points[seed_idx].processed = true;
348
349 let seed_neighbors = self.get_neighbors(x, seed_idx)?;
350
351 ordering.push(OpticsOrdering {
353 index: seed_idx,
354 core_distance: points[seed_idx].core_distance,
355 reachability_distance: points[seed_idx].reachability_distance,
356 });
357
358 if seed_neighbors.len() >= self.config.min_samples {
360 let seed_core_distance =
361 self.calculate_core_distance(x, seed_idx, &seed_neighbors)?;
362 points[seed_idx].core_distance = Some(seed_core_distance);
363 if !core_indices.contains(&seed_idx) {
364 core_indices.push(seed_idx);
365 }
366
367 self.update_seeds(&mut seeds, &mut points, x, seed_idx, &seed_neighbors)?;
368 }
369 }
370 } else {
371 ordering.push(OpticsOrdering {
373 index: i,
374 core_distance: None,
375 reachability_distance: None,
376 });
377 }
378 }
379
380 Ok((ordering, core_indices))
381 }
382
383 fn get_neighbors(&self, x: &Array2<Float>, point_idx: usize) -> Result<Vec<usize>> {
385 let n_samples = x.nrows();
386 let mut neighbors = Vec::new();
387 let point = x.row(point_idx);
388
389 for i in 0..n_samples {
390 if i == point_idx {
391 continue;
392 }
393
394 let neighbor = x.row(i);
395 let distance = self.calculate_distance(&point, &neighbor)?;
396
397 if distance <= self.config.max_eps {
398 neighbors.push(i);
399 }
400 }
401
402 Ok(neighbors)
403 }
404
405 fn calculate_distance(
407 &self,
408 point1: &scirs2_core::ndarray::ArrayView1<Float>,
409 point2: &scirs2_core::ndarray::ArrayView1<Float>,
410 ) -> Result<f64> {
411 match self.config.metric {
412 DistanceMetric::Euclidean => {
413 let mut sum = 0.0;
414 for (&a, &b) in point1.iter().zip(point2.iter()) {
415 let diff = a - b;
416 sum += diff * diff;
417 }
418 Ok(sum.sqrt())
419 }
420 DistanceMetric::Manhattan => {
421 let mut sum = 0.0;
422 for (&a, &b) in point1.iter().zip(point2.iter()) {
423 sum += (a - b).abs();
424 }
425 Ok(sum)
426 }
427 DistanceMetric::Chebyshev => {
428 let mut max_diff = 0.0;
429 for (&a, &b) in point1.iter().zip(point2.iter()) {
430 let diff = (a - b).abs();
431 if diff > max_diff {
432 max_diff = diff;
433 }
434 }
435 Ok(max_diff)
436 }
437 DistanceMetric::Minkowski(p) => {
438 let mut sum = 0.0;
439 for (&a, &b) in point1.iter().zip(point2.iter()) {
440 sum += (a - b).abs().powf(p);
441 }
442 Ok(sum.powf(1.0 / p))
443 }
444 }
445 }
446
447 fn calculate_core_distance(
449 &self,
450 x: &Array2<Float>,
451 point_idx: usize,
452 neighbors: &[usize],
453 ) -> Result<f64> {
454 if neighbors.len() < self.config.min_samples {
455 return Ok(f64::INFINITY);
456 }
457
458 let point = x.row(point_idx);
459 let mut distances: Vec<f64> = Vec::new();
460
461 for &neighbor_idx in neighbors {
462 let neighbor = x.row(neighbor_idx);
463 let distance = self.calculate_distance(&point, &neighbor)?;
464 distances.push(distance);
465 }
466
467 distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
468 Ok(distances[self.config.min_samples - 1])
469 }
470
471 fn update_seeds(
473 &self,
474 seeds: &mut BinaryHeap<SeedItem>,
475 points: &mut [OpticsPoint],
476 x: &Array2<Float>,
477 core_idx: usize,
478 neighbors: &[usize],
479 ) -> Result<()> {
480 let core_distance = points[core_idx].core_distance.unwrap();
481 let core_point = x.row(core_idx);
482
483 for &neighbor_idx in neighbors {
484 if points[neighbor_idx].processed {
485 continue;
486 }
487
488 let neighbor_point = x.row(neighbor_idx);
489 let distance = self.calculate_distance(&core_point, &neighbor_point)?;
490 let new_reachability = core_distance.max(distance);
491
492 if points[neighbor_idx].reachability_distance.is_none()
493 || new_reachability < points[neighbor_idx].reachability_distance.unwrap()
494 {
495 points[neighbor_idx].reachability_distance = Some(new_reachability);
496 seeds.push(SeedItem {
497 index: neighbor_idx,
498 reachability: new_reachability,
499 });
500 }
501 }
502
503 Ok(())
504 }
505
506 fn extract_clusters(&self, ordering: &[OpticsOrdering]) -> Result<Array1<i32>> {
508 let n_samples = ordering.len();
509 let mut labels = Array1::from_elem(n_samples, -1);
510
511 match self.config.cluster_method {
512 ClusterMethod::Threshold(threshold) => {
513 let mut cluster_id = 0;
514 let mut current_cluster: Option<i32> = None;
515
516 for (i, entry) in ordering.iter().enumerate() {
517 let reachability = entry.reachability_distance.unwrap_or(f64::INFINITY);
518
519 if reachability <= threshold {
520 if current_cluster.is_none() {
521 current_cluster = Some(cluster_id);
522 cluster_id += 1;
523 }
524 labels[entry.index] = current_cluster.unwrap();
525 } else {
526 current_cluster = None;
527 }
529 }
530 }
531 ClusterMethod::Hierarchical => {
532 self.extract_hierarchical_clusters(ordering, &mut labels)?;
534 }
535 ClusterMethod::SteepestDescent => {
536 self.extract_steepest_descent_clusters(ordering, &mut labels)?;
538 }
539 }
540
541 Ok(labels)
542 }
543
544 fn extract_hierarchical_clusters(
546 &self,
547 ordering: &[OpticsOrdering],
548 labels: &mut Array1<i32>,
549 ) -> Result<()> {
550 let mut cluster_id = 0;
552 let window_size = self.config.min_samples;
553
554 for i in window_size..ordering.len() - window_size {
555 let current_reach = ordering[i].reachability_distance.unwrap_or(f64::INFINITY);
556
557 let mut is_minimum = true;
559 for j in i.saturating_sub(window_size)..=(i + window_size).min(ordering.len() - 1) {
560 if j != i {
561 let reach = ordering[j].reachability_distance.unwrap_or(f64::INFINITY);
562 if reach < current_reach {
563 is_minimum = false;
564 break;
565 }
566 }
567 }
568
569 if is_minimum && current_reach < f64::INFINITY {
570 for j in i..ordering.len() {
572 let reach = ordering[j].reachability_distance.unwrap_or(f64::INFINITY);
573 if reach < current_reach * 2.0 {
574 labels[ordering[j].index] = cluster_id;
575 } else {
576 break;
577 }
578 }
579 cluster_id += 1;
580
581 if let Some(max_clusters) = self.config.max_clusters {
582 if cluster_id >= max_clusters as i32 {
583 break;
584 }
585 }
586 }
587 }
588
589 Ok(())
590 }
591
592 fn extract_steepest_descent_clusters(
594 &self,
595 ordering: &[OpticsOrdering],
596 labels: &mut Array1<i32>,
597 ) -> Result<()> {
598 let mut cluster_id = 0;
600 let mut in_cluster = false;
601 let steep_threshold = 0.1; for i in 1..ordering.len() {
604 let prev_reach = ordering[i - 1]
605 .reachability_distance
606 .unwrap_or(f64::INFINITY);
607 let curr_reach = ordering[i].reachability_distance.unwrap_or(f64::INFINITY);
608
609 let ratio = if prev_reach > 0.0 && curr_reach.is_finite() {
610 curr_reach / prev_reach
611 } else {
612 1.0
613 };
614
615 if ratio <= steep_threshold && !in_cluster {
617 in_cluster = true;
618 labels[ordering[i].index] = cluster_id;
619 }
620 else if ratio >= (1.0 / steep_threshold) && in_cluster {
622 in_cluster = false;
623 cluster_id += 1;
624
625 if let Some(max_clusters) = self.config.max_clusters {
626 if cluster_id >= max_clusters as i32 {
627 break;
628 }
629 }
630 }
631 else if in_cluster {
633 labels[ordering[i].index] = cluster_id;
634 }
635 }
636
637 Ok(())
638 }
639}
640
641#[derive(Debug, Clone, PartialEq)]
643struct SeedItem {
644 index: usize,
645 reachability: f64,
646}
647
648impl Eq for SeedItem {}
649
650impl PartialOrd for SeedItem {
651 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
652 other.reachability.partial_cmp(&self.reachability)
654 }
655}
656
657impl Ord for SeedItem {
658 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
659 self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
660 }
661}
662
663impl Optics<Trained> {
664 pub fn ordering(&self) -> &[OpticsOrdering] {
666 self.ordering_.as_ref().expect("Model is trained")
667 }
668
669 pub fn labels(&self) -> &Array1<i32> {
671 self.labels_.as_ref().expect("Model is trained")
672 }
673
674 pub fn core_sample_indices(&self) -> &[usize] {
676 self.core_sample_indices_
677 .as_ref()
678 .expect("Model is trained")
679 }
680
681 pub fn reachability_distances(&self) -> Vec<Option<f64>> {
683 self.ordering()
684 .iter()
685 .map(|entry| entry.reachability_distance)
686 .collect()
687 }
688
689 pub fn core_distances(&self) -> Vec<Option<f64>> {
691 self.ordering()
692 .iter()
693 .map(|entry| entry.core_distance)
694 .collect()
695 }
696}
697
698impl Predict<Array2<Float>, Array1<i32>> for Optics<Trained> {
699 fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
700 let n_features = self.n_features_.expect("Model is trained");
701 if x.ncols() != n_features {
702 return Err(SklearsError::InvalidInput(format!(
703 "Expected {} features, got {}",
704 n_features,
705 x.ncols()
706 )));
707 }
708
709 let n_samples = x.nrows();
712 let labels = Array1::zeros(n_samples);
713
714 log::warn!("OPTICS prediction on new data is approximate");
717
718 Ok(labels)
719 }
720}
721
722#[allow(non_snake_case)]
723#[cfg(test)]
724mod tests {
725 use super::*;
726 use scirs2_core::ndarray::array;
727
728 #[test]
729 fn test_optics_simple() {
730 let data = array![
732 [0.0, 0.0],
733 [0.1, 0.1],
734 [0.2, 0.2],
735 [5.0, 5.0],
736 [5.1, 5.1],
737 [5.2, 5.2],
738 ];
739
740 let model = Optics::new()
741 .max_eps(2.0)
742 .min_samples(2)
743 .cluster_method(ClusterMethod::Threshold(1.0))
744 .fit(&data, &())
745 .unwrap();
746
747 let labels = model.labels();
748 let ordering = model.ordering();
749
750 assert_eq!(labels.len(), 6);
751 assert_eq!(ordering.len(), 6);
752
753 let unique_labels: std::collections::HashSet<_> = labels.iter().collect();
755 assert!(unique_labels.len() >= 2); }
757
758 #[test]
759 fn test_optics_validation() {
760 use sklears_core::validation::{ConfigValidation, Validate};
761
762 let valid_config = OpticsConfig::default();
764 assert!(valid_config.validate().is_ok());
765 assert!(valid_config.validate_config().is_ok());
766
767 let mut invalid_config = OpticsConfig::default();
769 invalid_config.max_eps = -1.0;
770 assert!(invalid_config.validate().is_err());
771
772 let mut invalid_config = OpticsConfig::default();
774 invalid_config.min_samples = 0;
775 assert!(invalid_config.validate().is_err());
776
777 let mut invalid_config = OpticsConfig::default();
779 invalid_config.metric = DistanceMetric::Minkowski(-1.0);
780 assert!(invalid_config.validate().is_err());
781 }
782
783 #[test]
784 fn test_optics_distance_metrics() {
785 let data = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0],];
786
787 let metrics = vec![
789 DistanceMetric::Euclidean,
790 DistanceMetric::Manhattan,
791 DistanceMetric::Chebyshev,
792 DistanceMetric::Minkowski(2.0),
793 ];
794
795 for metric in metrics {
796 let model = Optics::new()
797 .max_eps(5.0)
798 .min_samples(2)
799 .metric(metric)
800 .fit(&data, &())
801 .unwrap();
802
803 let labels = model.labels();
804 assert_eq!(labels.len(), 3);
805 }
806 }
807
808 #[test]
809 fn test_optics_cluster_methods() {
810 let data = array![
811 [0.0, 0.0],
812 [0.1, 0.1],
813 [0.2, 0.2],
814 [5.0, 5.0],
815 [5.1, 5.1],
816 [5.2, 5.2],
817 ];
818
819 let methods = vec![
820 ClusterMethod::Threshold(1.0),
821 ClusterMethod::Hierarchical,
822 ClusterMethod::SteepestDescent,
823 ];
824
825 for method in methods {
826 let model = Optics::new()
827 .max_eps(3.0)
828 .min_samples(2)
829 .cluster_method(method)
830 .fit(&data, &())
831 .unwrap();
832
833 let labels = model.labels();
834 assert_eq!(labels.len(), 6);
835 }
836 }
837
838 #[test]
839 fn test_optics_core_samples() {
840 let data = array![
841 [0.0, 0.0],
842 [0.1, 0.1],
843 [0.2, 0.2],
844 [5.0, 5.0], ];
846
847 let model = Optics::new()
848 .max_eps(1.0)
849 .min_samples(2)
850 .fit(&data, &())
851 .unwrap();
852
853 let core_indices = model.core_sample_indices();
854
855 assert!(core_indices.len() >= 1);
857
858 assert!(!core_indices.contains(&3));
860 }
861
862 #[test]
863 fn test_optics_reachability_plot() {
864 let data = array![
865 [0.0, 0.0],
866 [0.1, 0.1],
867 [0.2, 0.2],
868 [1.0, 1.0],
869 [1.1, 1.1],
870 [1.2, 1.2],
871 ];
872
873 let model = Optics::new()
874 .max_eps(2.0)
875 .min_samples(2)
876 .fit(&data, &())
877 .unwrap();
878
879 let reachability_distances = model.reachability_distances();
880 let core_distances = model.core_distances();
881
882 assert_eq!(reachability_distances.len(), 6);
883 assert_eq!(core_distances.len(), 6);
884
885 let finite_reachability = reachability_distances
887 .iter()
888 .filter(|&&d| d.is_some() && d.unwrap().is_finite())
889 .count();
890
891 assert!(finite_reachability > 0);
892 }
893}