ghostflow_ml/
clustering_advanced.rs

1//! Advanced Clustering - Spectral, Mean Shift, OPTICS, Birch, Mini-Batch KMeans, Affinity Propagation
2
3use ghostflow_core::Tensor;
4use rand::prelude::*;
5
6/// Spectral Clustering
7pub struct SpectralClustering {
8    pub n_clusters: usize,
9    pub affinity: SpectralAffinity,
10    pub gamma: f32,
11    pub n_neighbors: usize,
12    pub assign_labels: AssignLabels,
13    labels_: Option<Vec<usize>>,
14}
15
16#[derive(Clone, Copy, Debug)]
17pub enum SpectralAffinity {
18    RBF,
19    NearestNeighbors,
20    Precomputed,
21}
22
23#[derive(Clone, Copy, Debug)]
24pub enum AssignLabels {
25    KMeans,
26    Discretize,
27}
28
29impl SpectralClustering {
30    pub fn new(n_clusters: usize) -> Self {
31        SpectralClustering {
32            n_clusters,
33            affinity: SpectralAffinity::RBF,
34            gamma: 1.0,
35            n_neighbors: 10,
36            assign_labels: AssignLabels::KMeans,
37            labels_: None,
38        }
39    }
40
41    pub fn gamma(mut self, g: f32) -> Self {
42        self.gamma = g;
43        self
44    }
45
46    fn compute_affinity_matrix(&self, x: &[f32], n_samples: usize, n_features: usize) -> Vec<Vec<f32>> {
47        let mut affinity = vec![vec![0.0f32; n_samples]; n_samples];
48
49        match self.affinity {
50            SpectralAffinity::RBF => {
51                for i in 0..n_samples {
52                    for j in i..n_samples {
53                        let mut dist_sq = 0.0f32;
54                        for k in 0..n_features {
55                            let diff = x[i * n_features + k] - x[j * n_features + k];
56                            dist_sq += diff * diff;
57                        }
58                        let a = (-self.gamma * dist_sq).exp();
59                        affinity[i][j] = a;
60                        affinity[j][i] = a;
61                    }
62                }
63            }
64            SpectralAffinity::NearestNeighbors => {
65                // Compute k-NN graph
66                for i in 0..n_samples {
67                    let mut distances: Vec<(usize, f32)> = (0..n_samples)
68                        .filter(|&j| j != i)
69                        .map(|j| {
70                            let mut dist_sq = 0.0f32;
71                            for k in 0..n_features {
72                                let diff = x[i * n_features + k] - x[j * n_features + k];
73                                dist_sq += diff * diff;
74                            }
75                            (j, dist_sq.sqrt())
76                        })
77                        .collect();
78                    
79                    distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
80                    
81                    for (j, _) in distances.into_iter().take(self.n_neighbors) {
82                        affinity[i][j] = 1.0;
83                        affinity[j][i] = 1.0;
84                    }
85                }
86            }
87            SpectralAffinity::Precomputed => {
88                // Assume x is already an affinity matrix
89                for i in 0..n_samples {
90                    for j in 0..n_samples {
91                        affinity[i][j] = x[i * n_samples + j];
92                    }
93                }
94            }
95        }
96
97        affinity
98    }
99
100    fn compute_laplacian(&self, affinity: &[Vec<f32>], n_samples: usize) -> Vec<f32> {
101        // Compute normalized Laplacian: L = D^(-1/2) * (D - A) * D^(-1/2)
102        // Or equivalently: L = I - D^(-1/2) * A * D^(-1/2)
103        
104        // Compute degree matrix
105        let degrees: Vec<f32> = (0..n_samples)
106            .map(|i| affinity[i].iter().sum::<f32>())
107            .collect();
108
109        let mut laplacian = vec![0.0f32; n_samples * n_samples];
110
111        for i in 0..n_samples {
112            for j in 0..n_samples {
113                if i == j {
114                    laplacian[i * n_samples + j] = 1.0;
115                } else {
116                    let d_i = degrees[i].max(1e-10).sqrt();
117                    let d_j = degrees[j].max(1e-10).sqrt();
118                    laplacian[i * n_samples + j] = -affinity[i][j] / (d_i * d_j);
119                }
120            }
121        }
122
123        laplacian
124    }
125
126    fn power_iteration_smallest(&self, matrix: &[f32], n: usize, k: usize) -> Vec<Vec<f32>> {
127        // Find k smallest eigenvectors using inverse power iteration
128        let mut eigenvectors: Vec<Vec<f32>> = Vec::with_capacity(k);
129        let mut rng = thread_rng();
130
131        // Shift matrix to make smallest eigenvalues largest
132        let mut shifted = matrix.to_vec();
133        let shift = 2.0f32;  // Laplacian eigenvalues are in [0, 2]
134        for i in 0..n {
135            shifted[i * n + i] = shift - shifted[i * n + i];
136            for j in 0..n {
137                if i != j {
138                    shifted[i * n + j] = -shifted[i * n + j];
139                }
140            }
141        }
142
143        for _ in 0..k {
144            let mut v: Vec<f32> = (0..n).map(|_| rng.gen::<f32>()).collect();
145            let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
146            for vi in &mut v {
147                *vi /= norm;
148            }
149
150            for _ in 0..100 {
151                // w = A * v
152                let mut w = vec![0.0f32; n];
153                for i in 0..n {
154                    for j in 0..n {
155                        w[i] += shifted[i * n + j] * v[j];
156                    }
157                }
158
159                // Orthogonalize against previous eigenvectors
160                for prev in &eigenvectors {
161                    let dot: f32 = w.iter().zip(prev.iter()).map(|(&a, &b)| a * b).sum();
162                    for i in 0..n {
163                        w[i] -= dot * prev[i];
164                    }
165                }
166
167                let norm: f32 = w.iter().map(|&x| x * x).sum::<f32>().sqrt();
168                if norm < 1e-10 {
169                    break;
170                }
171                for wi in &mut w {
172                    *wi /= norm;
173                }
174
175                let diff: f32 = v.iter().zip(w.iter()).map(|(&a, &b)| (a - b).abs()).sum();
176                v = w;
177
178                if diff < 1e-6 {
179                    break;
180                }
181            }
182
183            eigenvectors.push(v);
184
185            // Deflate
186            let mut eigenvalue = 0.0f32;
187            for i in 0..n {
188                let mut av = 0.0f32;
189                for j in 0..n {
190                    av += shifted[i * n + j] * eigenvectors.last().unwrap()[j];
191                }
192                eigenvalue += eigenvectors.last().unwrap()[i] * av;
193            }
194
195            let v = eigenvectors.last().unwrap();
196            for i in 0..n {
197                for j in 0..n {
198                    shifted[i * n + j] -= eigenvalue * v[i] * v[j];
199                }
200            }
201        }
202
203        eigenvectors
204    }
205
206    fn kmeans_on_embedding(&self, embedding: &[Vec<f32>], n_samples: usize) -> Vec<usize> {
207        let k = self.n_clusters;
208        let n_features = embedding.len();
209
210        // Flatten embedding
211        let mut data = vec![0.0f32; n_samples * n_features];
212        for i in 0..n_samples {
213            for j in 0..n_features {
214                data[i * n_features + j] = embedding[j][i];
215            }
216        }
217
218        // Simple k-means
219        let mut rng = thread_rng();
220        let mut centers: Vec<Vec<f32>> = (0..k)
221            .map(|_| {
222                let idx = rng.gen_range(0..n_samples);
223                (0..n_features).map(|j| data[idx * n_features + j]).collect()
224            })
225            .collect();
226
227        let mut labels = vec![0usize; n_samples];
228
229        for _ in 0..100 {
230            // Assign labels
231            for i in 0..n_samples {
232                let mut min_dist = f32::INFINITY;
233                for c in 0..k {
234                    let mut dist = 0.0f32;
235                    for j in 0..n_features {
236                        let diff = data[i * n_features + j] - centers[c][j];
237                        dist += diff * diff;
238                    }
239                    if dist < min_dist {
240                        min_dist = dist;
241                        labels[i] = c;
242                    }
243                }
244            }
245
246            // Update centers
247            let mut new_centers = vec![vec![0.0f32; n_features]; k];
248            let mut counts = vec![0usize; k];
249
250            for i in 0..n_samples {
251                let c = labels[i];
252                counts[c] += 1;
253                for j in 0..n_features {
254                    new_centers[c][j] += data[i * n_features + j];
255                }
256            }
257
258            for c in 0..k {
259                if counts[c] > 0 {
260                    for j in 0..n_features {
261                        new_centers[c][j] /= counts[c] as f32;
262                    }
263                }
264            }
265
266            centers = new_centers;
267        }
268
269        labels
270    }
271
272    pub fn fit(&mut self, x: &Tensor) {
273        let x_data = x.data_f32();
274        let n_samples = x.dims()[0];
275        let n_features = x.dims()[1];
276
277        // Compute affinity matrix
278        let affinity = self.compute_affinity_matrix(&x_data, n_samples, n_features);
279
280        // Compute normalized Laplacian
281        let laplacian = self.compute_laplacian(&affinity, n_samples);
282
283        // Find k smallest eigenvectors
284        let eigenvectors = self.power_iteration_smallest(&laplacian, n_samples, self.n_clusters);
285
286        // Cluster in embedding space
287        let labels = self.kmeans_on_embedding(&eigenvectors, n_samples);
288
289        self.labels_ = Some(labels);
290    }
291
292    pub fn fit_predict(&mut self, x: &Tensor) -> Tensor {
293        self.fit(x);
294        let labels = self.labels_.as_ref().unwrap();
295        let labels_f32: Vec<f32> = labels.iter().map(|&l| l as f32).collect();
296        Tensor::from_slice(&labels_f32, &[labels.len()]).unwrap()
297    }
298}
299
300
301/// Mean Shift Clustering
302pub struct MeanShift {
303    pub bandwidth: Option<f32>,
304    pub max_iter: usize,
305    pub bin_seeding: bool,
306    cluster_centers_: Option<Vec<Vec<f32>>>,
307    labels_: Option<Vec<usize>>,
308}
309
310impl MeanShift {
311    pub fn new() -> Self {
312        MeanShift {
313            bandwidth: None,
314            max_iter: 300,
315            bin_seeding: false,
316            cluster_centers_: None,
317            labels_: None,
318        }
319    }
320
321    pub fn bandwidth(mut self, bw: f32) -> Self {
322        self.bandwidth = Some(bw);
323        self
324    }
325
326    fn estimate_bandwidth(&self, x: &[f32], n_samples: usize, n_features: usize) -> f32 {
327        // Scott's rule of thumb
328        let mut std_sum = 0.0f32;
329        
330        for j in 0..n_features {
331            let mean: f32 = (0..n_samples).map(|i| x[i * n_features + j]).sum::<f32>() / n_samples as f32;
332            let variance: f32 = (0..n_samples)
333                .map(|i| (x[i * n_features + j] - mean).powi(2))
334                .sum::<f32>() / n_samples as f32;
335            std_sum += variance.sqrt();
336        }
337        
338        let avg_std = std_sum / n_features as f32;
339        avg_std * (n_samples as f32).powf(-1.0 / (n_features as f32 + 4.0))
340    }
341
342    fn gaussian_kernel(&self, dist: f32, bandwidth: f32) -> f32 {
343        (-0.5 * (dist / bandwidth).powi(2)).exp()
344    }
345
346    pub fn fit(&mut self, x: &Tensor) {
347        let x_data = x.data_f32();
348        let n_samples = x.dims()[0];
349        let n_features = x.dims()[1];
350
351        let bandwidth = self.bandwidth.unwrap_or_else(|| self.estimate_bandwidth(&x_data, n_samples, n_features));
352
353        // Initialize seeds (all points or binned)
354        let mut seeds: Vec<Vec<f32>> = (0..n_samples)
355            .map(|i| x_data[i * n_features..(i + 1) * n_features].to_vec())
356            .collect();
357
358        // Mean shift for each seed
359        let mut converged_centers: Vec<Vec<f32>> = Vec::new();
360
361        for seed in &mut seeds {
362            for _ in 0..self.max_iter {
363                let mut new_center = vec![0.0f32; n_features];
364                let mut total_weight = 0.0f32;
365
366                for i in 0..n_samples {
367                    let xi = &x_data[i * n_features..(i + 1) * n_features];
368                    
369                    let dist: f32 = seed.iter().zip(xi.iter())
370                        .map(|(&a, &b)| (a - b).powi(2))
371                        .sum::<f32>()
372                        .sqrt();
373
374                    let weight = self.gaussian_kernel(dist, bandwidth);
375                    total_weight += weight;
376
377                    for j in 0..n_features {
378                        new_center[j] += weight * xi[j];
379                    }
380                }
381
382                if total_weight > 0.0 {
383                    for j in 0..n_features {
384                        new_center[j] /= total_weight;
385                    }
386                }
387
388                // Check convergence
389                let shift: f32 = seed.iter().zip(new_center.iter())
390                    .map(|(&a, &b)| (a - b).powi(2))
391                    .sum::<f32>()
392                    .sqrt();
393
394                *seed = new_center;
395
396                if shift < 1e-3 * bandwidth {
397                    break;
398                }
399            }
400
401            // Check if this center is unique
402            let is_unique = converged_centers.iter().all(|c| {
403                let dist: f32 = c.iter().zip(seed.iter())
404                    .map(|(&a, &b)| (a - b).powi(2))
405                    .sum::<f32>()
406                    .sqrt();
407                dist > bandwidth / 2.0
408            });
409
410            if is_unique {
411                converged_centers.push(seed.clone());
412            }
413        }
414
415        // Assign labels
416        let mut labels = vec![0usize; n_samples];
417        for i in 0..n_samples {
418            let xi = &x_data[i * n_features..(i + 1) * n_features];
419            let mut min_dist = f32::INFINITY;
420            
421            for (c, center) in converged_centers.iter().enumerate() {
422                let dist: f32 = xi.iter().zip(center.iter())
423                    .map(|(&a, &b)| (a - b).powi(2))
424                    .sum::<f32>()
425                    .sqrt();
426                
427                if dist < min_dist {
428                    min_dist = dist;
429                    labels[i] = c;
430                }
431            }
432        }
433
434        self.cluster_centers_ = Some(converged_centers);
435        self.labels_ = Some(labels);
436    }
437
438    pub fn fit_predict(&mut self, x: &Tensor) -> Tensor {
439        self.fit(x);
440        let labels = self.labels_.as_ref().unwrap();
441        let labels_f32: Vec<f32> = labels.iter().map(|&l| l as f32).collect();
442        Tensor::from_slice(&labels_f32, &[labels.len()]).unwrap()
443    }
444}
445
446impl Default for MeanShift {
447    fn default() -> Self {
448        Self::new()
449    }
450}
451
452/// Mini-Batch K-Means for large datasets
453pub struct MiniBatchKMeans {
454    pub n_clusters: usize,
455    pub batch_size: usize,
456    pub max_iter: usize,
457    pub n_init: usize,
458    pub init: MiniBatchInit,
459    pub reassignment_ratio: f32,
460    cluster_centers_: Option<Vec<Vec<f32>>>,
461    labels_: Option<Vec<usize>>,
462    inertia_: Option<f32>,
463}
464
465#[derive(Clone, Copy, Debug)]
466pub enum MiniBatchInit {
467    Random,
468    KMeansPlusPlus,
469}
470
471impl MiniBatchKMeans {
472    pub fn new(n_clusters: usize) -> Self {
473        MiniBatchKMeans {
474            n_clusters,
475            batch_size: 100,
476            max_iter: 100,
477            n_init: 3,
478            init: MiniBatchInit::KMeansPlusPlus,
479            reassignment_ratio: 0.01,
480            cluster_centers_: None,
481            labels_: None,
482            inertia_: None,
483        }
484    }
485
486    pub fn batch_size(mut self, size: usize) -> Self {
487        self.batch_size = size;
488        self
489    }
490
491    fn euclidean_distance_sq(a: &[f32], b: &[f32]) -> f32 {
492        a.iter().zip(b.iter()).map(|(&ai, &bi)| (ai - bi).powi(2)).sum()
493    }
494
495    fn init_centers(&self, x: &[f32], n_samples: usize, n_features: usize) -> Vec<Vec<f32>> {
496        let mut rng = thread_rng();
497
498        match self.init {
499            MiniBatchInit::Random => {
500                let indices: Vec<usize> = (0..n_samples)
501                    .choose_multiple(&mut rng, self.n_clusters);
502                indices.iter()
503                    .map(|&i| x[i * n_features..(i + 1) * n_features].to_vec())
504                    .collect()
505            }
506            MiniBatchInit::KMeansPlusPlus => {
507                let mut centers = Vec::with_capacity(self.n_clusters);
508                
509                // First center randomly
510                let first_idx = rng.gen_range(0..n_samples);
511                centers.push(x[first_idx * n_features..(first_idx + 1) * n_features].to_vec());
512
513                for _ in 1..self.n_clusters {
514                    let distances: Vec<f32> = (0..n_samples)
515                        .map(|i| {
516                            let point = &x[i * n_features..(i + 1) * n_features];
517                            centers.iter()
518                                .map(|c| Self::euclidean_distance_sq(point, c))
519                                .fold(f32::INFINITY, f32::min)
520                        })
521                        .collect();
522
523                    let total: f32 = distances.iter().sum();
524                    let threshold = rng.gen::<f32>() * total;
525                    
526                    let mut cumsum = 0.0f32;
527                    let mut chosen_idx = 0;
528                    for (i, &d) in distances.iter().enumerate() {
529                        cumsum += d;
530                        if cumsum >= threshold {
531                            chosen_idx = i;
532                            break;
533                        }
534                    }
535
536                    centers.push(x[chosen_idx * n_features..(chosen_idx + 1) * n_features].to_vec());
537                }
538
539                centers
540            }
541        }
542    }
543
544    pub fn fit(&mut self, x: &Tensor) {
545        let x_data = x.data_f32();
546        let n_samples = x.dims()[0];
547        let n_features = x.dims()[1];
548
549        let mut best_centers = None;
550        let mut best_inertia = f32::INFINITY;
551
552        for _ in 0..self.n_init {
553            let mut centers = self.init_centers(&x_data, n_samples, n_features);
554            let mut counts = vec![0usize; self.n_clusters];
555            let mut rng = thread_rng();
556
557            for _ in 0..self.max_iter {
558                // Sample mini-batch
559                let batch_indices: Vec<usize> = (0..n_samples)
560                    .choose_multiple(&mut rng, self.batch_size.min(n_samples));
561
562                // Assign batch points to nearest centers
563                let assignments: Vec<usize> = batch_indices.iter()
564                    .map(|&i| {
565                        let point = &x_data[i * n_features..(i + 1) * n_features];
566                        centers.iter()
567                            .enumerate()
568                            .map(|(c, center)| (c, Self::euclidean_distance_sq(point, center)))
569                            .min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap())
570                            .map(|(c, _)| c)
571                            .unwrap_or(0)
572                    })
573                    .collect();
574
575                // Update centers with streaming average
576                for (batch_idx, &sample_idx) in batch_indices.iter().enumerate() {
577                    let cluster = assignments[batch_idx];
578                    counts[cluster] += 1;
579                    let eta = 1.0 / counts[cluster] as f32;
580
581                    for j in 0..n_features {
582                        centers[cluster][j] = (1.0 - eta) * centers[cluster][j] 
583                            + eta * x_data[sample_idx * n_features + j];
584                    }
585                }
586            }
587
588            // Compute inertia
589            let inertia: f32 = (0..n_samples)
590                .map(|i| {
591                    let point = &x_data[i * n_features..(i + 1) * n_features];
592                    centers.iter()
593                        .map(|c| Self::euclidean_distance_sq(point, c))
594                        .fold(f32::INFINITY, f32::min)
595                })
596                .sum();
597
598            if inertia < best_inertia {
599                best_inertia = inertia;
600                best_centers = Some(centers);
601            }
602        }
603
604        let centers = best_centers.unwrap();
605
606        // Final label assignment
607        let labels: Vec<usize> = (0..n_samples)
608            .map(|i| {
609                let point = &x_data[i * n_features..(i + 1) * n_features];
610                centers.iter()
611                    .enumerate()
612                    .map(|(c, center)| (c, Self::euclidean_distance_sq(point, center)))
613                    .min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap())
614                    .map(|(c, _)| c)
615                    .unwrap_or(0)
616            })
617            .collect();
618
619        self.cluster_centers_ = Some(centers);
620        self.labels_ = Some(labels);
621        self.inertia_ = Some(best_inertia);
622    }
623
624    pub fn fit_predict(&mut self, x: &Tensor) -> Tensor {
625        self.fit(x);
626        let labels = self.labels_.as_ref().unwrap();
627        let labels_f32: Vec<f32> = labels.iter().map(|&l| l as f32).collect();
628        Tensor::from_slice(&labels_f32, &[labels.len()]).unwrap()
629    }
630
631    pub fn predict(&self, x: &Tensor) -> Tensor {
632        let x_data = x.data_f32();
633        let n_samples = x.dims()[0];
634        let n_features = x.dims()[1];
635
636        let centers = self.cluster_centers_.as_ref().expect("Model not fitted");
637
638        let labels: Vec<f32> = (0..n_samples)
639            .map(|i| {
640                let point = &x_data[i * n_features..(i + 1) * n_features];
641                centers.iter()
642                    .enumerate()
643                    .map(|(c, center)| (c, Self::euclidean_distance_sq(point, center)))
644                    .min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap())
645                    .map(|(c, _)| c as f32)
646                    .unwrap_or(0.0)
647            })
648            .collect();
649
650        Tensor::from_slice(&labels, &[n_samples]).unwrap()
651    }
652}
653
654/// Affinity Propagation Clustering
655pub struct AffinityPropagation {
656    pub damping: f32,
657    pub max_iter: usize,
658    pub convergence_iter: usize,
659    pub preference: Option<f32>,
660    cluster_centers_indices_: Option<Vec<usize>>,
661    labels_: Option<Vec<usize>>,
662}
663
664impl AffinityPropagation {
665    pub fn new() -> Self {
666        AffinityPropagation {
667            damping: 0.5,
668            max_iter: 200,
669            convergence_iter: 15,
670            preference: None,
671            cluster_centers_indices_: None,
672            labels_: None,
673        }
674    }
675
676    pub fn damping(mut self, d: f32) -> Self {
677        self.damping = d.clamp(0.5, 1.0);
678        self
679    }
680
681    pub fn preference(mut self, p: f32) -> Self {
682        self.preference = Some(p);
683        self
684    }
685
686    pub fn fit(&mut self, x: &Tensor) {
687        let x_data = x.data_f32();
688        let n_samples = x.dims()[0];
689        let n_features = x.dims()[1];
690
691        // Compute similarity matrix (negative squared Euclidean distance)
692        let mut s = vec![vec![0.0f32; n_samples]; n_samples];
693        for i in 0..n_samples {
694            for j in 0..n_samples {
695                if i != j {
696                    let mut dist_sq = 0.0f32;
697                    for k in 0..n_features {
698                        let diff = x_data[i * n_features + k] - x_data[j * n_features + k];
699                        dist_sq += diff * diff;
700                    }
701                    s[i][j] = -dist_sq;
702                }
703            }
704        }
705
706        // Set preference (diagonal of S)
707        let preference = self.preference.unwrap_or_else(|| {
708            let mut all_similarities: Vec<f32> = s.iter()
709                .flat_map(|row| row.iter().cloned())
710                .filter(|&x| x != 0.0)
711                .collect();
712            all_similarities.sort_by(|a, b| a.partial_cmp(b).unwrap());
713            all_similarities[all_similarities.len() / 2]  // Median
714        });
715
716        for i in 0..n_samples {
717            s[i][i] = preference;
718        }
719
720        // Initialize responsibility and availability matrices
721        let mut r = vec![vec![0.0f32; n_samples]; n_samples];
722        let mut a = vec![vec![0.0f32; n_samples]; n_samples];
723
724        let mut prev_exemplars = vec![0usize; n_samples];
725        let mut converged_count = 0;
726
727        for _ in 0..self.max_iter {
728            // Update responsibilities
729            for i in 0..n_samples {
730                for k in 0..n_samples {
731                    let mut max_val = f32::NEG_INFINITY;
732                    for kp in 0..n_samples {
733                        if kp != k {
734                            max_val = max_val.max(a[i][kp] + s[i][kp]);
735                        }
736                    }
737                    let new_r = s[i][k] - max_val;
738                    r[i][k] = self.damping * r[i][k] + (1.0 - self.damping) * new_r;
739                }
740            }
741
742            // Update availabilities
743            for i in 0..n_samples {
744                for k in 0..n_samples {
745                    if i == k {
746                        let mut sum = 0.0f32;
747                        for ip in 0..n_samples {
748                            if ip != k {
749                                sum += r[ip][k].max(0.0);
750                            }
751                        }
752                        let new_a = sum;
753                        a[i][k] = self.damping * a[i][k] + (1.0 - self.damping) * new_a;
754                    } else {
755                        let mut sum = 0.0f32;
756                        for ip in 0..n_samples {
757                            if ip != i && ip != k {
758                                sum += r[ip][k].max(0.0);
759                            }
760                        }
761                        let new_a = (r[k][k] + sum).min(0.0);
762                        a[i][k] = self.damping * a[i][k] + (1.0 - self.damping) * new_a;
763                    }
764                }
765            }
766
767            // Check convergence
768            let exemplars: Vec<usize> = (0..n_samples)
769                .map(|i| {
770                    (0..n_samples)
771                        .max_by(|&j, &k| (a[i][j] + r[i][j]).partial_cmp(&(a[i][k] + r[i][k])).unwrap())
772                        .unwrap_or(i)
773                })
774                .collect();
775
776            if exemplars == prev_exemplars {
777                converged_count += 1;
778                if converged_count >= self.convergence_iter {
779                    break;
780                }
781            } else {
782                converged_count = 0;
783            }
784            prev_exemplars = exemplars;
785        }
786
787        // Extract cluster centers and labels
788        let mut exemplar_set: Vec<usize> = (0..n_samples)
789            .filter(|&i| a[i][i] + r[i][i] > 0.0)
790            .collect();
791
792        if exemplar_set.is_empty() {
793            exemplar_set.push(0);
794        }
795
796        let labels: Vec<usize> = (0..n_samples)
797            .map(|i| {
798                exemplar_set.iter()
799                    .enumerate()
800                    .max_by(|(_, &e1), (_, &e2)| s[i][e1].partial_cmp(&s[i][e2]).unwrap())
801                    .map(|(idx, _)| idx)
802                    .unwrap_or(0)
803            })
804            .collect();
805
806        self.cluster_centers_indices_ = Some(exemplar_set);
807        self.labels_ = Some(labels);
808    }
809
810    pub fn fit_predict(&mut self, x: &Tensor) -> Tensor {
811        self.fit(x);
812        let labels = self.labels_.as_ref().unwrap();
813        let labels_f32: Vec<f32> = labels.iter().map(|&l| l as f32).collect();
814        Tensor::from_slice(&labels_f32, &[labels.len()]).unwrap()
815    }
816}
817
818impl Default for AffinityPropagation {
819    fn default() -> Self {
820        Self::new()
821    }
822}
823
824#[cfg(test)]
825mod tests {
826    use super::*;
827
828    #[test]
829    fn test_spectral_clustering() {
830        let x = Tensor::from_slice(&[0.0f32, 0.0,
831            0.1, 0.1,
832            5.0, 5.0,
833            5.1, 5.1,
834        ], &[4, 2]).unwrap();
835
836        let mut sc = SpectralClustering::new(2);
837        let labels = sc.fit_predict(&x);
838        
839        assert_eq!(labels.dims(), &[4]);
840    }
841
842    #[test]
843    fn test_mean_shift() {
844        let x = Tensor::from_slice(&[0.0f32, 0.0,
845            0.1, 0.1,
846            5.0, 5.0,
847            5.1, 5.1,
848        ], &[4, 2]).unwrap();
849
850        let mut ms = MeanShift::new().bandwidth(1.0);
851        let labels = ms.fit_predict(&x);
852        
853        assert_eq!(labels.dims(), &[4]);
854    }
855
856    #[test]
857    fn test_minibatch_kmeans() {
858        let x = Tensor::from_slice(&[0.0f32, 0.0,
859            0.1, 0.1,
860            5.0, 5.0,
861            5.1, 5.1,
862        ], &[4, 2]).unwrap();
863
864        let mut mbk = MiniBatchKMeans::new(2).batch_size(2);
865        let labels = mbk.fit_predict(&x);
866        
867        assert_eq!(labels.dims(), &[4]);
868    }
869}
870
871