1use std::collections::HashMap;
32use std::sync::Arc;
33
34use sphereql_core::SphericalPoint;
35
36use crate::ann::{AnnConfig, AnnIndex};
37use crate::projection::{
38 Projection, ProjectionError, SplitMix64, dot, normalize_vec, project_xyz_to_spherical,
39};
40use crate::types::{Embedding, ProjectedPoint, RadialContext, RadialStrategy};
41
42#[derive(Debug, Clone)]
45pub struct UmapConfig {
46 pub n_neighbors: usize,
50 pub n_epochs: usize,
53 pub learning_rate: f64,
56 pub negative_sample_rate: usize,
58 pub category_weight: f64,
65 pub min_dist: f64,
73 pub warm_start_anchor: f64,
82 pub seed: u64,
85}
86
87impl Default for UmapConfig {
88 fn default() -> Self {
89 Self {
90 n_neighbors: 15,
91 n_epochs: 200,
92 learning_rate: 0.05,
93 negative_sample_rate: 5,
94 category_weight: 0.0,
95 min_dist: 0.1,
96 warm_start_anchor: 0.0,
97 seed: 0xA1B2_C3D4,
98 }
99 }
100}
101
102#[derive(Clone)]
107pub struct UmapGraph {
108 pub(crate) knn: Vec<Vec<usize>>,
110 pub(crate) weights: Vec<Vec<f64>>,
116 pub(crate) normalized: Vec<Vec<f64>>,
119 pub(crate) warm_start: Vec<[f64; 3]>,
121 pub(crate) dim: usize,
123 pub(crate) k: usize,
125 pub(crate) ann: Option<Arc<AnnIndex>>,
131}
132
133impl UmapGraph {
134 pub fn build(embeddings: &[Embedding], n_neighbors: usize) -> Result<Self, ProjectionError> {
140 if embeddings.is_empty() {
141 return Err(ProjectionError::EmptyCorpus);
142 }
143 let dim = embeddings[0].dimension();
144 if dim < 3 {
145 return Err(ProjectionError::DimensionTooLow {
146 got: dim,
147 required: 3,
148 });
149 }
150 for (i, e) in embeddings.iter().enumerate() {
151 if e.dimension() != dim {
152 return Err(ProjectionError::InconsistentDimension {
153 index: i,
154 expected: dim,
155 got: e.dimension(),
156 });
157 }
158 }
159 let n = embeddings.len();
160 if n < 4 {
161 return Err(ProjectionError::TooFewEmbeddings {
162 got: n,
163 required: 4,
164 });
165 }
166
167 let normalized: Vec<Vec<f64>> = embeddings.iter().map(|e| e.normalized()).collect();
168 let k = n_neighbors.min(n - 1).max(1);
169 let (knn, dists, ann) = build_knn_graph(&normalized, k);
170 let weights = fuzzy_simplicial_weights(&knn, &dists);
171 let warm_start = pca_warm_start(embeddings, &normalized)?;
172
173 Ok(Self {
174 knn,
175 weights,
176 normalized,
177 warm_start,
178 dim,
179 k,
180 ann,
181 })
182 }
183}
184
185#[derive(Clone)]
188pub struct UmapSphereProjection {
189 fitted_points: Vec<[f64; 3]>,
191 fitted_normalized: Vec<Vec<f64>>,
194 exact_lookup: HashMap<u64, Vec<usize>>,
203 ann: Option<Arc<AnnIndex>>,
207 dim: usize,
208 radial: RadialStrategy,
209 n_neighbors: usize,
210 quality: f64,
215}
216
217impl UmapSphereProjection {
218 pub fn fit_default(embeddings: &[Embedding]) -> Result<Self, ProjectionError> {
220 Self::fit(
221 embeddings,
222 None,
223 RadialStrategy::default(),
224 UmapConfig::default(),
225 )
226 }
227
228 pub fn fit_from_graph(
234 graph: &UmapGraph,
235 categories: Option<&[u32]>,
236 radial: RadialStrategy,
237 config: UmapConfig,
238 ) -> Result<Self, ProjectionError> {
239 let n = graph.normalized.len();
240
241 if let Some(cats) = categories
242 && cats.len() != n
243 {
244 return Err(ProjectionError::SliceLengthMismatch {
245 expected: n,
246 got: cats.len(),
247 });
248 }
249
250 let mut points = graph.warm_start.clone();
251 let mut rng = SplitMix64::new(config.seed);
252 let (ka, kb) = find_ab_params(config.min_dist);
255 let cat_active = config.category_weight > 0.0 && categories.is_some();
258 let anchor_active = config.warm_start_anchor > 0.0;
260
261 let cat_buckets: Option<(Vec<Vec<usize>>, Vec<usize>)> = cat_active.then(|| {
266 let cats = categories.unwrap();
268 let mut id_to_bucket: HashMap<u32, usize> = HashMap::new();
269 let mut buckets: Vec<Vec<usize>> = Vec::new();
270 let mut bucket_of = Vec::with_capacity(n);
271 for (i, &c) in cats.iter().enumerate() {
272 let b = *id_to_bucket.entry(c).or_insert_with(|| {
273 buckets.push(Vec::new());
274 buckets.len() - 1
275 });
276 buckets[b].push(i);
277 bucket_of.push(b);
278 }
279 (buckets, bucket_of)
280 });
281
282 let attract: Vec<Vec<f64>> = graph
288 .knn
289 .iter()
290 .enumerate()
291 .map(|(i, neighbors)| {
292 neighbors
293 .iter()
294 .zip(&graph.weights[i])
295 .map(|(&j, &w)| {
296 if graph.knn[j].contains(&i) {
297 0.5 * w
298 } else {
299 w
300 }
301 })
302 .collect()
303 })
304 .collect();
305
306 let mut m = vec![[0.0f64; 3]; n];
308 let mut v = vec![[0.0f64; 3]; n];
309 let beta1 = 0.9;
310 let beta2 = 0.999;
311 let eps = 1e-8;
312
313 for epoch in 1..=config.n_epochs {
314 let lr = config.learning_rate * (1.0 - (epoch as f64 / config.n_epochs as f64));
316 let mut grads = vec![[0.0f64; 3]; n];
317
318 for (i, neighbors) in graph.knn.iter().enumerate() {
329 for (idx, &j) in neighbors.iter().enumerate() {
330 let w = attract[i][idx];
331 let (gi, gj) = attractive_grad(&points[i], &points[j], ka, kb);
332 add3_scaled(&mut grads[i], &gi, w);
333 add3_scaled(&mut grads[j], &gj, w);
334
335 for _ in 0..config.negative_sample_rate {
336 let nj = (rng.next_u64() as usize) % n;
337 if nj == i {
338 continue;
339 }
340 let (gi_r, gj_r) = repulsive_grad(&points[i], &points[nj], ka, kb);
341 add3_scaled(&mut grads[i], &gi_r, w);
342 add3_scaled(&mut grads[nj], &gj_r, w);
343 }
344 }
345 }
346
347 if let Some((buckets, bucket_of)) = &cat_buckets {
354 let w = config.category_weight;
355 for i in 0..n {
356 let bucket = &buckets[bucket_of[i]];
357 if bucket.len() > 1 {
358 let idx = (rng.next_u64() as usize) % (bucket.len() - 1);
362 let j = if bucket[idx] == i {
363 bucket[bucket.len() - 1]
364 } else {
365 bucket[idx]
366 };
367 let (gi, gj) = attractive_grad(&points[i], &points[j], ka, kb);
368 add3_scaled(&mut grads[i], &gi, w);
369 add3_scaled(&mut grads[j], &gj, w);
370 }
371 for _ in 0..MAX_CROSS_CATEGORY_DRAWS {
375 let j = (rng.next_u64() as usize) % n;
376 if bucket_of[j] != bucket_of[i] {
377 let (gi, gj) = repulsive_grad(&points[i], &points[j], ka, kb);
378 add3_scaled(&mut grads[i], &gi, w);
379 add3_scaled(&mut grads[j], &gj, w);
380 break;
381 }
382 }
383 }
384 }
385
386 if anchor_active {
394 let w = config.warm_start_anchor;
395 for i in 0..n {
396 let (gi, _) = attractive_grad(&points[i], &graph.warm_start[i], ka, kb);
397 add3_scaled(&mut grads[i], &gi, w);
398 }
399 }
400
401 for i in 0..n {
403 let g_tan = project_to_tangent(&points[i], &grads[i]);
404 for d in 0..3 {
405 m[i][d] = beta1 * m[i][d] + (1.0 - beta1) * g_tan[d];
406 v[i][d] = beta2 * v[i][d] + (1.0 - beta2) * g_tan[d] * g_tan[d];
407 }
408 let t = epoch as f64;
409 let bc1 = 1.0 - beta1.powf(t);
410 let bc2 = 1.0 - beta2.powf(t);
411 let mut step = [0.0f64; 3];
412 for d in 0..3 {
413 let m_hat = m[i][d] / bc1;
414 let v_hat = v[i][d] / bc2;
415 step[d] = lr * m_hat / (v_hat.sqrt() + eps);
416 }
417 let mut next = [
420 points[i][0] - step[0],
421 points[i][1] - step[1],
422 points[i][2] - step[2],
423 ];
424 let mag = (next[0] * next[0] + next[1] * next[1] + next[2] * next[2]).sqrt();
425 if mag > f64::EPSILON {
426 next[0] /= mag;
427 next[1] /= mag;
428 next[2] /= mag;
429 points[i] = next;
430 }
431 }
432 }
433
434 let quality = knn_recall_score(&points, &graph.knn);
435
436 let mut exact_lookup: HashMap<u64, Vec<usize>> = HashMap::new();
437 for (i, vec) in graph.normalized.iter().enumerate() {
438 let bucket = exact_lookup.entry(hash_normalized(vec)).or_default();
439 if !bucket.iter().any(|&j| graph.normalized[j] == *vec) {
440 bucket.push(i);
441 }
442 }
443
444 Ok(Self {
445 fitted_points: points,
446 fitted_normalized: graph.normalized.clone(),
447 exact_lookup,
448 ann: graph.ann.clone(),
449 dim: graph.dim,
450 radial,
451 n_neighbors: graph.k,
452 quality,
453 })
454 }
455
456 pub fn fit(
465 embeddings: &[Embedding],
466 categories: Option<&[u32]>,
467 radial: RadialStrategy,
468 config: UmapConfig,
469 ) -> Result<Self, ProjectionError> {
470 let graph = UmapGraph::build(embeddings, config.n_neighbors)?;
471 Self::fit_from_graph(&graph, categories, radial, config)
472 }
473
474 pub fn explained_variance_ratio(&self) -> f64 {
490 self.quality
491 }
492
493 fn nearest_fitted(&self, normalized: &[f64]) -> Vec<(usize, f64)> {
497 if let Some(ann) = &self.ann {
498 return ann.query(normalized, self.n_neighbors);
499 }
500 let mut sims: Vec<(usize, f64)> = self
501 .fitted_normalized
502 .iter()
503 .enumerate()
504 .map(|(i, v)| (i, dot(normalized, v)))
505 .collect();
506 sims.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
507 sims.truncate(self.n_neighbors);
508 sims
509 }
510
511 fn exact_fitted(&self, normalized: &[f64]) -> Option<usize> {
512 self.exact_lookup
513 .get(&hash_normalized(normalized))?
514 .iter()
515 .copied()
516 .find(|&i| self.fitted_normalized[i] == normalized)
517 }
518
519 fn project_xyz(&self, embedding: &Embedding) -> ([f64; 3], f64) {
520 let normalized = embedding.normalized();
521 if let Some(idx) = self.exact_fitted(&normalized) {
522 return (self.fitted_points[idx], 1.0);
523 }
524 let neighbors = self.nearest_fitted(&normalized);
525
526 let max_sim = neighbors
528 .iter()
529 .map(|(_, s)| *s)
530 .fold(f64::NEG_INFINITY, f64::max);
531 let mut weights: Vec<f64> = neighbors
532 .iter()
533 .map(|(_, s)| ((s - max_sim) * 8.0).exp())
534 .collect();
535 let total: f64 = weights.iter().sum();
536 if total > f64::EPSILON {
537 for w in &mut weights {
538 *w /= total;
539 }
540 } else {
541 let n = weights.len() as f64;
542 for w in &mut weights {
543 *w = 1.0 / n;
544 }
545 }
546
547 let mut acc = [0.0f64; 3];
548 for ((idx, _), w) in neighbors.iter().zip(weights.iter()) {
549 let p = self.fitted_points[*idx];
550 acc[0] += w * p[0];
551 acc[1] += w * p[1];
552 acc[2] += w * p[2];
553 }
554 let mag = (acc[0] * acc[0] + acc[1] * acc[1] + acc[2] * acc[2]).sqrt();
555 let certainty = mag.clamp(0.0, 1.0);
556 (acc, certainty)
557 }
558}
559
560impl Projection for UmapSphereProjection {
561 fn project(&self, embedding: &Embedding) -> SphericalPoint {
562 assert_eq!(
564 embedding.dimension(),
565 self.dim,
566 "expected dimension {}, got {}",
567 self.dim,
568 embedding.dimension()
569 );
570 let (xyz, certainty) = self.project_xyz(embedding);
571 let projection_magnitude = (xyz[0] * xyz[0] + xyz[1] * xyz[1] + xyz[2] * xyz[2]).sqrt();
572 let intensity = embedding.magnitude();
573 let r = self.radial.compute_rich(&RadialContext::full(
574 intensity,
575 projection_magnitude,
576 certainty,
577 ));
578 project_xyz_to_spherical(xyz[0], xyz[1], xyz[2], r)
579 }
580
581 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
582 assert_eq!(
584 embedding.dimension(),
585 self.dim,
586 "expected dimension {}, got {}",
587 self.dim,
588 embedding.dimension()
589 );
590 let (xyz, certainty) = self.project_xyz(embedding);
591 let projection_magnitude = (xyz[0] * xyz[0] + xyz[1] * xyz[1] + xyz[2] * xyz[2]).sqrt();
592 let intensity = embedding.magnitude();
593 let r = self.radial.compute_rich(&RadialContext::full(
594 intensity,
595 projection_magnitude,
596 certainty,
597 ));
598 let position = project_xyz_to_spherical(xyz[0], xyz[1], xyz[2], r);
599 ProjectedPoint::new(position, certainty, intensity, projection_magnitude)
600 }
601
602 fn dimensionality(&self) -> usize {
603 self.dim
604 }
605}
606
607fn attractive_grad(xi: &[f64; 3], xj: &[f64; 3], a: f64, b: f64) -> ([f64; 3], [f64; 3]) {
618 let dx = [xi[0] - xj[0], xi[1] - xj[1], xi[2] - xj[2]];
622 let d2 = dx[0] * dx[0] + dx[1] * dx[1] + dx[2] * dx[2];
623 let coef = 2.0 * a * b * d2.max(1e-6).powf(b - 1.0) / (1.0 + a * d2.powf(b));
626 let g = [coef * dx[0], coef * dx[1], coef * dx[2]];
627 (g, [-g[0], -g[1], -g[2]])
628}
629
630fn repulsive_grad(xi: &[f64; 3], xj: &[f64; 3], a: f64, b: f64) -> ([f64; 3], [f64; 3]) {
631 let dx = [xi[0] - xj[0], xi[1] - xj[1], xi[2] - xj[2]];
635 let d2 = (dx[0] * dx[0] + dx[1] * dx[1] + dx[2] * dx[2]).max(1e-6);
636 let coef = -2.0 * b / (d2 * (1.0 + a * d2.powf(b)));
637 let g = [coef * dx[0], coef * dx[1], coef * dx[2]];
638 (g, [-g[0], -g[1], -g[2]])
639}
640
641fn find_ab_params(min_dist: f64) -> (f64, f64) {
650 const SAMPLES: usize = 300;
651 const D_MAX: f64 = 3.0;
652 let targets: Vec<(f64, f64)> = (0..SAMPLES)
653 .map(|i| {
654 let d = D_MAX * i as f64 / (SAMPLES - 1) as f64;
655 let f = if d <= min_dist {
656 1.0
657 } else {
658 (-(d - min_dist)).exp()
659 };
660 (d, f)
661 })
662 .collect();
663 let sse = |a: f64, b: f64| -> f64 {
664 targets
665 .iter()
666 .map(|&(d, f)| {
667 let phi = 1.0 / (1.0 + a * d.powf(2.0 * b));
668 (phi - f) * (phi - f)
669 })
670 .sum()
671 };
672
673 const LA_MIN: f64 = -3.0;
674 const LA_MAX: f64 = 1.0;
675 const B_MIN: f64 = 0.1;
676 const B_MAX: f64 = 2.5;
677 let mut la_lo = LA_MIN;
678 let mut la_hi = LA_MAX;
679 let mut b_lo = B_MIN;
680 let mut b_hi = B_MAX;
681 let mut best = (1.0f64, 1.0f64);
682 let mut best_err = f64::INFINITY;
683 for round in 0..6 {
684 let steps = if round == 0 { 24 } else { 8 };
685 for i in 0..=steps {
686 let la = la_lo + (la_hi - la_lo) * i as f64 / steps as f64;
687 let a = 10f64.powf(la);
688 for j in 0..=steps {
689 let b = b_lo + (b_hi - b_lo) * j as f64 / steps as f64;
690 let err = sse(a, b);
691 if err < best_err {
692 best_err = err;
693 best = (a, b);
694 }
695 }
696 }
697 let la_half = (la_hi - la_lo) / 4.0;
700 let la_best = best.0.log10();
701 la_lo = (la_best - la_half).max(LA_MIN);
702 la_hi = (la_best + la_half).min(LA_MAX);
703 let b_half = (b_hi - b_lo) / 4.0;
704 b_lo = (best.1 - b_half).max(B_MIN);
705 b_hi = (best.1 + b_half).min(B_MAX);
706 }
707 best
708}
709
710fn project_to_tangent(x: &[f64; 3], g: &[f64; 3]) -> [f64; 3] {
711 let radial = x[0] * g[0] + x[1] * g[1] + x[2] * g[2];
713 [
714 g[0] - radial * x[0],
715 g[1] - radial * x[1],
716 g[2] - radial * x[2],
717 ]
718}
719
720fn add3_scaled(a: &mut [f64; 3], b: &[f64; 3], s: f64) {
721 a[0] += s * b[0];
722 a[1] += s * b[1];
723 a[2] += s * b[2];
724}
725
726const MAX_CROSS_CATEGORY_DRAWS: usize = 8;
733
734const ANN_BRUTE_FORCE_THRESHOLD: usize = 2000;
738
739const SIGMA_SEARCH_ITERS: usize = 64;
742
743const SMOOTH_K_TOLERANCE: f64 = 1e-5;
746
747const MIN_K_DIST_SCALE: f64 = 1e-3;
751const SIGMA_ABS_FLOOR: f64 = 1e-8;
752
753fn hash_normalized(v: &[f64]) -> u64 {
758 let mut h = 0xcbf2_9ce4_8422_2325u64;
759 for &x in v {
760 h ^= x.to_bits();
761 h = h.wrapping_mul(0x0000_0100_0000_01b3);
762 }
763 h
764}
765
766fn cosine_distance(sim: f64) -> f64 {
769 (1.0 - sim).max(0.0)
770}
771
772fn split_knn_rows(rows: Vec<Vec<(usize, f64)>>) -> (Vec<Vec<usize>>, Vec<Vec<f64>>) {
775 let mut knn = Vec::with_capacity(rows.len());
776 let mut dists = Vec::with_capacity(rows.len());
777 for row in rows {
778 knn.push(row.iter().map(|&(j, _)| j).collect());
779 dists.push(row.iter().map(|&(_, s)| cosine_distance(s)).collect());
780 }
781 (knn, dists)
782}
783
784#[allow(clippy::type_complexity)]
785fn build_knn_graph(
786 normalized: &[Vec<f64>],
787 k: usize,
788) -> (Vec<Vec<usize>>, Vec<Vec<f64>>, Option<Arc<AnnIndex>>) {
789 let n = normalized.len();
790 if n < ANN_BRUTE_FORCE_THRESHOLD {
791 let rows: Vec<Vec<(usize, f64)>> = (0..n)
792 .map(|i| {
793 let mut sims: Vec<(usize, f64)> = (0..n)
794 .filter(|&j| j != i)
795 .map(|j| (j, dot(&normalized[i], &normalized[j])))
796 .collect();
797 sims.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
798 sims.truncate(k);
799 sims
800 })
801 .collect();
802 let (knn, dists) = split_knn_rows(rows);
803 return (knn, dists, None);
804 }
805
806 let index = Arc::new(AnnIndex::build_normalized(
809 normalized.to_vec(),
810 &AnnConfig::default(),
811 ));
812 let (knn, dists) = split_knn_rows(index.knn_graph_with_sims(k));
813 (knn, dists, Some(index))
814}
815
816fn smooth_knn_calibrate(dists: &[f64]) -> (f64, f64) {
828 let rho = dists.iter().copied().fold(f64::INFINITY, f64::min);
829 let target = (dists.len() as f64).log2();
830 let mut lo = 0.0f64;
831 let mut hi = f64::INFINITY;
832 let mut sigma = 1.0f64;
833 for _ in 0..SIGMA_SEARCH_ITERS {
834 let sum: f64 = dists
835 .iter()
836 .map(|&d| (-((d - rho).max(0.0)) / sigma).exp())
837 .sum();
838 if (sum - target).abs() < SMOOTH_K_TOLERANCE {
839 break;
840 }
841 if sum > target {
842 hi = sigma;
843 } else {
844 lo = sigma;
845 }
846 sigma = if hi.is_finite() {
847 (lo + hi) / 2.0
848 } else {
849 sigma * 2.0
850 };
851 }
852 let mean = dists.iter().sum::<f64>() / dists.len() as f64;
853 (
854 rho,
855 sigma.max((MIN_K_DIST_SCALE * mean).max(SIGMA_ABS_FLOOR)),
856 )
857}
858
859fn fuzzy_simplicial_weights(knn: &[Vec<usize>], dists: &[Vec<f64>]) -> Vec<Vec<f64>> {
865 let directed: Vec<Vec<f64>> = dists
866 .iter()
867 .map(|d| {
868 if d.is_empty() {
869 return Vec::new();
870 }
871 let (rho, sigma) = smooth_knn_calibrate(d);
872 d.iter()
873 .map(|&x| (-((x - rho).max(0.0)) / sigma).exp())
874 .collect()
875 })
876 .collect();
877
878 knn.iter()
879 .enumerate()
880 .map(|(i, neighbors)| {
881 neighbors
882 .iter()
883 .enumerate()
884 .map(|(idx, &j)| {
885 let a = directed[i][idx];
886 let b = knn[j]
887 .iter()
888 .position(|&x| x == i)
889 .map_or(0.0, |p| directed[j][p]);
890 (a + b - a * b).min(1.0)
892 })
893 .collect()
894 })
895 .collect()
896}
897
898fn pca_warm_start(
899 embeddings: &[Embedding],
900 normalized: &[Vec<f64>],
901) -> Result<Vec<[f64; 3]>, ProjectionError> {
902 use crate::projection::PcaProjection;
903 use sphereql_core::spherical_to_cartesian;
904
905 let pca = PcaProjection::fit(embeddings, RadialStrategy::Fixed(1.0))?;
906 let mut out: Vec<[f64; 3]> = Vec::with_capacity(embeddings.len());
907 for (i, e) in embeddings.iter().enumerate() {
908 let pp = pca.project_rich(e);
915 if pp.projection_magnitude > f64::EPSILON {
916 let cart = spherical_to_cartesian(&pp.position);
917 out.push([cart.x, cart.y, cart.z]);
918 continue;
919 }
920 let row = &normalized[i];
925 let mut v = [row[0], row[1], row[2]];
926 normalize_vec(&mut v);
927 if v[0] == 0.0 && v[1] == 0.0 && v[2] == 0.0 {
928 v = [1.0, 0.0, 0.0];
929 }
930 out.push(v);
931 }
932 Ok(out)
933}
934
935fn knn_recall_score(points: &[[f64; 3]], knn: &[Vec<usize>]) -> f64 {
947 let n = points.len();
948 if n < 2 {
949 return 1.0;
950 }
951
952 let ann = (n >= ANN_BRUTE_FORCE_THRESHOLD).then(|| {
953 let coords: Vec<Vec<f64>> = points.iter().map(|p| p.to_vec()).collect();
954 AnnIndex::build_normalized(coords, &AnnConfig::default())
955 });
956
957 let mut total = 0.0;
958 let mut counted = 0usize;
959 for (i, original) in knn.iter().enumerate() {
960 let k = original.len();
961 if k == 0 {
962 continue;
963 }
964 let spherical: Vec<usize> = match &ann {
965 Some(index) => index
966 .query_by_index(i, k)
967 .into_iter()
968 .map(|(j, _)| j)
969 .collect(),
970 None => {
971 let mut sims: Vec<(usize, f64)> = (0..n)
972 .filter(|&j| j != i)
973 .map(|j| (j, dot(&points[i], &points[j])))
974 .collect();
975 sims.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
976 sims.into_iter().take(k).map(|(j, _)| j).collect()
977 }
978 };
979 let hits = spherical.iter().filter(|j| original.contains(j)).count();
980 total += hits as f64 / k as f64;
981 counted += 1;
982 }
983 if counted == 0 {
984 1.0
985 } else {
986 total / counted as f64
987 }
988}
989
990#[cfg(test)]
991mod tests {
992 use super::*;
993 use sphereql_core::angular_distance;
994
995 fn emb(vals: &[f64]) -> Embedding {
996 Embedding::new(vals.to_vec())
997 }
998
999 fn cluster_corpus() -> Vec<Embedding> {
1000 let mut out = Vec::new();
1002 for i in 0..8 {
1003 let t = i as f64 * 0.01;
1004 out.push(emb(&[1.0 + t, 0.5 + t, 0.0, 0.0, 0.0, 0.0]));
1005 }
1006 for i in 0..8 {
1007 let t = i as f64 * 0.01;
1008 out.push(emb(&[0.0, 0.0, 0.0, 1.0 + t, 0.5 + t, 0.0]));
1009 }
1010 out
1011 }
1012
1013 #[test]
1014 fn fit_default_runs_and_produces_valid_points() {
1015 let corpus = cluster_corpus();
1016 let proj = UmapSphereProjection::fit_default(&corpus).unwrap();
1017 for e in &corpus {
1018 let sp = proj.project(e);
1019 assert!(sp.r >= 0.0);
1020 assert!(sp.theta.is_finite());
1021 assert!(sp.phi.is_finite());
1022 }
1023 }
1024
1025 #[test]
1026 fn quality_score_in_unit_interval() {
1027 let corpus = cluster_corpus();
1028 let proj = UmapSphereProjection::fit_default(&corpus).unwrap();
1029 let q = proj.explained_variance_ratio();
1030 assert!((0.0..=1.0).contains(&q), "got {q}");
1031 }
1032
1033 #[test]
1034 fn well_separated_clusters_score_high_recall() {
1035 let corpus = cluster_corpus();
1036 let proj = UmapSphereProjection::fit(
1037 &corpus,
1038 None,
1039 RadialStrategy::Fixed(1.0),
1040 UmapConfig {
1041 n_neighbors: 5,
1042 ..UmapConfig::default()
1043 },
1044 )
1045 .unwrap();
1046 let q = proj.explained_variance_ratio();
1047 assert!(
1048 q > 0.5,
1049 "expected high recall for separated clusters, got {q}"
1050 );
1051 }
1052
1053 #[test]
1054 fn shuffled_positions_score_lower_recall() {
1055 let corpus = cluster_corpus();
1056 let config = UmapConfig {
1057 n_neighbors: 5,
1058 ..UmapConfig::default()
1059 };
1060 let graph = UmapGraph::build(&corpus, config.n_neighbors).unwrap();
1061 let proj =
1062 UmapSphereProjection::fit_from_graph(&graph, None, RadialStrategy::Fixed(1.0), config)
1063 .unwrap();
1064 let fitted = proj.explained_variance_ratio();
1065
1066 let mut shuffled = proj.fitted_points.clone();
1070 let mut rng = SplitMix64::new(0xD15C);
1071 for i in (1..shuffled.len()).rev() {
1072 let j = (rng.next_u64() as usize) % (i + 1);
1073 shuffled.swap(i, j);
1074 }
1075 let broken = knn_recall_score(&shuffled, &graph.knn);
1076 assert!(broken < fitted, "shuffled={broken}, fitted={fitted}");
1077 }
1078
1079 #[test]
1080 fn empty_corpus_errors() {
1081 assert!(matches!(
1082 UmapSphereProjection::fit_default(&[]),
1083 Err(ProjectionError::EmptyCorpus)
1084 ));
1085 }
1086
1087 #[test]
1088 fn dimension_too_low_errors() {
1089 let bad = vec![emb(&[1.0, 2.0]); 8];
1090 assert!(matches!(
1091 UmapSphereProjection::fit_default(&bad),
1092 Err(ProjectionError::DimensionTooLow { .. })
1093 ));
1094 }
1095
1096 #[test]
1097 fn too_few_embeddings_errors() {
1098 let small = vec![emb(&[1.0, 2.0, 3.0, 4.0]); 3];
1099 assert!(matches!(
1100 UmapSphereProjection::fit_default(&small),
1101 Err(ProjectionError::TooFewEmbeddings {
1102 got: 3,
1103 required: 4
1104 })
1105 ));
1106 }
1107
1108 #[test]
1109 fn ann_backed_knn_routes_to_correct_cluster() {
1110 use crate::ann::{AnnConfig, AnnIndex};
1120
1121 let corpus = cluster_corpus();
1122 let normalized: Vec<Vec<f64>> = corpus.iter().map(|e| e.normalized()).collect();
1123
1124 let config = AnnConfig {
1125 n_trees: 8,
1126 max_leaf_size: 8,
1127 seed: 42,
1128 };
1129 let index = AnnIndex::build_normalized(normalized.clone(), &config);
1130 let ann: Vec<Vec<usize>> = index.knn_graph(5);
1131
1132 for (i, neighbors) in ann.iter().enumerate() {
1133 let own_cluster = if i < 8 { 0..8 } else { 8..16 };
1134 for &n in neighbors {
1135 assert!(
1136 own_cluster.contains(&n),
1137 "item {i} got neighbor {n} from the wrong cluster"
1138 );
1139 }
1140 }
1141 }
1142
1143 #[test]
1144 fn category_term_pulls_same_class_together() {
1145 let corpus = cluster_corpus();
1146 let cats: Vec<u32> = (0..corpus.len())
1147 .map(|i| if i < 8 { 0 } else { 1 })
1148 .collect();
1149
1150 let unsupervised = UmapSphereProjection::fit(
1151 &corpus,
1152 None,
1153 RadialStrategy::Fixed(1.0),
1154 UmapConfig {
1155 n_epochs: 100,
1156 category_weight: 0.0,
1157 ..UmapConfig::default()
1158 },
1159 )
1160 .unwrap();
1161
1162 let supervised = UmapSphereProjection::fit(
1163 &corpus,
1164 Some(&cats),
1165 RadialStrategy::Fixed(1.0),
1166 UmapConfig {
1167 n_epochs: 100,
1168 category_weight: 2.0,
1169 ..UmapConfig::default()
1170 },
1171 )
1172 .unwrap();
1173
1174 let within_unsup = mean_within_class(&unsupervised.fitted_points, &cats);
1176 let within_sup = mean_within_class(&supervised.fitted_points, &cats);
1177 assert!(
1178 within_sup <= within_unsup + 1e-6,
1179 "supervised within-class={within_sup}, unsupervised={within_unsup}"
1180 );
1181 }
1182
1183 #[test]
1184 fn category_term_tightens_classes_at_many_categories() {
1185 let mut corpus = Vec::new();
1192 let mut cats: Vec<u32> = Vec::new();
1193 for c in 0..8u32 {
1194 for i in 0..4 {
1195 let mut v = vec![0.0; 8];
1196 v[c as usize] = 1.0 + i as f64 * 0.05;
1197 v[(c as usize + 1) % 8] = 0.1 + i as f64 * 0.02;
1198 corpus.push(emb(&v));
1199 cats.push(c);
1200 }
1201 }
1202
1203 let config = |category_weight: f64| UmapConfig {
1204 n_neighbors: 3,
1205 n_epochs: 100,
1206 category_weight,
1207 ..UmapConfig::default()
1208 };
1209
1210 let unsupervised =
1211 UmapSphereProjection::fit(&corpus, None, RadialStrategy::Fixed(1.0), config(0.0))
1212 .unwrap();
1213 let supervised = UmapSphereProjection::fit(
1214 &corpus,
1215 Some(&cats),
1216 RadialStrategy::Fixed(1.0),
1217 config(2.0),
1218 )
1219 .unwrap();
1220
1221 let within_unsup = mean_within_class(&unsupervised.fitted_points, &cats);
1222 let within_sup = mean_within_class(&supervised.fitted_points, &cats);
1223 assert!(
1224 within_sup < within_unsup,
1225 "supervised within-class={within_sup}, unsupervised={within_unsup}"
1226 );
1227 }
1228
1229 fn mean_within_class(points: &[[f64; 3]], cats: &[u32]) -> f64 {
1230 let mut total = 0.0;
1231 let mut count = 0;
1232 for i in 0..points.len() {
1233 for j in (i + 1)..points.len() {
1234 if cats[i] == cats[j] {
1235 let pi = SphericalPoint::new_unchecked(
1236 1.0,
1237 points[i][1]
1238 .atan2(points[i][0])
1239 .rem_euclid(std::f64::consts::TAU),
1240 points[i][2].clamp(-1.0, 1.0).acos(),
1241 );
1242 let pj = SphericalPoint::new_unchecked(
1243 1.0,
1244 points[j][1]
1245 .atan2(points[j][0])
1246 .rem_euclid(std::f64::consts::TAU),
1247 points[j][2].clamp(-1.0, 1.0).acos(),
1248 );
1249 total += angular_distance(&pi, &pj);
1250 count += 1;
1251 }
1252 }
1253 }
1254 if count == 0 {
1255 0.0
1256 } else {
1257 total / count as f64
1258 }
1259 }
1260
1261 #[test]
1262 fn sigma_calibration_hits_log2_k() {
1263 let dists = [0.1, 0.2, 0.3, 0.4, 0.5];
1264 let (rho, sigma) = smooth_knn_calibrate(&dists);
1265 assert_eq!(rho, 0.1);
1266 let sum: f64 = dists
1267 .iter()
1268 .map(|&d| (-((d - rho).max(0.0)) / sigma).exp())
1269 .sum();
1270 let target = 5.0f64.log2();
1271 assert!((sum - target).abs() < 1e-4, "sum={sum}, target={target}");
1272 }
1273
1274 #[test]
1275 fn nearest_neighbor_weight_is_one() {
1276 let corpus = cluster_corpus();
1277 let graph = UmapGraph::build(&corpus, 5).unwrap();
1278 for (i, neighbors) in graph.knn.iter().enumerate() {
1279 assert!(!neighbors.is_empty());
1280 let w = graph.weights[i][0];
1284 assert!((w - 1.0).abs() < 1e-9, "point {i}: nearest weight {w}");
1285 }
1286 }
1287
1288 #[test]
1289 fn duplicate_heavy_corpus_fits_with_finite_weights() {
1290 let mut corpus = vec![emb(&[1.0, 0.5, 0.0, 0.0, 0.0, 0.0]); 6];
1294 corpus.extend(cluster_corpus());
1295
1296 let graph = UmapGraph::build(&corpus, 5).unwrap();
1297 for row in &graph.weights {
1298 for &w in row {
1299 assert!(w.is_finite() && (0.0..=1.0).contains(&w), "weight {w}");
1300 }
1301 }
1302
1303 let proj = UmapSphereProjection::fit_from_graph(
1304 &graph,
1305 None,
1306 RadialStrategy::Fixed(1.0),
1307 UmapConfig {
1308 n_neighbors: 5,
1309 n_epochs: 30,
1310 ..UmapConfig::default()
1311 },
1312 )
1313 .unwrap();
1314 for p in &proj.fitted_points {
1315 assert!(p.iter().all(|c| c.is_finite()));
1316 }
1317 assert!(proj.explained_variance_ratio().is_finite());
1318 }
1319
1320 #[test]
1321 fn dense_cluster_edges_outweigh_diffuse_cluster_edges() {
1322 let mut corpus = vec![emb(&[1.0, 0.2, 0.0, 0.0, 0.0, 0.0]); 8];
1329 for i in 0..8 {
1330 let mut v = vec![0.0; 6];
1331 v[3] = 1.0;
1332 v[4] = 0.15 * i as f64;
1333 corpus.push(emb(&v));
1334 }
1335 let graph = UmapGraph::build(&corpus, 5).unwrap();
1336
1337 let mean_intra = |range: std::ops::Range<usize>| {
1338 let mut total = 0.0;
1339 let mut count = 0usize;
1340 for i in range.clone() {
1341 for (idx, &j) in graph.knn[i].iter().enumerate() {
1342 if range.contains(&j) {
1343 total += graph.weights[i][idx];
1344 count += 1;
1345 }
1346 }
1347 }
1348 assert!(count > 0, "no intra-cluster edges in {range:?}");
1349 total / count as f64
1350 };
1351
1352 let tight = mean_intra(0..8);
1353 let diffuse = mean_intra(8..16);
1354 assert!(
1355 tight >= diffuse,
1356 "tight mean weight {tight} < diffuse mean weight {diffuse}"
1357 );
1358 assert!(
1359 tight > 0.99,
1360 "duplicate-cluster edges should be ~1, got {tight}"
1361 );
1362 assert!(diffuse < 0.95, "diffuse edges should decay, got {diffuse}");
1363 }
1364
1365 #[test]
1366 fn dimensionality_reports_input_dim() {
1367 let corpus = cluster_corpus();
1368 let proj = UmapSphereProjection::fit_default(&corpus).unwrap();
1369 assert_eq!(proj.dimensionality(), 6);
1370 }
1371
1372 #[test]
1373 fn fit_from_graph_matches_full_fit() {
1374 let corpus = cluster_corpus();
1375 let config = UmapConfig {
1376 n_epochs: 50,
1377 category_weight: 0.0,
1378 seed: 42,
1379 ..UmapConfig::default()
1380 };
1381
1382 let full =
1383 UmapSphereProjection::fit(&corpus, None, RadialStrategy::Fixed(1.0), config.clone())
1384 .unwrap();
1385
1386 let graph = UmapGraph::build(&corpus, config.n_neighbors).unwrap();
1387 let split =
1388 UmapSphereProjection::fit_from_graph(&graph, None, RadialStrategy::Fixed(1.0), config)
1389 .unwrap();
1390
1391 assert!(
1392 (full.explained_variance_ratio() - split.explained_variance_ratio()).abs() < 1e-6,
1393 "full={}, split={}",
1394 full.explained_variance_ratio(),
1395 split.explained_variance_ratio()
1396 );
1397 }
1398
1399 #[test]
1400 fn graph_reusable_across_configs() {
1401 let corpus = cluster_corpus();
1402 let graph = UmapGraph::build(&corpus, 5).unwrap();
1403
1404 let config1 = UmapConfig {
1405 n_epochs: 30,
1406 category_weight: 0.0,
1407 seed: 1,
1408 ..UmapConfig::default()
1409 };
1410 let config2 = UmapConfig {
1411 n_epochs: 60,
1412 category_weight: 1.0,
1413 seed: 2,
1414 ..UmapConfig::default()
1415 };
1416
1417 let p1 =
1418 UmapSphereProjection::fit_from_graph(&graph, None, RadialStrategy::Fixed(1.0), config1)
1419 .unwrap();
1420 let p2 =
1421 UmapSphereProjection::fit_from_graph(&graph, None, RadialStrategy::Fixed(1.0), config2)
1422 .unwrap();
1423
1424 assert!((0.0..=1.0).contains(&p1.explained_variance_ratio()));
1425 assert!((0.0..=1.0).contains(&p2.explained_variance_ratio()));
1426 }
1427
1428 fn assert_projects_to_fitted(proj: &UmapSphereProjection, e: &Embedding, idx: usize) {
1429 use sphereql_core::spherical_to_cartesian;
1430
1431 let pp = proj.project_rich(e);
1432 assert_eq!(pp.certainty, 1.0, "exact match must report certainty 1.0");
1433 let cart = spherical_to_cartesian(&pp.position);
1434 let expected = proj.fitted_points[idx];
1435 assert!((cart.x - expected[0]).abs() < 1e-12);
1436 assert!((cart.y - expected[1]).abs() < 1e-12);
1437 assert!((cart.z - expected[2]).abs() < 1e-12);
1438 }
1439
1440 #[test]
1441 fn projecting_training_embedding_returns_exact_fitted_position() {
1442 let corpus = cluster_corpus();
1443 let proj = UmapSphereProjection::fit(
1444 &corpus,
1445 None,
1446 RadialStrategy::Fixed(1.0),
1447 UmapConfig::default(),
1448 )
1449 .unwrap();
1450
1451 for (i, e) in corpus.iter().enumerate() {
1456 assert_projects_to_fitted(&proj, e, i);
1457 }
1458 }
1459
1460 #[test]
1461 fn duplicate_training_embedding_maps_to_first_fitted_index() {
1462 let mut corpus = cluster_corpus();
1463 corpus.push(corpus[0].clone());
1464 let proj = UmapSphereProjection::fit(
1465 &corpus,
1466 None,
1467 RadialStrategy::Fixed(1.0),
1468 UmapConfig::default(),
1469 )
1470 .unwrap();
1471
1472 assert_projects_to_fitted(&proj, &corpus[0], 0);
1473 assert_projects_to_fitted(&proj, &corpus[16], 0);
1474
1475 let a = proj.project_rich(&corpus[0]);
1476 let b = proj.project_rich(&corpus[16]);
1477 assert_eq!(a.position.theta, b.position.theta);
1478 assert_eq!(a.position.phi, b.position.phi);
1479 assert_eq!(a.position.r, b.position.r);
1480 }
1481
1482 #[test]
1483 fn unseen_embedding_interpolates_on_sphere() {
1484 let corpus = cluster_corpus();
1485 let proj = UmapSphereProjection::fit(
1486 &corpus,
1487 None,
1488 RadialStrategy::Fixed(1.0),
1489 UmapConfig::default(),
1490 )
1491 .unwrap();
1492
1493 let unseen = emb(&[1.0, 0.55, 0.02, 0.0, 0.0, 0.0]);
1494 let pp = proj.project_rich(&unseen);
1495 assert!(pp.certainty > 0.0 && pp.certainty < 1.0);
1496 assert!(pp.position.theta.is_finite());
1497 assert!(pp.position.phi.is_finite());
1498 assert!((pp.position.r - 1.0).abs() < 1e-12);
1499 assert!(pp.projection_magnitude > 0.0 && pp.projection_magnitude <= 1.0 + 1e-12);
1500 }
1501
1502 #[test]
1503 fn ann_backed_transform_above_threshold() {
1504 let mut rng = SplitMix64::new(0x5EED);
1505 let mut random_emb = |dim: usize| {
1506 let vals: Vec<f64> = (0..dim).map(|_| rng.normal()).collect();
1507 emb(&vals)
1508 };
1509 let corpus: Vec<Embedding> = (0..ANN_BRUTE_FORCE_THRESHOLD)
1510 .map(|_| random_emb(8))
1511 .collect();
1512 let config = UmapConfig {
1513 n_neighbors: 5,
1514 n_epochs: 2,
1515 negative_sample_rate: 1,
1516 ..UmapConfig::default()
1517 };
1518 let proj =
1519 UmapSphereProjection::fit(&corpus, None, RadialStrategy::Fixed(1.0), config).unwrap();
1520 assert!(proj.ann.is_some(), "expected ANN index above threshold");
1521
1522 assert_projects_to_fitted(&proj, &corpus[1234], 1234);
1523
1524 let unseen = random_emb(8);
1525 let pp = proj.project_rich(&unseen);
1526 assert!(pp.certainty > 0.0 && pp.certainty <= 1.0);
1527 assert!(pp.position.theta.is_finite());
1528 assert!(pp.position.phi.is_finite());
1529 assert!(pp.position.r.is_finite());
1530 }
1531
1532 #[test]
1533 fn gradients_at_unit_ab_match_old_hardcoded_forms() {
1534 let pairs: [([f64; 3], [f64; 3]); 4] = [
1538 ([1.0, 0.0, 0.0], [0.6, 0.8, 0.0]),
1539 ([0.0, 1.0, 0.0], [0.0, 0.0, 1.0]),
1540 ([1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]),
1541 ([0.36, 0.48, 0.8], [0.48, 0.36, 0.8]),
1542 ];
1543 for (xi, xj) in pairs {
1544 let dx = [xi[0] - xj[0], xi[1] - xj[1], xi[2] - xj[2]];
1545 let d2 = dx[0] * dx[0] + dx[1] * dx[1] + dx[2] * dx[2];
1546
1547 let (gi, gj) = attractive_grad(&xi, &xj, 1.0, 1.0);
1548 let coef = 2.0 / (1.0 + d2);
1549 for d in 0..3 {
1550 assert!((gi[d] - coef * dx[d]).abs() < 1e-12, "attractive gi[{d}]");
1551 assert!((gj[d] + coef * dx[d]).abs() < 1e-12, "attractive gj[{d}]");
1552 }
1553
1554 let (ri, rj) = repulsive_grad(&xi, &xj, 1.0, 1.0);
1555 let d2f = d2.max(1e-6);
1556 let rcoef = -2.0 / (d2f * (1.0 + d2f));
1557 for d in 0..3 {
1558 assert!((ri[d] - rcoef * dx[d]).abs() < 1e-12, "repulsive ri[{d}]");
1559 assert!((rj[d] + rcoef * dx[d]).abs() < 1e-12, "repulsive rj[{d}]");
1560 }
1561 }
1562 }
1563
1564 #[test]
1565 fn ab_fit_matches_canonical_anchor_at_default_min_dist() {
1566 let (a, b) = find_ab_params(0.1);
1570 assert!((1.3..=1.9).contains(&a), "a = {a}");
1571 assert!((0.78..=1.0).contains(&b), "b = {b}");
1572 }
1573
1574 #[test]
1575 fn larger_min_dist_flattens_kernel_near_origin() {
1576 let (a0, b0) = find_ab_params(0.0);
1580 let (a5, b5) = find_ab_params(0.5);
1581 let phi = |a: f64, b: f64, d: f64| 1.0 / (1.0 + a * d.powf(2.0 * b));
1582 for d in [0.1, 0.25, 0.5, 0.75, 1.0] {
1583 assert!(
1584 phi(a5, b5, d) > phi(a0, b0, d),
1585 "Phi at d={d}: min_dist=0.5 gives {}, min_dist=0.0 gives {}",
1586 phi(a5, b5, d),
1587 phi(a0, b0, d)
1588 );
1589 }
1590 }
1591
1592 #[test]
1593 fn larger_min_dist_spreads_fitted_points() {
1594 let corpus = cluster_corpus();
1598 let cats: Vec<u32> = (0..corpus.len())
1599 .map(|i| if i < 8 { 0 } else { 1 })
1600 .collect();
1601 let fit_at = |min_dist: f64| {
1602 UmapSphereProjection::fit(
1603 &corpus,
1604 None,
1605 RadialStrategy::Fixed(1.0),
1606 UmapConfig {
1607 n_neighbors: 5,
1608 min_dist,
1609 ..UmapConfig::default()
1610 },
1611 )
1612 .unwrap()
1613 };
1614 let tight = mean_within_class(&fit_at(0.0).fitted_points, &cats);
1615 let spread = mean_within_class(&fit_at(0.5).fitted_points, &cats);
1616 assert!(
1617 spread > tight,
1618 "min_dist=0.5 within-class spread {spread} should exceed min_dist=0.0's {tight}"
1619 );
1620 }
1621
1622 fn mean_warm_start_displacement(graph: &UmapGraph, points: &[[f64; 3]]) -> f64 {
1623 points
1624 .iter()
1625 .zip(&graph.warm_start)
1626 .map(|(p, w)| {
1627 let cos = (p[0] * w[0] + p[1] * w[1] + p[2] * w[2]).clamp(-1.0, 1.0);
1628 cos.acos()
1629 })
1630 .sum::<f64>()
1631 / points.len() as f64
1632 }
1633
1634 #[test]
1635 fn zero_anchor_is_bit_identical_and_rng_neutral() {
1636 let corpus = cluster_corpus();
1641 let config = UmapConfig {
1642 n_neighbors: 5,
1643 n_epochs: 30,
1644 warm_start_anchor: 0.0,
1645 seed: 42,
1646 ..UmapConfig::default()
1647 };
1648
1649 let graph = UmapGraph::build(&corpus, config.n_neighbors).unwrap();
1650 let a = UmapSphereProjection::fit_from_graph(
1651 &graph,
1652 None,
1653 RadialStrategy::Fixed(1.0),
1654 config.clone(),
1655 )
1656 .unwrap();
1657 let b = UmapSphereProjection::fit_from_graph(
1658 &graph,
1659 None,
1660 RadialStrategy::Fixed(1.0),
1661 config.clone(),
1662 )
1663 .unwrap();
1664 let full =
1665 UmapSphereProjection::fit(&corpus, None, RadialStrategy::Fixed(1.0), config).unwrap();
1666
1667 assert_eq!(a.fitted_points, b.fitted_points);
1668 assert_eq!(a.fitted_points, full.fitted_points);
1669 }
1670
1671 #[test]
1672 fn warm_start_anchor_limits_component_drift() {
1673 let corpus = cluster_corpus();
1682 let graph = UmapGraph::build(&corpus, 5).unwrap();
1683 for (i, neighbors) in graph.knn.iter().enumerate() {
1684 let own_cluster = if i < 8 { 0..8 } else { 8..16 };
1685 for &j in neighbors {
1686 assert!(
1687 own_cluster.contains(&j),
1688 "expected disconnected components, but {i} links to {j}"
1689 );
1690 }
1691 }
1692
1693 let fit_at = |anchor: f64| {
1694 UmapSphereProjection::fit_from_graph(
1695 &graph,
1696 None,
1697 RadialStrategy::Fixed(1.0),
1698 UmapConfig {
1699 n_neighbors: 5,
1700 n_epochs: 100,
1701 warm_start_anchor: anchor,
1702 seed: 42,
1703 ..UmapConfig::default()
1704 },
1705 )
1706 .unwrap()
1707 };
1708
1709 let free = mean_warm_start_displacement(&graph, &fit_at(0.0).fitted_points);
1710 let anchored = mean_warm_start_displacement(&graph, &fit_at(0.05).fitted_points);
1711 assert!(
1712 anchored < free,
1713 "anchored displacement {anchored} should be below unanchored {free}"
1714 );
1715 }
1716
1717 #[test]
1718 fn anchored_fit_stays_on_sphere_with_valid_quality() {
1719 let corpus = cluster_corpus();
1720 let proj = UmapSphereProjection::fit(
1721 &corpus,
1722 None,
1723 RadialStrategy::Fixed(1.0),
1724 UmapConfig {
1725 n_neighbors: 5,
1726 n_epochs: 50,
1727 warm_start_anchor: 0.05,
1728 ..UmapConfig::default()
1729 },
1730 )
1731 .unwrap();
1732 for p in &proj.fitted_points {
1733 assert!(p.iter().all(|c| c.is_finite()));
1734 let mag = (p[0] * p[0] + p[1] * p[1] + p[2] * p[2]).sqrt();
1735 assert!((mag - 1.0).abs() < 1e-9, "off-sphere magnitude {mag}");
1736 }
1737 let q = proj.explained_variance_ratio();
1738 assert!((0.0..=1.0).contains(&q), "quality {q}");
1739 }
1740}