1use std::f64::consts::PI;
2
3use sphereql_core::{
4 CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical,
5 spherical_to_cartesian,
6};
7
8use crate::traits::{DimensionMapper, LayoutStrategy};
9use crate::types::{LayoutEntry, LayoutQuality, LayoutResult};
10
11const MAX_KMEANS_ITERATIONS: usize = 50;
12const OVERLAP_THRESHOLD: f64 = 0.01;
13
14pub struct ClusteredLayout {
15 pub num_clusters: usize,
16 pub radius: f64,
17 pub intra_cluster_spread: f64,
18}
19
20impl Default for ClusteredLayout {
21 fn default() -> Self {
22 Self {
23 num_clusters: 4,
24 radius: 1.0,
25 intra_cluster_spread: 0.3,
26 }
27 }
28}
29
30impl ClusteredLayout {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn with_clusters(mut self, n: usize) -> Self {
36 self.num_clusters = n;
37 self
38 }
39
40 pub fn with_radius(mut self, r: f64) -> Self {
41 self.radius = r;
42 self
43 }
44
45 pub fn with_spread(mut self, s: f64) -> Self {
46 self.intra_cluster_spread = s;
47 self
48 }
49}
50
51fn evenly_spaced_centers(k: usize) -> Vec<CartesianPoint> {
52 let golden_ratio = (1.0 + 5.0_f64.sqrt()) / 2.0;
53 (0..k)
54 .map(|i| {
55 let phi = (1.0 - 2.0 * (i as f64 + 0.5) / k as f64)
56 .clamp(-1.0, 1.0)
57 .acos();
58 let theta = (2.0 * PI * (i as f64) / golden_ratio).rem_euclid(2.0 * PI);
59 let sp = SphericalPoint::new_unchecked(1.0, theta, phi);
60 spherical_to_cartesian(&sp)
61 })
62 .collect()
63}
64
65fn normalized_mean(points: &[CartesianPoint]) -> CartesianPoint {
66 if points.is_empty() {
67 return CartesianPoint::new(0.0, 0.0, 1.0);
68 }
69 let mut sx = 0.0;
70 let mut sy = 0.0;
71 let mut sz = 0.0;
72 for p in points {
73 sx += p.x;
74 sy += p.y;
75 sz += p.z;
76 }
77 let mean = CartesianPoint::new(sx, sy, sz);
78 let n = mean.normalize();
79 if n.magnitude() == 0.0 {
80 points[0].normalize()
81 } else {
82 n
83 }
84}
85
86struct KMeansResult {
87 assignments: Vec<usize>,
88 centers: Vec<CartesianPoint>,
89}
90
91fn kmeans_spherical(mapped_cartesian: &[CartesianPoint], k: usize) -> KMeansResult {
92 let n = mapped_cartesian.len();
93
94 let mut centers: Vec<CartesianPoint> = if n >= k {
95 mapped_cartesian[..k]
96 .iter()
97 .map(|c| c.normalize())
98 .collect()
99 } else {
100 evenly_spaced_centers(k)
101 };
102
103 let mut assignments = vec![0usize; n];
104
105 #[inline]
113 fn dot(a: &CartesianPoint, b: &CartesianPoint) -> f64 {
114 a.x * b.x + a.y * b.y + a.z * b.z
115 }
116
117 for _ in 0..MAX_KMEANS_ITERATIONS {
118 let mut changed = false;
119
120 for (i, point) in mapped_cartesian.iter().enumerate() {
121 let mut best = 0;
122 let mut best_dot = f64::MIN;
123 for (j, center) in centers.iter().enumerate() {
124 let d = dot(point, center);
125 if d > best_dot {
126 best_dot = d;
127 best = j;
128 }
129 }
130 if assignments[i] != best {
131 assignments[i] = best;
132 changed = true;
133 }
134 }
135
136 if !changed {
137 break;
138 }
139
140 let mut cluster_points: Vec<Vec<CartesianPoint>> = vec![vec![]; k];
141 for (i, &a) in assignments.iter().enumerate() {
142 cluster_points[a].push(mapped_cartesian[i]);
143 }
144
145 for (j, cp) in cluster_points.iter().enumerate() {
146 if cp.is_empty() {
147 let mut farthest_idx = 0;
151 let mut farthest_dot = f64::MAX;
152 for (i, point) in mapped_cartesian.iter().enumerate() {
153 let d = dot(point, ¢ers[assignments[i]]);
154 if d < farthest_dot {
155 farthest_dot = d;
156 farthest_idx = i;
157 }
158 }
159 centers[j] = mapped_cartesian[farthest_idx].normalize();
160 } else {
161 centers[j] = normalized_mean(cp);
162 }
163 }
164 }
165
166 KMeansResult {
167 assignments,
168 centers,
169 }
170}
171
172fn fibonacci_sub_spiral(
173 center: &SphericalPoint,
174 count: usize,
175 spread: f64,
176 radius: f64,
177) -> Vec<SphericalPoint> {
178 if count == 0 {
179 return vec![];
180 }
181 if count == 1 {
182 return vec![SphericalPoint::new_unchecked(
183 radius,
184 center.theta,
185 center.phi,
186 )];
187 }
188
189 let golden_angle = PI * (3.0 - 5.0_f64.sqrt());
190 let center_cart = spherical_to_cartesian(&SphericalPoint::new_unchecked(
191 1.0,
192 center.theta,
193 center.phi,
194 ));
195
196 let (tangent_u, tangent_v) = local_frame(¢er_cart);
197
198 (0..count)
199 .map(|i| {
200 let frac = i as f64 / count as f64;
201 let angular_r = spread * frac.sqrt();
202 let angle = golden_angle * i as f64;
203
204 let offset_u = angular_r * angle.cos();
205 let offset_v = angular_r * angle.sin();
206
207 let displaced = CartesianPoint::new(
208 center_cart.x + offset_u * tangent_u.x + offset_v * tangent_v.x,
209 center_cart.y + offset_u * tangent_u.y + offset_v * tangent_v.y,
210 center_cart.z + offset_u * tangent_u.z + offset_v * tangent_v.z,
211 )
212 .normalize();
213
214 let sp = cartesian_to_spherical(&displaced);
215 SphericalPoint::new_unchecked(radius, sp.theta, sp.phi)
216 })
217 .collect()
218}
219
220fn local_frame(center: &CartesianPoint) -> (CartesianPoint, CartesianPoint) {
221 let up = if center.z.abs() < 0.9 {
222 CartesianPoint::new(0.0, 0.0, 1.0)
223 } else {
224 CartesianPoint::new(1.0, 0.0, 0.0)
225 };
226
227 let ux = up.y * center.z - up.z * center.y;
229 let uy = up.z * center.x - up.x * center.z;
230 let uz = up.x * center.y - up.y * center.x;
231 let u = CartesianPoint::new(ux, uy, uz).normalize();
232
233 let vx = center.y * u.z - center.z * u.y;
235 let vy = center.z * u.x - center.x * u.z;
236 let vz = center.x * u.y - center.y * u.x;
237 let v = CartesianPoint::new(vx, vy, vz).normalize();
238
239 (u, v)
240}
241
242const MAX_QUALITY_N: usize = 5000;
243
244fn compute_quality(
245 positions: &[SphericalPoint],
246 assignments: &[usize],
247 num_clusters: usize,
248) -> LayoutQuality {
249 let n = positions.len();
250
251 if n <= 1 {
252 return LayoutQuality {
253 dispersion_score: if n == 0 { 0.0 } else { 1.0 },
254 overlap_score: 0.0,
255 silhouette_score: 0.0,
256 };
257 }
258
259 let (positions, assignments, n) = if n > MAX_QUALITY_N {
260 let step = n / MAX_QUALITY_N;
261 let sampled_pos: Vec<_> = positions
262 .iter()
263 .step_by(step)
264 .take(MAX_QUALITY_N)
265 .copied()
266 .collect();
267 let sampled_asgn: Vec<_> = assignments
268 .iter()
269 .step_by(step)
270 .take(MAX_QUALITY_N)
271 .copied()
272 .collect();
273 let len = sampled_pos.len();
274 (sampled_pos, sampled_asgn, len)
275 } else {
276 (positions.to_vec(), assignments.to_vec(), n)
277 };
278
279 let mut cluster_point_sets: Vec<Vec<CartesianPoint>> = vec![vec![]; num_clusters];
281 for (i, &a) in assignments.iter().enumerate() {
282 cluster_point_sets[a].push(spherical_to_cartesian(&positions[i]));
283 }
284 let active_centers: Vec<SphericalPoint> = cluster_point_sets
285 .iter()
286 .filter(|cp| !cp.is_empty())
287 .map(|cp| cartesian_to_spherical(&normalized_mean(cp)))
288 .collect();
289
290 use rayon::prelude::*;
296 const SERIAL_THRESHOLD: usize = 128;
297
298 let dispersion_score = if active_centers.len() >= 2 {
299 let len = active_centers.len();
300 let (sum, count) = if len < SERIAL_THRESHOLD {
301 let mut s = 0.0;
302 let mut c = 0u64;
303 for i in 0..len {
304 for j in (i + 1)..len {
305 s += angular_distance(&active_centers[i], &active_centers[j]);
306 c += 1;
307 }
308 }
309 (s, c)
310 } else {
311 (0..len)
312 .into_par_iter()
313 .map(|i| {
314 let mut s = 0.0;
315 let mut c = 0u64;
316 for j in (i + 1)..len {
317 s += angular_distance(&active_centers[i], &active_centers[j]);
318 c += 1;
319 }
320 (s, c)
321 })
322 .reduce(|| (0.0, 0), |(sa, ca), (sb, cb)| (sa + sb, ca + cb))
323 };
324 (sum / count as f64 / PI).clamp(0.0, 1.0)
325 } else {
326 0.0
327 };
328
329 let total_pairs = (n * (n - 1)) / 2;
331 let overlap_count: u64 = if n < SERIAL_THRESHOLD {
332 let mut c = 0u64;
333 for i in 0..n {
334 for j in (i + 1)..n {
335 if angular_distance(&positions[i], &positions[j]) < OVERLAP_THRESHOLD {
336 c += 1;
337 }
338 }
339 }
340 c
341 } else {
342 (0..n)
343 .into_par_iter()
344 .map(|i| {
345 let mut c = 0u64;
346 for j in (i + 1)..n {
347 if angular_distance(&positions[i], &positions[j]) < OVERLAP_THRESHOLD {
348 c += 1;
349 }
350 }
351 c
352 })
353 .sum()
354 };
355 let overlap_score = if total_pairs > 0 {
356 overlap_count as f64 / total_pairs as f64
357 } else {
358 0.0
359 };
360
361 let silhouette_score = if num_clusters <= 1 || active_centers.len() <= 1 {
365 0.0
366 } else {
367 let per_point = |i: usize| -> f64 {
368 let ci = assignments[i];
369
370 let mut a_sum = 0.0;
372 let mut a_count = 0;
373 for j in 0..n {
374 if j != i && assignments[j] == ci {
375 a_sum += angular_distance(&positions[i], &positions[j]);
376 a_count += 1;
377 }
378 }
379 let a = if a_count > 0 {
380 a_sum / a_count as f64
381 } else {
382 0.0
383 };
384
385 let mut b = f64::MAX;
387 for k in 0..num_clusters {
388 if k == ci {
389 continue;
390 }
391 let mut b_sum = 0.0;
392 let mut b_count = 0;
393 for j in 0..n {
394 if assignments[j] == k {
395 b_sum += angular_distance(&positions[i], &positions[j]);
396 b_count += 1;
397 }
398 }
399 if b_count > 0 {
400 let mean_dist = b_sum / b_count as f64;
401 if mean_dist < b {
402 b = mean_dist;
403 }
404 }
405 }
406 if b == f64::MAX {
407 b = 0.0;
408 }
409
410 let denom = a.max(b);
411 if denom > 0.0 { (b - a) / denom } else { 0.0 }
412 };
413
414 let sil_sum: f64 = if n < SERIAL_THRESHOLD {
415 (0..n).map(per_point).sum()
416 } else {
417 (0..n).into_par_iter().map(per_point).sum()
418 };
419 sil_sum / n as f64
420 };
421
422 LayoutQuality {
423 dispersion_score,
424 overlap_score,
425 silhouette_score,
426 }
427}
428
429impl<T: Clone + Send + Sync> LayoutStrategy<T> for ClusteredLayout {
430 fn layout(&self, items: &[T], mapper: &dyn DimensionMapper<Item = T>) -> LayoutResult<T> {
431 if items.is_empty() {
432 return LayoutResult {
433 entries: vec![],
434 quality: LayoutQuality::default(),
435 };
436 }
437
438 let mapped: Vec<SphericalPoint> = items.iter().map(|item| mapper.map(item)).collect();
439 let mapped_cart: Vec<CartesianPoint> = mapped.iter().map(spherical_to_cartesian).collect();
440
441 let k = self.num_clusters.min(items.len()).max(1);
442 let km = kmeans_spherical(&mapped_cart, k);
443
444 let mut cluster_items: Vec<Vec<usize>> = vec![vec![]; k];
445 for (i, &a) in km.assignments.iter().enumerate() {
446 cluster_items[a].push(i);
447 }
448
449 let mut entries: Vec<(usize, LayoutEntry<T>)> = Vec::with_capacity(items.len());
450 let mut final_positions: Vec<(usize, SphericalPoint)> = Vec::with_capacity(items.len());
451 let mut final_assignments = vec![0usize; items.len()];
452
453 for (cluster_idx, member_indices) in cluster_items.iter().enumerate() {
454 let center_sp = cartesian_to_spherical(&km.centers[cluster_idx]);
455 let sub_positions = fibonacci_sub_spiral(
456 ¢er_sp,
457 member_indices.len(),
458 self.intra_cluster_spread,
459 self.radius,
460 );
461
462 for (sub_idx, &item_idx) in member_indices.iter().enumerate() {
463 let pos = sub_positions[sub_idx];
464 entries.push((
465 item_idx,
466 LayoutEntry {
467 item: items[item_idx].clone(),
468 position: pos,
469 },
470 ));
471 final_positions.push((item_idx, pos));
472 final_assignments[item_idx] = cluster_idx;
473 }
474 }
475
476 entries.sort_by_key(|(idx, _)| *idx);
477 let entries: Vec<LayoutEntry<T>> = entries.into_iter().map(|(_, e)| e).collect();
478
479 final_positions.sort_by_key(|(idx, _)| *idx);
480 let positions: Vec<SphericalPoint> = final_positions.into_iter().map(|(_, p)| p).collect();
481
482 let quality = compute_quality(&positions, &final_assignments, k);
483
484 LayoutResult { entries, quality }
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 struct FixedMapper {
493 positions: Vec<SphericalPoint>,
494 }
495
496 impl DimensionMapper for FixedMapper {
497 type Item = usize;
498 fn map(&self, item: &usize) -> SphericalPoint {
499 self.positions[*item]
500 }
501 }
502
503 #[test]
504 fn empty_items_returns_empty_result() {
505 let layout = ClusteredLayout::new();
506 let mapper = FixedMapper { positions: vec![] };
507 let result = layout.layout(&[], &mapper);
508 assert!(result.entries.is_empty());
509 }
510
511 #[test]
512 fn single_item_gets_placed() {
513 let layout = ClusteredLayout::new().with_clusters(1);
514 let mapper = FixedMapper {
515 positions: vec![SphericalPoint::new_unchecked(1.0, 0.5, 1.0)],
516 };
517 let result = layout.layout(&[0usize], &mapper);
518 assert_eq!(result.entries.len(), 1);
519 assert!((result.entries[0].position.r - 1.0).abs() < 1e-12);
520 }
521
522 #[test]
523 fn correct_number_of_entries() {
524 let layout = ClusteredLayout::new().with_clusters(3);
525 let positions: Vec<SphericalPoint> = (0..20)
526 .map(|i| {
527 let theta = (i as f64 * 0.3) % (2.0 * PI);
528 SphericalPoint::new_unchecked(1.0, theta, 1.0)
529 })
530 .collect();
531 let mapper = FixedMapper { positions };
532 let items: Vec<usize> = (0..20).collect();
533 let result = layout.layout(&items, &mapper);
534 assert_eq!(result.entries.len(), 20);
535 }
536
537 #[test]
538 fn items_in_same_cluster_are_angularly_close() {
539 let mut positions = Vec::new();
540 for i in 0..5 {
541 positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.1));
542 }
543 for i in 0..5 {
544 positions.push(SphericalPoint::new_unchecked(
545 1.0,
546 0.01 * i as f64,
547 PI - 0.1,
548 ));
549 }
550 let mapper = FixedMapper { positions };
551 let items: Vec<usize> = (0..10).collect();
552 let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
553 let result = layout.layout(&items, &mapper);
554
555 let group_a: Vec<&SphericalPoint> =
556 result.entries[..5].iter().map(|e| &e.position).collect();
557 for i in 0..group_a.len() {
558 for j in (i + 1)..group_a.len() {
559 let d = angular_distance(group_a[i], group_a[j]);
560 assert!(d < 1.0, "Intra-cluster distance too large: {d}");
561 }
562 }
563 }
564
565 #[test]
566 fn different_clusters_are_angularly_separated() {
567 let mut positions = Vec::new();
568 for i in 0..5 {
569 positions.push(SphericalPoint::new_unchecked(
570 1.0,
571 0.01 * i as f64,
572 PI / 2.0,
573 ));
574 }
575 for i in 0..5 {
576 positions.push(SphericalPoint::new_unchecked(
577 1.0,
578 PI + 0.01 * i as f64,
579 PI / 2.0,
580 ));
581 }
582 let mapper = FixedMapper { positions };
583 let items: Vec<usize> = (0..10).collect();
584 let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
585 let result = layout.layout(&items, &mapper);
586
587 let p_a = &result.entries[0].position;
588 let p_b = &result.entries[5].position;
589 let d = angular_distance(p_a, p_b);
590 assert!(d > 1.0, "Inter-cluster distance too small: {d}");
591 }
592
593 #[test]
594 fn silhouette_positive_for_well_separated_data() {
595 let mut positions = Vec::new();
596 for i in 0..10 {
597 positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.2));
598 }
599 for i in 0..10 {
600 positions.push(SphericalPoint::new_unchecked(
601 1.0,
602 PI + 0.01 * i as f64,
603 PI - 0.2,
604 ));
605 }
606 let mapper = FixedMapper { positions };
607 let items: Vec<usize> = (0..20).collect();
608 let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.15);
609 let result = layout.layout(&items, &mapper);
610 assert!(
611 result.quality.silhouette_score > 0.0,
612 "Silhouette should be positive for well-separated clusters, got {}",
613 result.quality.silhouette_score
614 );
615 }
616
617 #[test]
618 fn builder_methods_apply() {
619 let layout = ClusteredLayout::new()
620 .with_clusters(8)
621 .with_radius(2.5)
622 .with_spread(0.5);
623 assert_eq!(layout.num_clusters, 8);
624 assert!((layout.radius - 2.5).abs() < 1e-12);
625 assert!((layout.intra_cluster_spread - 0.5).abs() < 1e-12);
626 }
627
628 #[test]
629 fn output_radius_matches_configured() {
630 let layout = ClusteredLayout::new().with_radius(3.0).with_clusters(2);
631 let positions = vec![
632 SphericalPoint::new_unchecked(1.0, 0.0, 0.5),
633 SphericalPoint::new_unchecked(1.0, PI, 2.0),
634 ];
635 let mapper = FixedMapper { positions };
636 let result = layout.layout(&[0usize, 1], &mapper);
637 for entry in &result.entries {
638 assert!(
639 (entry.position.r - 3.0).abs() < 1e-12,
640 "Expected radius 3.0, got {}",
641 entry.position.r
642 );
643 }
644 }
645}