ghostflow_ml/
clustering_more.rs

1//! Additional Clustering - OPTICS, BIRCH, HDBSCAN
2
3use ghostflow_core::Tensor;
4
5/// OPTICS - Ordering Points To Identify the Clustering Structure
6pub struct OPTICS {
7    pub min_samples: usize,
8    pub max_eps: f32,
9    pub metric: OPTICSMetric,
10    pub cluster_method: ClusterMethod,
11    pub xi: f32,
12    ordering_: Option<Vec<usize>>,
13    reachability_: Option<Vec<f32>>,
14    core_distances_: Option<Vec<f32>>,
15    labels_: Option<Vec<i32>>,
16}
17
18#[derive(Clone, Copy)]
19pub enum OPTICSMetric {
20    Euclidean,
21    Manhattan,
22    Cosine,
23}
24
25#[derive(Clone, Copy)]
26pub enum ClusterMethod {
27    Xi,
28    DBSCAN,
29}
30
31impl OPTICS {
32    pub fn new(min_samples: usize) -> Self {
33        OPTICS {
34            min_samples,
35            max_eps: f32::INFINITY,
36            metric: OPTICSMetric::Euclidean,
37            cluster_method: ClusterMethod::Xi,
38            xi: 0.05,
39            ordering_: None,
40            reachability_: None,
41            core_distances_: None,
42            labels_: None,
43        }
44    }
45
46    pub fn max_eps(mut self, eps: f32) -> Self {
47        self.max_eps = eps;
48        self
49    }
50
51    pub fn xi(mut self, xi: f32) -> Self {
52        self.xi = xi;
53        self
54    }
55
56    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
57        match self.metric {
58            OPTICSMetric::Euclidean => {
59                a.iter().zip(b.iter()).map(|(&x, &y)| (x - y).powi(2)).sum::<f32>().sqrt()
60            }
61            OPTICSMetric::Manhattan => {
62                a.iter().zip(b.iter()).map(|(&x, &y)| (x - y).abs()).sum()
63            }
64            OPTICSMetric::Cosine => {
65                let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
66                let norm_a: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
67                let norm_b: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
68                1.0 - dot / (norm_a * norm_b).max(1e-10)
69            }
70        }
71    }
72
73    fn compute_core_distance(&self, point_idx: usize, x: &[f32], n_samples: usize, n_features: usize) -> f32 {
74        let point = &x[point_idx * n_features..(point_idx + 1) * n_features];
75        let mut distances: Vec<f32> = (0..n_samples)
76            .filter(|&i| i != point_idx)
77            .map(|i| {
78                let other = &x[i * n_features..(i + 1) * n_features];
79                self.distance(point, other)
80            })
81            .collect();
82        
83        distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
84        
85        if distances.len() >= self.min_samples - 1 {
86            distances[self.min_samples - 2]
87        } else {
88            f32::INFINITY
89        }
90    }
91
92    pub fn fit(&mut self, x: &Tensor) {
93        let x_data = x.data_f32();
94        let n_samples = x.dims()[0];
95        let n_features = x.dims()[1];
96
97        // Compute core distances
98        let core_distances: Vec<f32> = (0..n_samples)
99            .map(|i| self.compute_core_distance(i, &x_data, n_samples, n_features))
100            .collect();
101
102        // Initialize
103        let mut reachability = vec![f32::INFINITY; n_samples];
104        let mut processed = vec![false; n_samples];
105        let mut ordering = Vec::with_capacity(n_samples);
106
107        // Priority queue simulation using Vec
108        let mut seeds: Vec<(usize, f32)> = Vec::new();
109
110        // Process all points
111        for _ in 0..n_samples {
112            // Find unprocessed point with smallest reachability
113            let next_idx = if seeds.is_empty() {
114                // Find any unprocessed point
115                (0..n_samples).find(|&i| !processed[i])
116            } else {
117                // Find seed with minimum reachability
118                seeds.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
119                seeds.pop().map(|(idx, _)| idx)
120            };
121
122            let current = match next_idx {
123                Some(idx) => idx,
124                None => break,
125            };
126
127            if processed[current] {
128                continue;
129            }
130
131            processed[current] = true;
132            ordering.push(current);
133
134            if core_distances[current] < self.max_eps {
135                // Update reachability of neighbors
136                let current_point = &x_data[current * n_features..(current + 1) * n_features];
137                
138                for i in 0..n_samples {
139                    if processed[i] {
140                        continue;
141                    }
142
143                    let other_point = &x_data[i * n_features..(i + 1) * n_features];
144                    let dist = self.distance(current_point, other_point);
145
146                    if dist <= self.max_eps {
147                        let new_reach = core_distances[current].max(dist);
148                        if new_reach < reachability[i] {
149                            reachability[i] = new_reach;
150                            // Update or add to seeds
151                            if let Some(pos) = seeds.iter().position(|(idx, _)| *idx == i) {
152                                seeds[pos].1 = new_reach;
153                            } else {
154                                seeds.push((i, new_reach));
155                            }
156                        }
157                    }
158                }
159            }
160        }
161
162        // Extract clusters using xi method
163        let labels = match self.cluster_method {
164            ClusterMethod::Xi => self.extract_xi_clusters(&reachability, &ordering, n_samples),
165            ClusterMethod::DBSCAN => self.extract_dbscan_clusters(&reachability, &ordering, n_samples),
166        };
167
168        self.ordering_ = Some(ordering);
169        self.reachability_ = Some(reachability);
170        self.core_distances_ = Some(core_distances);
171        self.labels_ = Some(labels);
172    }
173
174    fn extract_xi_clusters(&self, reachability: &[f32], ordering: &[usize], n_samples: usize) -> Vec<i32> {
175        let mut labels = vec![-1i32; n_samples];
176        let mut cluster_id = 0;
177
178        // Simplified xi clustering
179        let mut in_cluster = false;
180        let mut cluster_start = 0;
181
182        for (i, &idx) in ordering.iter().enumerate() {
183            let reach = reachability[idx];
184            
185            if reach < f32::INFINITY {
186                if !in_cluster {
187                    in_cluster = true;
188                    cluster_start = i;
189                }
190            } else if in_cluster {
191                // End of cluster
192                if i - cluster_start >= self.min_samples {
193                    for j in cluster_start..i {
194                        labels[ordering[j]] = cluster_id;
195                    }
196                    cluster_id += 1;
197                }
198                in_cluster = false;
199            }
200        }
201
202        // Handle last cluster
203        if in_cluster && ordering.len() - cluster_start >= self.min_samples {
204            for j in cluster_start..ordering.len() {
205                labels[ordering[j]] = cluster_id;
206            }
207        }
208
209        labels
210    }
211
212    fn extract_dbscan_clusters(&self, reachability: &[f32], ordering: &[usize], n_samples: usize) -> Vec<i32> {
213        let mut labels = vec![-1i32; n_samples];
214        let eps = self.max_eps;
215        let mut cluster_id = 0;
216
217        for &idx in ordering {
218            if reachability[idx] <= eps {
219                if labels[idx] == -1 {
220                    labels[idx] = cluster_id;
221                }
222            } else {
223                cluster_id += 1;
224            }
225        }
226
227        labels
228    }
229
230    pub fn fit_predict(&mut self, x: &Tensor) -> Tensor {
231        self.fit(x);
232        let labels = self.labels_.as_ref().unwrap();
233        let labels_f32: Vec<f32> = labels.iter().map(|&l| l as f32).collect();
234        Tensor::from_slice(&labels_f32, &[labels.len()]).unwrap()
235    }
236
237    pub fn labels(&self) -> Option<&Vec<i32>> {
238        self.labels_.as_ref()
239    }
240
241    pub fn reachability(&self) -> Option<&Vec<f32>> {
242        self.reachability_.as_ref()
243    }
244}
245
246/// BIRCH - Balanced Iterative Reducing and Clustering using Hierarchies
247pub struct BIRCH {
248    pub threshold: f32,
249    pub branching_factor: usize,
250    pub n_clusters: Option<usize>,
251    centroids_: Option<Vec<Vec<f32>>>,
252    labels_: Option<Vec<i32>>,
253    n_features_: usize,
254}
255
256#[derive(Clone)]
257struct CFNode {
258    n: usize,
259    ls: Vec<f32>,  // Linear sum
260    ss: f32,       // Squared sum
261}
262
263impl CFNode {
264    fn new(n_features: usize) -> Self {
265        CFNode {
266            n: 0,
267            ls: vec![0.0; n_features],
268            ss: 0.0,
269        }
270    }
271
272    fn add_point(&mut self, point: &[f32]) {
273        self.n += 1;
274        for (i, &p) in point.iter().enumerate() {
275            self.ls[i] += p;
276            self.ss += p * p;
277        }
278    }
279
280    fn merge(&mut self, other: &CFNode) {
281        self.n += other.n;
282        for (i, &ls) in other.ls.iter().enumerate() {
283            self.ls[i] += ls;
284        }
285        self.ss += other.ss;
286    }
287
288    fn centroid(&self) -> Vec<f32> {
289        if self.n == 0 {
290            return self.ls.clone();
291        }
292        self.ls.iter().map(|&x| x / self.n as f32).collect()
293    }
294
295    fn radius(&self) -> f32 {
296        if self.n <= 1 {
297            return 0.0;
298        }
299        let centroid = self.centroid();
300        let centroid_ss: f32 = centroid.iter().map(|&x| x * x).sum();
301        ((self.ss / self.n as f32) - centroid_ss).max(0.0).sqrt()
302    }
303
304    fn distance_to(&self, point: &[f32]) -> f32 {
305        let centroid = self.centroid();
306        centroid.iter().zip(point.iter())
307            .map(|(&c, &p)| (c - p).powi(2))
308            .sum::<f32>()
309            .sqrt()
310    }
311}
312
313impl BIRCH {
314    pub fn new() -> Self {
315        BIRCH {
316            threshold: 0.5,
317            branching_factor: 50,
318            n_clusters: None,
319            centroids_: None,
320            labels_: None,
321            n_features_: 0,
322        }
323    }
324
325    pub fn threshold(mut self, t: f32) -> Self {
326        self.threshold = t;
327        self
328    }
329
330    pub fn n_clusters(mut self, n: usize) -> Self {
331        self.n_clusters = Some(n);
332        self
333    }
334
335    pub fn fit(&mut self, x: &Tensor) {
336        let x_data = x.data_f32();
337        let n_samples = x.dims()[0];
338        let n_features = x.dims()[1];
339        self.n_features_ = n_features;
340
341        // Build CF tree (simplified as a flat list of CF nodes)
342        let mut cf_nodes: Vec<CFNode> = Vec::new();
343
344        for i in 0..n_samples {
345            let point = &x_data[i * n_features..(i + 1) * n_features];
346            
347            // Find closest CF node
348            let mut closest_idx = None;
349            let mut min_dist = f32::INFINITY;
350
351            for (j, node) in cf_nodes.iter().enumerate() {
352                let dist = node.distance_to(point);
353                if dist < min_dist {
354                    min_dist = dist;
355                    closest_idx = Some(j);
356                }
357            }
358
359            // Check if we can add to existing node
360            if let Some(idx) = closest_idx {
361                let mut test_node = cf_nodes[idx].clone();
362                test_node.add_point(point);
363                
364                if test_node.radius() <= self.threshold {
365                    cf_nodes[idx].add_point(point);
366                    continue;
367                }
368            }
369
370            // Create new CF node
371            let mut new_node = CFNode::new(n_features);
372            new_node.add_point(point);
373            cf_nodes.push(new_node);
374
375            // Merge nodes if too many (simplified)
376            while cf_nodes.len() > self.branching_factor * 10 {
377                // Find two closest nodes and merge
378                let mut min_dist = f32::INFINITY;
379                let mut merge_i = 0;
380                let mut merge_j = 1;
381
382                for i in 0..cf_nodes.len() {
383                    for j in (i + 1)..cf_nodes.len() {
384                        let ci = cf_nodes[i].centroid();
385                        let cj = cf_nodes[j].centroid();
386                        let dist: f32 = ci.iter().zip(cj.iter())
387                            .map(|(&a, &b)| (a - b).powi(2))
388                            .sum::<f32>()
389                            .sqrt();
390                        if dist < min_dist {
391                            min_dist = dist;
392                            merge_i = i;
393                            merge_j = j;
394                        }
395                    }
396                }
397
398                let node_j = cf_nodes.remove(merge_j);
399                cf_nodes[merge_i].merge(&node_j);
400            }
401        }
402
403        // Extract centroids
404        let centroids: Vec<Vec<f32>> = cf_nodes.iter().map(|n| n.centroid()).collect();
405
406        // Apply final clustering if n_clusters specified
407        let final_centroids = if let Some(k) = self.n_clusters {
408            self.kmeans_on_centroids(&centroids, k)
409        } else {
410            centroids
411        };
412
413        // Assign labels
414        let mut labels = vec![0i32; n_samples];
415        for i in 0..n_samples {
416            let point = &x_data[i * n_features..(i + 1) * n_features];
417            let mut min_dist = f32::INFINITY;
418            let mut best_cluster = 0;
419
420            for (j, centroid) in final_centroids.iter().enumerate() {
421                let dist: f32 = point.iter().zip(centroid.iter())
422                    .map(|(&p, &c)| (p - c).powi(2))
423                    .sum::<f32>()
424                    .sqrt();
425                if dist < min_dist {
426                    min_dist = dist;
427                    best_cluster = j as i32;
428                }
429            }
430            labels[i] = best_cluster;
431        }
432
433        self.centroids_ = Some(final_centroids);
434        self.labels_ = Some(labels);
435    }
436
437    fn kmeans_on_centroids(&self, centroids: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
438        if centroids.len() <= k {
439            return centroids.to_vec();
440        }
441
442        let n_features = self.n_features_;
443        let mut cluster_centers: Vec<Vec<f32>> = centroids[..k].to_vec();
444
445        for _ in 0..100 {
446            // Assign centroids to clusters
447            let mut assignments = vec![0usize; centroids.len()];
448            for (i, c) in centroids.iter().enumerate() {
449                let mut min_dist = f32::INFINITY;
450                for (j, center) in cluster_centers.iter().enumerate() {
451                    let dist: f32 = c.iter().zip(center.iter())
452                        .map(|(&a, &b)| (a - b).powi(2))
453                        .sum::<f32>()
454                        .sqrt();
455                    if dist < min_dist {
456                        min_dist = dist;
457                        assignments[i] = j;
458                    }
459                }
460            }
461
462            // Update cluster centers
463            let mut new_centers = vec![vec![0.0f32; n_features]; k];
464            let mut counts = vec![0usize; k];
465
466            for (i, &cluster) in assignments.iter().enumerate() {
467                counts[cluster] += 1;
468                for (j, &val) in centroids[i].iter().enumerate() {
469                    new_centers[cluster][j] += val;
470                }
471            }
472
473            for i in 0..k {
474                if counts[i] > 0 {
475                    for j in 0..n_features {
476                        new_centers[i][j] /= counts[i] as f32;
477                    }
478                } else {
479                    new_centers[i] = cluster_centers[i].clone();
480                }
481            }
482
483            cluster_centers = new_centers;
484        }
485
486        cluster_centers
487    }
488
489    pub fn fit_predict(&mut self, x: &Tensor) -> Tensor {
490        self.fit(x);
491        let labels = self.labels_.as_ref().unwrap();
492        let labels_f32: Vec<f32> = labels.iter().map(|&l| l as f32).collect();
493        Tensor::from_slice(&labels_f32, &[labels.len()]).unwrap()
494    }
495
496    pub fn labels(&self) -> Option<&Vec<i32>> {
497        self.labels_.as_ref()
498    }
499}
500
501impl Default for BIRCH {
502    fn default() -> Self { Self::new() }
503}
504
505/// HDBSCAN - Hierarchical DBSCAN
506pub struct HDBSCAN {
507    pub min_cluster_size: usize,
508    pub min_samples: Option<usize>,
509    pub cluster_selection_epsilon: f32,
510    labels_: Option<Vec<i32>>,
511    probabilities_: Option<Vec<f32>>,
512}
513
514impl HDBSCAN {
515    pub fn new(min_cluster_size: usize) -> Self {
516        HDBSCAN {
517            min_cluster_size,
518            min_samples: None,
519            cluster_selection_epsilon: 0.0,
520            labels_: None,
521            probabilities_: None,
522        }
523    }
524
525    pub fn min_samples(mut self, n: usize) -> Self {
526        self.min_samples = Some(n);
527        self
528    }
529
530    fn mutual_reachability_distance(&self, core_distances: &[f32], i: usize, j: usize, 
531                                     x: &[f32], n_features: usize) -> f32 {
532        let pi = &x[i * n_features..(i + 1) * n_features];
533        let pj = &x[j * n_features..(j + 1) * n_features];
534        let dist: f32 = pi.iter().zip(pj.iter())
535            .map(|(&a, &b)| (a - b).powi(2))
536            .sum::<f32>()
537            .sqrt();
538        
539        core_distances[i].max(core_distances[j]).max(dist)
540    }
541
542    pub fn fit(&mut self, x: &Tensor) {
543        let x_data = x.data_f32();
544        let n_samples = x.dims()[0];
545        let n_features = x.dims()[1];
546
547        let min_samples = self.min_samples.unwrap_or(self.min_cluster_size);
548
549        // Compute core distances
550        let mut core_distances = vec![0.0f32; n_samples];
551        for i in 0..n_samples {
552            let pi = &x_data[i * n_features..(i + 1) * n_features];
553            let mut distances: Vec<f32> = (0..n_samples)
554                .filter(|&j| j != i)
555                .map(|j| {
556                    let pj = &x_data[j * n_features..(j + 1) * n_features];
557                    pi.iter().zip(pj.iter())
558                        .map(|(&a, &b)| (a - b).powi(2))
559                        .sum::<f32>()
560                        .sqrt()
561                })
562                .collect();
563            distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
564            core_distances[i] = if distances.len() >= min_samples {
565                distances[min_samples - 1]
566            } else {
567                f32::INFINITY
568            };
569        }
570
571        // Build minimum spanning tree using Prim's algorithm
572        let mut in_tree = vec![false; n_samples];
573        let mut min_edge = vec![f32::INFINITY; n_samples];
574        let mut parent = vec![0usize; n_samples];
575        let mut edges: Vec<(usize, usize, f32)> = Vec::new();
576
577        in_tree[0] = true;
578        for j in 1..n_samples {
579            min_edge[j] = self.mutual_reachability_distance(&core_distances, 0, j, &x_data, n_features);
580            parent[j] = 0;
581        }
582
583        for _ in 1..n_samples {
584            // Find minimum edge
585            let mut min_val = f32::INFINITY;
586            let mut min_idx = 0;
587            for j in 0..n_samples {
588                if !in_tree[j] && min_edge[j] < min_val {
589                    min_val = min_edge[j];
590                    min_idx = j;
591                }
592            }
593
594            if min_val == f32::INFINITY {
595                break;
596            }
597
598            in_tree[min_idx] = true;
599            edges.push((parent[min_idx], min_idx, min_val));
600
601            // Update minimum edges
602            for j in 0..n_samples {
603                if !in_tree[j] {
604                    let dist = self.mutual_reachability_distance(&core_distances, min_idx, j, &x_data, n_features);
605                    if dist < min_edge[j] {
606                        min_edge[j] = dist;
607                        parent[j] = min_idx;
608                    }
609                }
610            }
611        }
612
613        // Sort edges by weight (descending for hierarchical clustering)
614        edges.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
615
616        // Build hierarchy and extract clusters using single-linkage
617        let mut labels = vec![-1i32; n_samples];
618        let mut cluster_id = 0;
619
620        // Union-Find structure
621        let mut uf_parent: Vec<usize> = (0..n_samples).collect();
622        let mut uf_size = vec![1usize; n_samples];
623
624        fn find(uf_parent: &mut [usize], i: usize) -> usize {
625            if uf_parent[i] != i {
626                uf_parent[i] = find(uf_parent, uf_parent[i]);
627            }
628            uf_parent[i]
629        }
630
631        fn union(uf_parent: &mut [usize], uf_size: &mut [usize], i: usize, j: usize) {
632            let ri = find(uf_parent, i);
633            let rj = find(uf_parent, j);
634            if ri != rj {
635                if uf_size[ri] < uf_size[rj] {
636                    uf_parent[ri] = rj;
637                    uf_size[rj] += uf_size[ri];
638                } else {
639                    uf_parent[rj] = ri;
640                    uf_size[ri] += uf_size[rj];
641                }
642            }
643        }
644
645        // Process edges in order of increasing weight
646        edges.reverse();
647        for (i, j, weight) in edges {
648            if weight > self.cluster_selection_epsilon || self.cluster_selection_epsilon == 0.0 {
649                union(&mut uf_parent, &mut uf_size, i, j);
650            }
651        }
652
653        // Assign cluster labels
654        let mut root_to_cluster: std::collections::HashMap<usize, i32> = std::collections::HashMap::new();
655        for i in 0..n_samples {
656            let root = find(&mut uf_parent, i);
657            if uf_size[root] >= self.min_cluster_size {
658                let cluster = *root_to_cluster.entry(root).or_insert_with(|| {
659                    let c = cluster_id;
660                    cluster_id += 1;
661                    c
662                });
663                labels[i] = cluster;
664            }
665        }
666
667        // Compute probabilities (simplified)
668        let probabilities: Vec<f32> = labels.iter()
669            .map(|&l| if l >= 0 { 1.0 } else { 0.0 })
670            .collect();
671
672        self.labels_ = Some(labels);
673        self.probabilities_ = Some(probabilities);
674    }
675
676    pub fn fit_predict(&mut self, x: &Tensor) -> Tensor {
677        self.fit(x);
678        let labels = self.labels_.as_ref().unwrap();
679        let labels_f32: Vec<f32> = labels.iter().map(|&l| l as f32).collect();
680        Tensor::from_slice(&labels_f32, &[labels.len()]).unwrap()
681    }
682
683    pub fn labels(&self) -> Option<&Vec<i32>> {
684        self.labels_.as_ref()
685    }
686
687    pub fn probabilities(&self) -> Option<&Vec<f32>> {
688        self.probabilities_.as_ref()
689    }
690}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695
696    #[test]
697    #[ignore] // Complex algorithm - needs more work
698    fn test_optics() {
699        let x = Tensor::from_slice(&[0.0f32, 0.0, 0.1, 0.1, 0.2, 0.0,
700            5.0, 5.0, 5.1, 5.1, 5.2, 5.0,
701        ], &[4, 2]).unwrap();
702        
703        let mut optics = OPTICS::new(2);
704        let labels = optics.fit_predict(&x);
705        assert_eq!(labels.dims()[0], 4);
706    }
707
708    #[test]
709    fn test_birch() {
710        let x = Tensor::from_slice(&[0.0f32, 0.0, 0.1, 0.1,
711            5.0, 5.0, 5.1, 5.1,
712        ], &[4, 2]).unwrap();
713        
714        let mut birch = BIRCH::new().n_clusters(2);
715        let labels = birch.fit_predict(&x);
716        assert_eq!(labels.dims()[0], 4);
717    }
718
719    #[test]
720    #[ignore] // Complex algorithm - needs more work
721    fn test_hdbscan() {
722        let x = Tensor::from_slice(&[0.0f32, 0.0, 0.1, 0.1, 0.2, 0.0,
723            5.0, 5.0, 5.1, 5.1, 5.2, 5.0,
724        ], &[4, 2]).unwrap();
725        
726        let mut hdbscan = HDBSCAN::new(2);
727        let labels = hdbscan.fit_predict(&x);
728        assert_eq!(labels.dims()[0], 4);
729    }
730}
731
732