Skip to main content

sphereql_layout/
clustered.rs

1use std::f64::consts::PI;
2
3use sphereql_core::{
4    CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical,
5    spherical_to_cartesian,
6};
7
8use crate::traits::{DimensionMapper, LayoutStrategy};
9use crate::types::{LayoutEntry, LayoutQuality, LayoutResult};
10
11const MAX_KMEANS_ITERATIONS: usize = 50;
12const OVERLAP_THRESHOLD: f64 = 0.01;
13
14pub struct ClusteredLayout {
15    pub num_clusters: usize,
16    pub radius: f64,
17    pub intra_cluster_spread: f64,
18}
19
20impl Default for ClusteredLayout {
21    fn default() -> Self {
22        Self {
23            num_clusters: 4,
24            radius: 1.0,
25            intra_cluster_spread: 0.3,
26        }
27    }
28}
29
30impl ClusteredLayout {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    pub fn with_clusters(mut self, n: usize) -> Self {
36        self.num_clusters = n;
37        self
38    }
39
40    pub fn with_radius(mut self, r: f64) -> Self {
41        self.radius = r;
42        self
43    }
44
45    pub fn with_spread(mut self, s: f64) -> Self {
46        self.intra_cluster_spread = s;
47        self
48    }
49}
50
51fn evenly_spaced_centers(k: usize) -> Vec<CartesianPoint> {
52    let golden_ratio = (1.0 + 5.0_f64.sqrt()) / 2.0;
53    (0..k)
54        .map(|i| {
55            let phi = (1.0 - 2.0 * (i as f64 + 0.5) / k as f64)
56                .clamp(-1.0, 1.0)
57                .acos();
58            let theta = (2.0 * PI * (i as f64) / golden_ratio).rem_euclid(2.0 * PI);
59            let sp = SphericalPoint::new_unchecked(1.0, theta, phi);
60            spherical_to_cartesian(&sp)
61        })
62        .collect()
63}
64
65fn normalized_mean(points: &[CartesianPoint]) -> CartesianPoint {
66    if points.is_empty() {
67        return CartesianPoint::new(0.0, 0.0, 1.0);
68    }
69    let mut sx = 0.0;
70    let mut sy = 0.0;
71    let mut sz = 0.0;
72    for p in points {
73        sx += p.x;
74        sy += p.y;
75        sz += p.z;
76    }
77    let mean = CartesianPoint::new(sx, sy, sz);
78    let n = mean.normalize();
79    if n.magnitude() == 0.0 {
80        points[0].normalize()
81    } else {
82        n
83    }
84}
85
86struct KMeansResult {
87    assignments: Vec<usize>,
88    centers: Vec<CartesianPoint>,
89}
90
91fn kmeans_spherical(
92    mapped_cartesian: &[CartesianPoint],
93    mapped_spherical: &[SphericalPoint],
94    k: usize,
95) -> KMeansResult {
96    let n = mapped_cartesian.len();
97
98    let mut centers: Vec<CartesianPoint> = if n >= k {
99        mapped_cartesian[..k]
100            .iter()
101            .map(|c| c.normalize())
102            .collect()
103    } else {
104        evenly_spaced_centers(k)
105    };
106
107    let mut assignments = vec![0usize; n];
108
109    for _ in 0..MAX_KMEANS_ITERATIONS {
110        let mut changed = false;
111
112        for (i, sp) in mapped_spherical.iter().enumerate() {
113            let mut best = 0;
114            let mut best_dist = f64::MAX;
115            for (j, center) in centers.iter().enumerate() {
116                let center_sp = cartesian_to_spherical(center);
117                let d = angular_distance(sp, &center_sp);
118                if d < best_dist {
119                    best_dist = d;
120                    best = j;
121                }
122            }
123            if assignments[i] != best {
124                assignments[i] = best;
125                changed = true;
126            }
127        }
128
129        if !changed {
130            break;
131        }
132
133        let mut cluster_points: Vec<Vec<CartesianPoint>> = vec![vec![]; k];
134        for (i, &a) in assignments.iter().enumerate() {
135            cluster_points[a].push(mapped_cartesian[i]);
136        }
137
138        for (j, cp) in cluster_points.iter().enumerate() {
139            if cp.is_empty() {
140                let mut farthest_idx = 0;
141                let mut farthest_dist = 0.0_f64;
142                for (i, sp) in mapped_spherical.iter().enumerate() {
143                    let center_sp = cartesian_to_spherical(&centers[assignments[i]]);
144                    let d = angular_distance(sp, &center_sp);
145                    if d > farthest_dist {
146                        farthest_dist = d;
147                        farthest_idx = i;
148                    }
149                }
150                centers[j] = mapped_cartesian[farthest_idx].normalize();
151            } else {
152                centers[j] = normalized_mean(cp);
153            }
154        }
155    }
156
157    KMeansResult {
158        assignments,
159        centers,
160    }
161}
162
163fn fibonacci_sub_spiral(
164    center: &SphericalPoint,
165    count: usize,
166    spread: f64,
167    radius: f64,
168) -> Vec<SphericalPoint> {
169    if count == 0 {
170        return vec![];
171    }
172    if count == 1 {
173        return vec![SphericalPoint::new_unchecked(
174            radius,
175            center.theta,
176            center.phi,
177        )];
178    }
179
180    let golden_angle = PI * (3.0 - 5.0_f64.sqrt());
181    let center_cart = spherical_to_cartesian(&SphericalPoint::new_unchecked(
182        1.0,
183        center.theta,
184        center.phi,
185    ));
186
187    let (tangent_u, tangent_v) = local_frame(&center_cart);
188
189    (0..count)
190        .map(|i| {
191            let frac = i as f64 / count as f64;
192            let angular_r = spread * frac.sqrt();
193            let angle = golden_angle * i as f64;
194
195            let offset_u = angular_r * angle.cos();
196            let offset_v = angular_r * angle.sin();
197
198            let displaced = CartesianPoint::new(
199                center_cart.x + offset_u * tangent_u.x + offset_v * tangent_v.x,
200                center_cart.y + offset_u * tangent_u.y + offset_v * tangent_v.y,
201                center_cart.z + offset_u * tangent_u.z + offset_v * tangent_v.z,
202            )
203            .normalize();
204
205            let sp = cartesian_to_spherical(&displaced);
206            SphericalPoint::new_unchecked(radius, sp.theta, sp.phi)
207        })
208        .collect()
209}
210
211fn local_frame(center: &CartesianPoint) -> (CartesianPoint, CartesianPoint) {
212    let up = if center.z.abs() < 0.9 {
213        CartesianPoint::new(0.0, 0.0, 1.0)
214    } else {
215        CartesianPoint::new(1.0, 0.0, 0.0)
216    };
217
218    // u = normalize(up x center)
219    let ux = up.y * center.z - up.z * center.y;
220    let uy = up.z * center.x - up.x * center.z;
221    let uz = up.x * center.y - up.y * center.x;
222    let u = CartesianPoint::new(ux, uy, uz).normalize();
223
224    // v = center x u
225    let vx = center.y * u.z - center.z * u.y;
226    let vy = center.z * u.x - center.x * u.z;
227    let vz = center.x * u.y - center.y * u.x;
228    let v = CartesianPoint::new(vx, vy, vz).normalize();
229
230    (u, v)
231}
232
233const MAX_QUALITY_N: usize = 5000;
234
235fn compute_quality(
236    positions: &[SphericalPoint],
237    assignments: &[usize],
238    num_clusters: usize,
239) -> LayoutQuality {
240    let n = positions.len();
241
242    if n <= 1 {
243        return LayoutQuality {
244            dispersion_score: if n == 0 { 0.0 } else { 1.0 },
245            overlap_score: 0.0,
246            silhouette_score: 0.0,
247        };
248    }
249
250    let (positions, assignments, n) = if n > MAX_QUALITY_N {
251        let step = n / MAX_QUALITY_N;
252        let sampled_pos: Vec<_> = positions
253            .iter()
254            .step_by(step)
255            .take(MAX_QUALITY_N)
256            .copied()
257            .collect();
258        let sampled_asgn: Vec<_> = assignments
259            .iter()
260            .step_by(step)
261            .take(MAX_QUALITY_N)
262            .copied()
263            .collect();
264        let len = sampled_pos.len();
265        (sampled_pos, sampled_asgn, len)
266    } else {
267        (positions.to_vec(), assignments.to_vec(), n)
268    };
269
270    // Dispersion: average inter-cluster center distance / PI
271    let mut cluster_point_sets: Vec<Vec<CartesianPoint>> = vec![vec![]; num_clusters];
272    for (i, &a) in assignments.iter().enumerate() {
273        cluster_point_sets[a].push(spherical_to_cartesian(&positions[i]));
274    }
275    let active_centers: Vec<SphericalPoint> = cluster_point_sets
276        .iter()
277        .filter(|cp| !cp.is_empty())
278        .map(|cp| cartesian_to_spherical(&normalized_mean(cp)))
279        .collect();
280
281    let dispersion_score = if active_centers.len() >= 2 {
282        let mut sum = 0.0;
283        let mut count = 0;
284        for i in 0..active_centers.len() {
285            for j in (i + 1)..active_centers.len() {
286                sum += angular_distance(&active_centers[i], &active_centers[j]);
287                count += 1;
288            }
289        }
290        (sum / count as f64 / PI).clamp(0.0, 1.0)
291    } else {
292        0.0
293    };
294
295    // Overlap: fraction of pairs within threshold
296    let mut overlap_count = 0u64;
297    let total_pairs = (n * (n - 1)) / 2;
298    for i in 0..n {
299        for j in (i + 1)..n {
300            if angular_distance(&positions[i], &positions[j]) < OVERLAP_THRESHOLD {
301                overlap_count += 1;
302            }
303        }
304    }
305    let overlap_score = if total_pairs > 0 {
306        overlap_count as f64 / total_pairs as f64
307    } else {
308        0.0
309    };
310
311    // Silhouette coefficient
312    let silhouette_score = if num_clusters <= 1 || active_centers.len() <= 1 {
313        0.0
314    } else {
315        let mut sil_sum = 0.0;
316        for i in 0..n {
317            let ci = assignments[i];
318
319            // a(i) = mean distance to same-cluster members
320            let mut a_sum = 0.0;
321            let mut a_count = 0;
322            for j in 0..n {
323                if j != i && assignments[j] == ci {
324                    a_sum += angular_distance(&positions[i], &positions[j]);
325                    a_count += 1;
326                }
327            }
328            let a = if a_count > 0 {
329                a_sum / a_count as f64
330            } else {
331                0.0
332            };
333
334            // b(i) = min over other clusters of mean distance to that cluster
335            let mut b = f64::MAX;
336            for k in 0..num_clusters {
337                if k == ci {
338                    continue;
339                }
340                let mut b_sum = 0.0;
341                let mut b_count = 0;
342                for j in 0..n {
343                    if assignments[j] == k {
344                        b_sum += angular_distance(&positions[i], &positions[j]);
345                        b_count += 1;
346                    }
347                }
348                if b_count > 0 {
349                    let mean_dist = b_sum / b_count as f64;
350                    if mean_dist < b {
351                        b = mean_dist;
352                    }
353                }
354            }
355            if b == f64::MAX {
356                b = 0.0;
357            }
358
359            let denom = a.max(b);
360            let s = if denom > 0.0 { (b - a) / denom } else { 0.0 };
361            sil_sum += s;
362        }
363        sil_sum / n as f64
364    };
365
366    LayoutQuality {
367        dispersion_score,
368        overlap_score,
369        silhouette_score,
370    }
371}
372
373impl<T: Clone + Send + Sync> LayoutStrategy<T> for ClusteredLayout {
374    fn layout(&self, items: &[T], mapper: &dyn DimensionMapper<Item = T>) -> LayoutResult<T> {
375        if items.is_empty() {
376            return LayoutResult {
377                entries: vec![],
378                quality: LayoutQuality::default(),
379            };
380        }
381
382        let mapped: Vec<SphericalPoint> = items.iter().map(|item| mapper.map(item)).collect();
383        let mapped_cart: Vec<CartesianPoint> = mapped.iter().map(spherical_to_cartesian).collect();
384
385        let k = self.num_clusters.min(items.len()).max(1);
386        let km = kmeans_spherical(&mapped_cart, &mapped, k);
387
388        let mut cluster_items: Vec<Vec<usize>> = vec![vec![]; k];
389        for (i, &a) in km.assignments.iter().enumerate() {
390            cluster_items[a].push(i);
391        }
392
393        let mut entries: Vec<(usize, LayoutEntry<T>)> = Vec::with_capacity(items.len());
394        let mut final_positions: Vec<(usize, SphericalPoint)> = Vec::with_capacity(items.len());
395        let mut final_assignments = vec![0usize; items.len()];
396
397        for (cluster_idx, member_indices) in cluster_items.iter().enumerate() {
398            let center_sp = cartesian_to_spherical(&km.centers[cluster_idx]);
399            let sub_positions = fibonacci_sub_spiral(
400                &center_sp,
401                member_indices.len(),
402                self.intra_cluster_spread,
403                self.radius,
404            );
405
406            for (sub_idx, &item_idx) in member_indices.iter().enumerate() {
407                let pos = sub_positions[sub_idx];
408                entries.push((
409                    item_idx,
410                    LayoutEntry {
411                        item: items[item_idx].clone(),
412                        position: pos,
413                    },
414                ));
415                final_positions.push((item_idx, pos));
416                final_assignments[item_idx] = cluster_idx;
417            }
418        }
419
420        entries.sort_by_key(|(idx, _)| *idx);
421        let entries: Vec<LayoutEntry<T>> = entries.into_iter().map(|(_, e)| e).collect();
422
423        final_positions.sort_by_key(|(idx, _)| *idx);
424        let positions: Vec<SphericalPoint> = final_positions.into_iter().map(|(_, p)| p).collect();
425
426        let quality = compute_quality(&positions, &final_assignments, k);
427
428        LayoutResult { entries, quality }
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    struct FixedMapper {
437        positions: Vec<SphericalPoint>,
438    }
439
440    impl DimensionMapper for FixedMapper {
441        type Item = usize;
442        fn map(&self, item: &usize) -> SphericalPoint {
443            self.positions[*item]
444        }
445    }
446
447    #[test]
448    fn empty_items_returns_empty_result() {
449        let layout = ClusteredLayout::new();
450        let mapper = FixedMapper { positions: vec![] };
451        let result = layout.layout(&[], &mapper);
452        assert!(result.entries.is_empty());
453    }
454
455    #[test]
456    fn single_item_gets_placed() {
457        let layout = ClusteredLayout::new().with_clusters(1);
458        let mapper = FixedMapper {
459            positions: vec![SphericalPoint::new_unchecked(1.0, 0.5, 1.0)],
460        };
461        let result = layout.layout(&[0usize], &mapper);
462        assert_eq!(result.entries.len(), 1);
463        assert!((result.entries[0].position.r - 1.0).abs() < 1e-12);
464    }
465
466    #[test]
467    fn correct_number_of_entries() {
468        let layout = ClusteredLayout::new().with_clusters(3);
469        let positions: Vec<SphericalPoint> = (0..20)
470            .map(|i| {
471                let theta = (i as f64 * 0.3) % (2.0 * PI);
472                SphericalPoint::new_unchecked(1.0, theta, 1.0)
473            })
474            .collect();
475        let mapper = FixedMapper { positions };
476        let items: Vec<usize> = (0..20).collect();
477        let result = layout.layout(&items, &mapper);
478        assert_eq!(result.entries.len(), 20);
479    }
480
481    #[test]
482    fn items_in_same_cluster_are_angularly_close() {
483        let mut positions = Vec::new();
484        for i in 0..5 {
485            positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.1));
486        }
487        for i in 0..5 {
488            positions.push(SphericalPoint::new_unchecked(
489                1.0,
490                0.01 * i as f64,
491                PI - 0.1,
492            ));
493        }
494        let mapper = FixedMapper { positions };
495        let items: Vec<usize> = (0..10).collect();
496        let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
497        let result = layout.layout(&items, &mapper);
498
499        let group_a: Vec<&SphericalPoint> =
500            result.entries[..5].iter().map(|e| &e.position).collect();
501        for i in 0..group_a.len() {
502            for j in (i + 1)..group_a.len() {
503                let d = angular_distance(group_a[i], group_a[j]);
504                assert!(d < 1.0, "Intra-cluster distance too large: {d}");
505            }
506        }
507    }
508
509    #[test]
510    fn different_clusters_are_angularly_separated() {
511        let mut positions = Vec::new();
512        for i in 0..5 {
513            positions.push(SphericalPoint::new_unchecked(
514                1.0,
515                0.01 * i as f64,
516                PI / 2.0,
517            ));
518        }
519        for i in 0..5 {
520            positions.push(SphericalPoint::new_unchecked(
521                1.0,
522                PI + 0.01 * i as f64,
523                PI / 2.0,
524            ));
525        }
526        let mapper = FixedMapper { positions };
527        let items: Vec<usize> = (0..10).collect();
528        let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
529        let result = layout.layout(&items, &mapper);
530
531        let p_a = &result.entries[0].position;
532        let p_b = &result.entries[5].position;
533        let d = angular_distance(p_a, p_b);
534        assert!(d > 1.0, "Inter-cluster distance too small: {d}");
535    }
536
537    #[test]
538    fn silhouette_positive_for_well_separated_data() {
539        let mut positions = Vec::new();
540        for i in 0..10 {
541            positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.2));
542        }
543        for i in 0..10 {
544            positions.push(SphericalPoint::new_unchecked(
545                1.0,
546                PI + 0.01 * i as f64,
547                PI - 0.2,
548            ));
549        }
550        let mapper = FixedMapper { positions };
551        let items: Vec<usize> = (0..20).collect();
552        let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.15);
553        let result = layout.layout(&items, &mapper);
554        assert!(
555            result.quality.silhouette_score > 0.0,
556            "Silhouette should be positive for well-separated clusters, got {}",
557            result.quality.silhouette_score
558        );
559    }
560
561    #[test]
562    fn builder_methods_apply() {
563        let layout = ClusteredLayout::new()
564            .with_clusters(8)
565            .with_radius(2.5)
566            .with_spread(0.5);
567        assert_eq!(layout.num_clusters, 8);
568        assert!((layout.radius - 2.5).abs() < 1e-12);
569        assert!((layout.intra_cluster_spread - 0.5).abs() < 1e-12);
570    }
571
572    #[test]
573    fn output_radius_matches_configured() {
574        let layout = ClusteredLayout::new().with_radius(3.0).with_clusters(2);
575        let positions = vec![
576            SphericalPoint::new_unchecked(1.0, 0.0, 0.5),
577            SphericalPoint::new_unchecked(1.0, PI, 2.0),
578        ];
579        let mapper = FixedMapper { positions };
580        let result = layout.layout(&[0usize, 1], &mapper);
581        for entry in &result.entries {
582            assert!(
583                (entry.position.r - 3.0).abs() < 1e-12,
584                "Expected radius 3.0, got {}",
585                entry.position.r
586            );
587        }
588    }
589}