Skip to main content

sphereql_embed/
query.rs

1use std::collections::{BinaryHeap, HashMap};
2
3use sphereql_core::*;
4use sphereql_index::*;
5
6use crate::projection::Projection;
7use crate::types::{Embedding, ProjectedPoint};
8
9#[derive(Debug, Clone)]
10pub struct EmbeddingItem {
11    pub id: String,
12    pub position: SphericalPoint,
13    pub original_magnitude: f64,
14    /// Rich projection metadata. `None` for items inserted via `insert()` (legacy path).
15    pub projected: Option<ProjectedPoint>,
16}
17
18impl SpatialItem for EmbeddingItem {
19    type Id = String;
20    fn id(&self) -> &String {
21        &self.id
22    }
23    fn position(&self) -> &SphericalPoint {
24        &self.position
25    }
26}
27
28impl EmbeddingItem {
29    /// Certainty of this point's projection. Falls back to 1.0 if no rich metadata.
30    pub fn certainty(&self) -> f64 {
31        self.projected.map_or(1.0, |p| p.certainty)
32    }
33
34    /// Intensity (pre-normalization magnitude) of the original embedding.
35    pub fn intensity(&self) -> f64 {
36        self.projected
37            .map_or(self.original_magnitude, |p| p.intensity)
38    }
39
40    /// PCA projection magnitude — how strongly this point projects onto
41    /// the 3 principal components. Low values indicate ambiguous points.
42    pub fn projection_magnitude(&self) -> f64 {
43        self.projected.map_or(1.0, |p| p.projection_magnitude)
44    }
45}
46
47pub struct EmbeddingIndexBuilder<P> {
48    projection: P,
49    inner: SpatialIndexBuilder,
50}
51
52impl<P: Projection> EmbeddingIndexBuilder<P> {
53    pub fn new(projection: P) -> Self {
54        Self {
55            projection,
56            inner: SpatialIndexBuilder::new(),
57        }
58    }
59
60    pub fn shell_boundary(mut self, r: f64) -> Self {
61        self.inner = self.inner.shell_boundary(r);
62        self
63    }
64
65    pub fn uniform_shells(mut self, count: usize, max_r: f64) -> Self {
66        self.inner = self.inner.uniform_shells(count, max_r);
67        self
68    }
69
70    pub fn theta_divisions(mut self, n: usize) -> Self {
71        self.inner = self.inner.theta_divisions(n);
72        self
73    }
74
75    pub fn phi_divisions(mut self, n: usize) -> Self {
76        self.inner = self.inner.phi_divisions(n);
77        self
78    }
79
80    pub fn build(self) -> EmbeddingIndex<P> {
81        EmbeddingIndex {
82            projection: self.projection,
83            index: self.inner.build(),
84        }
85    }
86}
87
88pub struct EmbeddingIndex<P> {
89    projection: P,
90    index: SpatialIndex<EmbeddingItem>,
91}
92
93impl<P: Projection> EmbeddingIndex<P> {
94    pub fn builder(projection: P) -> EmbeddingIndexBuilder<P> {
95        EmbeddingIndexBuilder::new(projection)
96    }
97
98    pub fn insert(&mut self, id: impl Into<String>, embedding: &Embedding) {
99        let rich = self.projection.project_rich(embedding);
100        self.index.insert(EmbeddingItem {
101            id: id.into(),
102            position: rich.position,
103            original_magnitude: embedding.magnitude(),
104            projected: Some(rich),
105        });
106    }
107
108    /// Insert with an explicit radial value, overriding the projection's RadialStrategy.
109    /// The angular coordinates (theta, phi) are still determined by the projection.
110    /// Use this for metadata-driven radius: recency scores, importance weights, etc.
111    pub fn insert_with_radius(&mut self, id: impl Into<String>, embedding: &Embedding, r: f64) {
112        let rich = self.projection.project_rich(embedding);
113        let position = SphericalPoint::new_unchecked(r, rich.position.theta, rich.position.phi);
114        self.index.insert(EmbeddingItem {
115            id: id.into(),
116            position,
117            original_magnitude: embedding.magnitude(),
118            projected: Some(ProjectedPoint { position, ..rich }),
119        });
120    }
121
122    /// Find the k embeddings whose projected directions are closest to the query.
123    pub fn search_nearest(&self, query: &Embedding, k: usize) -> Vec<NearestResult<EmbeddingItem>> {
124        let projected = self.projection.project(query);
125        self.index.nearest(&projected, k)
126    }
127
128    /// Find all embeddings whose projected cosine similarity to the query
129    /// is at least `min_cosine_similarity`.
130    ///
131    /// Internally maps cos(sim) → angular distance and uses `within_distance`.
132    pub fn search_similar(
133        &self,
134        query: &Embedding,
135        min_cosine_similarity: f64,
136    ) -> SpatialQueryResult<EmbeddingItem> {
137        let projected = self.projection.project(query);
138        let max_angle = min_cosine_similarity.clamp(-1.0, 1.0).acos();
139        self.index.within_distance(&projected, max_angle)
140    }
141
142    pub fn search_region(&self, region: &Region) -> SpatialQueryResult<EmbeddingItem> {
143        self.index.query_region(region)
144    }
145
146    pub fn remove(&mut self, id: &str) -> Option<EmbeddingItem> {
147        self.index.remove(&id.to_string())
148    }
149
150    pub fn get(&self, id: &str) -> Option<&EmbeddingItem> {
151        self.index.get(&id.to_string())
152    }
153
154    pub fn len(&self) -> usize {
155        self.index.len()
156    }
157
158    pub fn is_empty(&self) -> bool {
159        self.index.is_empty()
160    }
161
162    pub fn projection(&self) -> &P {
163        &self.projection
164    }
165
166    pub fn all_items(&self) -> Vec<&EmbeddingItem> {
167        self.index.all_items()
168    }
169
170    /// Find the shortest semantic path between two items through a k-NN graph.
171    ///
172    /// Builds a k-nearest-neighbor graph over all indexed embeddings, then
173    /// runs Dijkstra's algorithm weighted by angular distance. The resulting
174    /// path traces the chain of closest intermediate concepts connecting
175    /// the source to the target.
176    ///
177    /// **Complexity:** O(n^2 * k) — the k-NN graph is rebuilt from scratch
178    /// on every call. Not suitable for large indices (>5000 items) without
179    /// caching.
180    pub fn concept_path(&self, source_id: &str, target_id: &str, k: usize) -> Option<ConceptPath> {
181        let items = self.index.all_items();
182        let n = items.len();
183        if n < 2 {
184            return None;
185        }
186
187        let id_to_idx: HashMap<&str, usize> = items
188            .iter()
189            .enumerate()
190            .map(|(i, item)| (item.id.as_str(), i))
191            .collect();
192
193        let source_idx = *id_to_idx.get(source_id)?;
194        let target_idx = *id_to_idx.get(target_id)?;
195
196        // Build k-NN adjacency list (undirected)
197        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
198        for (i, item) in items.iter().enumerate() {
199            let nearest = self.index.nearest(item.position(), k + 1);
200            for result in &nearest {
201                if let Some(&j) = id_to_idx.get(result.item.id.as_str())
202                    && i != j
203                {
204                    adj[i].push((j, result.distance));
205                }
206            }
207        }
208        // Symmetrize
209        let snapshot: Vec<Vec<(usize, f64)>> = adj.clone();
210        for (i, edges) in snapshot.iter().enumerate() {
211            for &(j, d) in edges {
212                if !adj[j].iter().any(|&(k, _)| k == i) {
213                    adj[j].push((i, d));
214                }
215            }
216        }
217
218        // Dijkstra (min-heap via reversed Ord)
219        let mut dist = vec![f64::INFINITY; n];
220        let mut prev: Vec<Option<usize>> = vec![None; n];
221        let mut heap = BinaryHeap::new();
222
223        dist[source_idx] = 0.0;
224        heap.push(DijkstraEntry {
225            dist: 0.0,
226            node: source_idx,
227        });
228
229        while let Some(entry) = heap.pop() {
230            let u = entry.node;
231            if entry.dist > dist[u] {
232                continue;
233            }
234            if u == target_idx {
235                break;
236            }
237            for &(v, w) in &adj[u] {
238                let nd = dist[u] + w;
239                if nd < dist[v] {
240                    dist[v] = nd;
241                    prev[v] = Some(u);
242                    heap.push(DijkstraEntry { dist: nd, node: v });
243                }
244            }
245        }
246
247        if dist[target_idx].is_infinite() {
248            return None;
249        }
250
251        // Reconstruct
252        let mut path = Vec::new();
253        let mut cur = target_idx;
254        loop {
255            path.push(PathStep {
256                id: items[cur].id.clone(),
257                cumulative_distance: dist[cur],
258            });
259            match prev[cur] {
260                Some(p) => cur = p,
261                None => break,
262            }
263        }
264        path.reverse();
265
266        Some(ConceptPath {
267            total_distance: dist[target_idx],
268            steps: path,
269        })
270    }
271}
272
273// --- Concept path types ---
274
275#[derive(Debug, Clone)]
276pub struct ConceptPath {
277    pub steps: Vec<PathStep>,
278    pub total_distance: f64,
279}
280
281#[derive(Debug, Clone)]
282pub struct PathStep {
283    pub id: String,
284    pub cumulative_distance: f64,
285}
286
287#[derive(PartialEq)]
288struct DijkstraEntry {
289    dist: f64,
290    node: usize,
291}
292
293// Safety: dist values come from cosine_proxy on unit vectors, never NaN in practice.
294// Ord impl uses unwrap_or(Equal) as a NaN guard.
295impl Eq for DijkstraEntry {}
296
297impl PartialOrd for DijkstraEntry {
298    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
299        Some(self.cmp(other))
300    }
301}
302
303impl Ord for DijkstraEntry {
304    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
305        // Reversed: BinaryHeap is a max-heap, so smaller dist = higher priority
306        other
307            .dist
308            .partial_cmp(&self.dist)
309            .unwrap_or(std::cmp::Ordering::Equal)
310    }
311}
312
313// --- Slicing manifold ---
314
315/// A 2D plane fitted through the 3D projected point cloud that captures
316/// the maximum variance. Found by PCA on the Cartesian coordinates of
317/// the projected embeddings.
318///
319/// The plane is defined by:
320/// - `centroid`: the mean of all 3D points
321/// - `basis_u`, `basis_v`: orthonormal vectors spanning the plane (directions of max variance)
322/// - `normal`: vector perpendicular to the plane (direction of minimum variance)
323#[derive(Debug, Clone)]
324pub struct SlicingManifold {
325    pub centroid: [f64; 3],
326    pub normal: [f64; 3],
327    pub basis_u: [f64; 3],
328    pub basis_v: [f64; 3],
329    pub variance_ratio: f64,
330}
331
332impl SlicingManifold {
333    /// Fit the optimal slicing plane to a set of 3D points.
334    /// Each point is (x, y, z) in Cartesian coordinates.
335    pub fn fit(points: &[[f64; 3]]) -> Self {
336        let n = points.len() as f64;
337        assert!(n >= 3.0, "need at least 3 points to fit a plane");
338
339        // Centroid
340        let mut c = [0.0; 3];
341        for p in points {
342            for i in 0..3 {
343                c[i] += p[i];
344            }
345        }
346        for ci in &mut c {
347            *ci /= n;
348        }
349
350        // 3×3 covariance matrix (symmetric)
351        let mut cov = [[0.0f64; 3]; 3];
352        for p in points {
353            let d = [p[0] - c[0], p[1] - c[1], p[2] - c[2]];
354            for i in 0..3 {
355                for j in 0..3 {
356                    cov[i][j] += d[i] * d[j];
357                }
358            }
359        }
360        for row in &mut cov {
361            for v in row.iter_mut() {
362                *v /= n;
363            }
364        }
365
366        // Eigendecomposition of 3×3 symmetric matrix via Jacobi iteration
367        let (eigenvalues, eigenvectors) = eigen_symmetric_3x3(&cov);
368
369        // eigenvalues are sorted descending: λ₀ ≥ λ₁ ≥ λ₂
370        // basis_u = eigenvector of λ₀, basis_v = eigenvector of λ₁, normal = eigenvector of λ₂
371        let total_var = eigenvalues[0] + eigenvalues[1] + eigenvalues[2];
372        let variance_ratio = if total_var > 0.0 {
373            (eigenvalues[0] + eigenvalues[1]) / total_var
374        } else {
375            1.0
376        };
377
378        Self {
379            centroid: c,
380            normal: eigenvectors[2],
381            basis_u: eigenvectors[0],
382            basis_v: eigenvectors[1],
383            variance_ratio,
384        }
385    }
386
387    /// Project a 3D point onto the plane, returning (u, v) coordinates.
388    pub fn project_2d(&self, point: &[f64; 3]) -> (f64, f64) {
389        let d = [
390            point[0] - self.centroid[0],
391            point[1] - self.centroid[1],
392            point[2] - self.centroid[2],
393        ];
394        let u = d[0] * self.basis_u[0] + d[1] * self.basis_u[1] + d[2] * self.basis_u[2];
395        let v = d[0] * self.basis_v[0] + d[1] * self.basis_v[1] + d[2] * self.basis_v[2];
396        (u, v)
397    }
398
399    /// Signed distance from the plane (positive = same side as normal).
400    pub fn distance(&self, point: &[f64; 3]) -> f64 {
401        let d = [
402            point[0] - self.centroid[0],
403            point[1] - self.centroid[1],
404            point[2] - self.centroid[2],
405        ];
406        d[0] * self.normal[0] + d[1] * self.normal[1] + d[2] * self.normal[2]
407    }
408
409    /// Fit a local manifold around a query point using its k nearest neighbors.
410    ///
411    /// The local plane captures the shape of the semantic neighborhood:
412    /// - If variance_ratio ≈ 1.0, the neighborhood is flat (concepts spread in a plane)
413    /// - If variance_ratio ≈ 0.67, concepts are uniformly distributed (spherical)
414    /// - The normal direction reveals which semantic axis is least relevant locally
415    ///
416    /// This enables directional search narrowing: once you know the local geometry,
417    /// you can restrict subsequent queries to the dominant plane, cutting the
418    /// effective search dimensionality from 3D to 2D in that region.
419    pub fn fit_local(query: &[f64; 3], all_points: &[[f64; 3]], k: usize) -> Self {
420        let mut dists: Vec<(usize, f64)> = all_points
421            .iter()
422            .enumerate()
423            .map(|(i, p)| (i, dist3(query, p)))
424            .collect();
425        dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
426
427        let neighborhood: Vec<[f64; 3]> = dists
428            .iter()
429            .take(k.max(3))
430            .map(|&(i, _)| all_points[i])
431            .collect();
432
433        Self::fit(&neighborhood)
434    }
435}
436
437/// Eigendecomposition of a 3×3 symmetric matrix via Jacobi rotations.
438/// Returns (eigenvalues_desc, eigenvectors_desc) sorted by decreasing eigenvalue.
439fn eigen_symmetric_3x3(m: &[[f64; 3]; 3]) -> ([f64; 3], [[f64; 3]; 3]) {
440    let mut a = *m;
441    let mut v = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]; // eigenvector matrix
442
443    #[allow(clippy::needless_range_loop)]
444    for _ in 0..50 {
445        // Find largest off-diagonal element
446        let mut p = 0;
447        let mut q = 1;
448        let mut max_val = a[0][1].abs();
449        for i in 0..3 {
450            for j in (i + 1)..3 {
451                if a[i][j].abs() > max_val {
452                    max_val = a[i][j].abs();
453                    p = i;
454                    q = j;
455                }
456            }
457        }
458        if max_val < 1e-15 {
459            break;
460        }
461
462        // Jacobi rotation to zero out a[p][q]
463        let theta = if (a[p][p] - a[q][q]).abs() < 1e-30 {
464            std::f64::consts::FRAC_PI_4
465        } else {
466            0.5 * (2.0 * a[p][q] / (a[p][p] - a[q][q])).atan()
467        };
468        let c = theta.cos();
469        let s = theta.sin();
470
471        // Rotate a ← GᵀaG
472        let mut new_a = a;
473        for i in 0..3 {
474            new_a[i][p] = c * a[i][p] + s * a[i][q];
475            new_a[i][q] = -s * a[i][p] + c * a[i][q];
476        }
477        let snapshot = new_a;
478        for j in 0..3 {
479            new_a[p][j] = c * snapshot[p][j] + s * snapshot[q][j];
480            new_a[q][j] = -s * snapshot[p][j] + c * snapshot[q][j];
481        }
482        new_a[p][q] = 0.0;
483        new_a[q][p] = 0.0;
484        a = new_a;
485
486        // Rotate eigenvectors: V ← VG
487        let mut new_v = v;
488        for i in 0..3 {
489            new_v[i][p] = c * v[i][p] + s * v[i][q];
490            new_v[i][q] = -s * v[i][p] + c * v[i][q];
491        }
492        v = new_v;
493    }
494
495    let eigenvalues = [a[0][0], a[1][1], a[2][2]];
496
497    // Sort by descending eigenvalue
498    let mut order = [0usize, 1, 2];
499    order.sort_by(|&a, &b| eigenvalues[b].partial_cmp(&eigenvalues[a]).unwrap());
500
501    let sorted_vals = [
502        eigenvalues[order[0]],
503        eigenvalues[order[1]],
504        eigenvalues[order[2]],
505    ];
506    // Eigenvectors are columns of v
507    let sorted_vecs = [
508        [v[0][order[0]], v[1][order[0]], v[2][order[0]]],
509        [v[0][order[1]], v[1][order[1]], v[2][order[1]]],
510        [v[0][order[2]], v[1][order[2]], v[2][order[2]]],
511    ];
512
513    (sorted_vals, sorted_vecs)
514}
515
516// --- Concept Globs (spherical k-means + silhouette auto-k) ---
517
518/// A cluster of semantically related embeddings in the projected 3D space.
519#[derive(Debug, Clone)]
520pub struct ConceptGlob {
521    pub id: usize,
522    pub centroid: [f64; 3],
523    pub member_ids: Vec<String>,
524    pub member_distances: Vec<f64>,
525    pub radius: f64,
526}
527
528/// Result of glob detection: the set of all globs plus quality metrics.
529#[derive(Debug, Clone)]
530pub struct GlobResult {
531    pub globs: Vec<ConceptGlob>,
532    pub k: usize,
533    pub silhouette: f64,
534}
535
536impl GlobResult {
537    /// Detect concept globs from 3D projected points.
538    ///
539    /// If `k` is `Some`, uses that many clusters.
540    /// If `None`, auto-selects k ∈ [2, max_k] by maximizing the silhouette score.
541    pub fn detect(points: &[[f64; 3]], ids: &[String], k: Option<usize>, max_k: usize) -> Self {
542        let n = points.len();
543        assert_eq!(n, ids.len());
544        assert!(n >= 2, "need at least 2 points for clustering");
545
546        let max_k = max_k.min(n);
547
548        if let Some(k) = k {
549            let k = k.clamp(2, max_k);
550            let (assignments, silhouette) = kmeans_3d(points, k);
551            let globs = build_globs(points, ids, &assignments, k);
552            return Self {
553                globs,
554                k,
555                silhouette,
556            };
557        }
558
559        // Auto-detect: try k = 2..=max_k, pick best silhouette
560        let mut best_k = 2;
561        let mut best_sil = f64::NEG_INFINITY;
562        let mut best_assignments = vec![0usize; n];
563
564        for trial_k in 2..=max_k {
565            let (assignments, sil) = kmeans_3d(points, trial_k);
566            if sil > best_sil {
567                best_sil = sil;
568                best_k = trial_k;
569                best_assignments = assignments;
570            }
571        }
572
573        let globs = build_globs(points, ids, &best_assignments, best_k);
574        Self {
575            globs,
576            k: best_k,
577            silhouette: best_sil,
578        }
579    }
580}
581
582fn kmeans_3d(points: &[[f64; 3]], k: usize) -> (Vec<usize>, f64) {
583    let n = points.len();
584    let max_iter = 50;
585
586    // Init: spread initial centers evenly across the point set
587    let mut centers: Vec<[f64; 3]> = (0..k).map(|i| points[i * n / k]).collect();
588
589    let mut assignments = vec![0usize; n];
590
591    for _ in 0..max_iter {
592        let mut changed = false;
593
594        // Assign by angular distance (direction, not position)
595        for (i, p) in points.iter().enumerate() {
596            let mut best = 0;
597            let mut best_d = f64::MAX;
598            for (j, c) in centers.iter().enumerate() {
599                let d = angular_dist3(p, c);
600                if d < best_d {
601                    best_d = d;
602                    best = j;
603                }
604            }
605            if assignments[i] != best {
606                assignments[i] = best;
607                changed = true;
608            }
609        }
610
611        if !changed {
612            break;
613        }
614
615        // Update centers: mean direction (Euclidean mean of unit vectors, then normalize).
616        // This is the Fréchet mean on S² for concentrated clusters.
617        let mut sums = vec![[0.0f64; 3]; k];
618        let mut counts = vec![0usize; k];
619        for (i, &a) in assignments.iter().enumerate() {
620            let norm_p = normalize3(&points[i]);
621            for (d, &np) in norm_p.iter().enumerate() {
622                sums[a][d] += np;
623            }
624            counts[a] += 1;
625        }
626        for j in 0..k {
627            if counts[j] > 0 {
628                centers[j] = normalize3(&sums[j]);
629            }
630        }
631    }
632
633    let sil = silhouette_score(points, &assignments, k);
634    (assignments, sil)
635}
636
637fn silhouette_score(points: &[[f64; 3]], assignments: &[usize], k: usize) -> f64 {
638    let n = points.len();
639    if n <= 1 || k <= 1 {
640        return 0.0;
641    }
642
643    let mut total = 0.0;
644    for i in 0..n {
645        let ci = assignments[i];
646
647        // a(i): mean angular dist to same-cluster members
648        let mut a_sum = 0.0;
649        let mut a_cnt = 0;
650        for j in 0..n {
651            if j != i && assignments[j] == ci {
652                a_sum += angular_dist3(&points[i], &points[j]);
653                a_cnt += 1;
654            }
655        }
656        let a = if a_cnt > 0 { a_sum / a_cnt as f64 } else { 0.0 };
657
658        // b(i): min mean angular dist to any other cluster
659        let mut b = f64::MAX;
660        for ck in 0..k {
661            if ck == ci {
662                continue;
663            }
664            let mut b_sum = 0.0;
665            let mut b_cnt = 0;
666            for j in 0..n {
667                if assignments[j] == ck {
668                    b_sum += angular_dist3(&points[i], &points[j]);
669                    b_cnt += 1;
670                }
671            }
672            if b_cnt > 0 {
673                b = b.min(b_sum / b_cnt as f64);
674            }
675        }
676        if b == f64::MAX {
677            b = 0.0;
678        }
679
680        let denom = a.max(b);
681        total += if denom > 0.0 { (b - a) / denom } else { 0.0 };
682    }
683
684    total / n as f64
685}
686
687fn build_globs(
688    points: &[[f64; 3]],
689    ids: &[String],
690    assignments: &[usize],
691    k: usize,
692) -> Vec<ConceptGlob> {
693    let mut globs = Vec::with_capacity(k);
694
695    for cluster_id in 0..k {
696        let member_indices: Vec<usize> = assignments
697            .iter()
698            .enumerate()
699            .filter(|&(_, &a)| a == cluster_id)
700            .map(|(i, _)| i)
701            .collect();
702
703        if member_indices.is_empty() {
704            continue;
705        }
706
707        // Centroid: mean direction (normalize to get angular centroid)
708        let mut centroid = [0.0; 3];
709        for &i in &member_indices {
710            let norm_p = normalize3(&points[i]);
711            for (d, c) in centroid.iter_mut().enumerate() {
712                *c += norm_p[d];
713            }
714        }
715        centroid = normalize3(&centroid);
716
717        // Member angular distances from centroid
718        let member_distances: Vec<f64> = member_indices
719            .iter()
720            .map(|&i| angular_dist3(&points[i], &centroid))
721            .collect();
722
723        let radius = member_distances.iter().cloned().fold(0.0f64, f64::max);
724
725        let member_ids: Vec<String> = member_indices.iter().map(|&i| ids[i].clone()).collect();
726
727        globs.push(ConceptGlob {
728            id: cluster_id,
729            centroid,
730            member_ids,
731            member_distances,
732            radius,
733        });
734    }
735
736    globs
737}
738
739fn dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
740    let dx = a[0] - b[0];
741    let dy = a[1] - b[1];
742    let dz = a[2] - b[2];
743    (dx * dx + dy * dy + dz * dz).sqrt()
744}
745
746/// Angular distance between two 3D points treated as direction vectors.
747/// Returns the angle in radians [0, π]. Ignores magnitude differences.
748fn angular_dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
749    let dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
750    let ma = (a[0] * a[0] + a[1] * a[1] + a[2] * a[2]).sqrt();
751    let mb = (b[0] * b[0] + b[1] * b[1] + b[2] * b[2]).sqrt();
752    let denom = ma * mb;
753    if denom < f64::EPSILON {
754        return 0.0;
755    }
756    (dot / denom).clamp(-1.0, 1.0).acos()
757}
758
759/// Normalize a 3D vector to unit length. Returns zero vector if input is zero.
760fn normalize3(v: &[f64; 3]) -> [f64; 3] {
761    let mag = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
762    if mag < f64::EPSILON {
763        return [0.0; 3];
764    }
765    [v[0] / mag, v[1] / mag, v[2] / mag]
766}
767
768/// Builds SphereQL [`Region`]s from semantic constraints on embeddings.
769pub struct SemanticQuery;
770
771impl SemanticQuery {
772    /// Spherical cap: all points within `max_angular_distance` radians of the query.
773    pub fn within_angle<P: Projection>(
774        query: &Embedding,
775        projection: &P,
776        max_angular_distance: f64,
777    ) -> Region {
778        let point = projection.project(query);
779        let half_angle = max_angular_distance.clamp(1e-10, std::f64::consts::PI);
780        Region::Cap(
781            Cap::new(
782                SphericalPoint::new_unchecked(1.0, point.theta, point.phi),
783                half_angle,
784            )
785            .unwrap(),
786        )
787    }
788
789    /// Spherical cap from a cosine similarity threshold.
790    /// cos_sim >= threshold ↔ angular_distance <= arccos(threshold).
791    pub fn above_similarity<P: Projection>(
792        query: &Embedding,
793        projection: &P,
794        min_similarity: f64,
795    ) -> Region {
796        let half_angle = min_similarity.clamp(-1.0, 1.0).acos();
797        Self::within_angle(query, projection, half_angle)
798    }
799
800    /// Radial shell: embeddings whose projected radius falls in [inner, outer].
801    pub fn in_shell(inner: f64, outer: f64) -> Region {
802        Region::Shell(Shell::new(inner, outer).expect("invalid shell bounds"))
803    }
804
805    /// Intersection of a similarity cap with a radial shell.
806    /// "Semantically similar AND within a magnitude/metadata range."
807    pub fn similar_in_shell<P: Projection>(
808        query: &Embedding,
809        projection: &P,
810        min_similarity: f64,
811        shell_inner: f64,
812        shell_outer: f64,
813    ) -> Region {
814        Region::intersection(vec![
815            Self::above_similarity(query, projection, min_similarity),
816            Self::in_shell(shell_inner, shell_outer),
817        ])
818    }
819}
820
821#[cfg(test)]
822mod tests {
823    use super::*;
824    use crate::projection::{PcaProjection, RandomProjection};
825    use crate::types::RadialStrategy;
826    use sphereql_core::angular_distance;
827
828    fn emb(vals: &[f64]) -> Embedding {
829        Embedding::new(vals.to_vec())
830    }
831
832    fn test_corpus() -> Vec<Embedding> {
833        vec![
834            emb(&[1.0, 0.0, 0.0, 0.1, 0.0]),
835            emb(&[0.0, 1.0, 0.0, 0.0, 0.1]),
836            emb(&[0.0, 0.0, 1.0, 0.1, 0.0]),
837            emb(&[1.0, 1.0, 0.0, 0.05, 0.05]),
838            emb(&[-1.0, 0.0, 0.0, -0.1, 0.0]),
839            emb(&[0.0, -1.0, 0.0, 0.0, -0.1]),
840        ]
841    }
842
843    // --- EmbeddingIndex ---
844
845    #[test]
846    fn insert_and_get() {
847        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
848        let mut idx = EmbeddingIndex::builder(rp)
849            .theta_divisions(4)
850            .phi_divisions(3)
851            .build();
852
853        idx.insert("a", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
854        idx.insert("b", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
855
856        assert_eq!(idx.len(), 2);
857        assert!(!idx.is_empty());
858        assert!(idx.get("a").is_some());
859        assert!(idx.get("b").is_some());
860        assert!(idx.get("c").is_none());
861    }
862
863    #[test]
864    fn remove() {
865        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
866        let mut idx = EmbeddingIndex::builder(rp).build();
867
868        idx.insert("a", &emb(&[1.0; 5]));
869        assert_eq!(idx.len(), 1);
870
871        let removed = idx.remove("a");
872        assert!(removed.is_some());
873        assert_eq!(removed.unwrap().id, "a");
874        assert_eq!(idx.len(), 0);
875        assert!(idx.get("a").is_none());
876    }
877
878    #[test]
879    fn remove_nonexistent() {
880        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
881        let mut idx = EmbeddingIndex::builder(rp).build();
882        assert!(idx.remove("nope").is_none());
883    }
884
885    #[test]
886    fn search_nearest_returns_sorted() {
887        let corpus = test_corpus();
888        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
889        let mut idx = EmbeddingIndex::builder(pca)
890            .theta_divisions(4)
891            .phi_divisions(3)
892            .build();
893
894        for (i, e) in corpus.iter().enumerate() {
895            idx.insert(format!("item-{i}"), e);
896        }
897
898        let query = emb(&[0.95, 0.1, 0.0, 0.05, 0.0]);
899        let results = idx.search_nearest(&query, 3);
900
901        assert_eq!(results.len(), 3);
902        assert!(results[0].distance <= results[1].distance);
903        assert!(results[1].distance <= results[2].distance);
904    }
905
906    #[test]
907    fn search_similar_respects_threshold() {
908        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
909        let mut idx = EmbeddingIndex::builder(rp)
910            .theta_divisions(4)
911            .phi_divisions(3)
912            .build();
913
914        idx.insert("close_a", &emb(&[1.0, 0.1, 0.0, 0.0, 0.0]));
915        idx.insert("close_b", &emb(&[0.9, 0.2, 0.0, 0.0, 0.0]));
916        idx.insert("far", &emb(&[-1.0, 0.0, 0.0, 0.0, 0.0]));
917
918        let query = emb(&[1.0, 0.0, 0.0, 0.0, 0.0]);
919        let projected_query = idx.projection().project(&query);
920        let result = idx.search_similar(&query, 0.5);
921
922        let max_angle = 0.5_f64.acos();
923        for item in &result.items {
924            let d = angular_distance(&projected_query, item.position());
925            assert!(d <= max_angle + 1e-10, "item {} too far: {d}", item.id);
926        }
927    }
928
929    #[test]
930    fn insert_with_radius_overrides() {
931        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
932        let mut idx = EmbeddingIndex::builder(rp).build();
933
934        idx.insert_with_radius("custom", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]), 42.0);
935        let item = idx.get("custom").unwrap();
936        assert!((item.position.r - 42.0).abs() < 1e-12);
937    }
938
939    #[test]
940    fn original_magnitude_stored() {
941        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
942        let mut idx = EmbeddingIndex::builder(rp).build();
943
944        let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0]);
945        idx.insert("vec", &e);
946        let item = idx.get("vec").unwrap();
947        assert!((item.original_magnitude - 5.0).abs() < 1e-10);
948    }
949
950    #[test]
951    fn magnitude_radial_with_shell_query() {
952        let corpus = test_corpus();
953        let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude);
954        let mut idx = EmbeddingIndex::builder(pca)
955            .uniform_shells(5, 10.0)
956            .theta_divisions(4)
957            .phi_divisions(3)
958            .build();
959
960        idx.insert("small", &emb(&[0.1, 0.0, 0.0, 0.0, 0.0]));
961        idx.insert("medium", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
962        idx.insert("large", &emb(&[5.0, 0.0, 0.0, 0.0, 0.0]));
963
964        let shell = Shell::new(0.5, 2.0).unwrap();
965        let result = idx.search_region(&Region::Shell(shell));
966
967        let ids: Vec<&str> = result.items.iter().map(|i| i.id.as_str()).collect();
968        assert!(
969            ids.contains(&"medium"),
970            "medium (mag=1.0) should be in [0.5, 2.0]"
971        );
972        assert!(
973            !ids.contains(&"large"),
974            "large (mag=5.0) should not be in [0.5, 2.0]"
975        );
976    }
977
978    #[test]
979    fn empty_index() {
980        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
981        let idx = EmbeddingIndex::builder(rp).build();
982
983        assert!(idx.is_empty());
984        assert_eq!(idx.len(), 0);
985        assert!(idx.get("x").is_none());
986
987        let results = idx.search_nearest(&emb(&[1.0; 5]), 5);
988        assert!(results.is_empty());
989    }
990
991    #[test]
992    fn projection_accessor() {
993        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
994        let idx = EmbeddingIndex::builder(rp).build();
995        assert_eq!(idx.projection().dimensionality(), 5);
996    }
997
998    // --- SemanticQuery ---
999
1000    #[test]
1001    fn above_similarity_creates_cap() {
1002        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1003        let region = SemanticQuery::above_similarity(&emb(&[1.0; 5]), &rp, 0.8);
1004        assert!(matches!(region, Region::Cap(_)));
1005    }
1006
1007    #[test]
1008    fn within_angle_creates_cap() {
1009        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1010        let region = SemanticQuery::within_angle(&emb(&[1.0; 5]), &rp, 0.5);
1011        assert!(matches!(region, Region::Cap(_)));
1012    }
1013
1014    #[test]
1015    fn in_shell_creates_shell() {
1016        let region = SemanticQuery::in_shell(1.0, 5.0);
1017        assert!(matches!(region, Region::Shell(_)));
1018    }
1019
1020    #[test]
1021    fn similar_in_shell_creates_intersection() {
1022        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1023        let region = SemanticQuery::similar_in_shell(&emb(&[1.0; 5]), &rp, 0.7, 1.0, 5.0);
1024
1025        match region {
1026            Region::Intersection(parts) => {
1027                assert_eq!(parts.len(), 2);
1028                assert!(matches!(parts[0], Region::Cap(_)));
1029                assert!(matches!(parts[1], Region::Shell(_)));
1030            }
1031            other => panic!("expected Intersection, got {other:?}"),
1032        }
1033    }
1034
1035    #[test]
1036    fn semantic_query_region_used_in_index() {
1037        let corpus = test_corpus();
1038        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
1039        let projection_clone = pca.clone();
1040        let mut idx = EmbeddingIndex::builder(pca)
1041            .theta_divisions(4)
1042            .phi_divisions(3)
1043            .build();
1044
1045        for (i, e) in corpus.iter().enumerate() {
1046            idx.insert(format!("item-{i}"), e);
1047        }
1048
1049        let query = emb(&[1.0, 0.0, 0.0, 0.05, 0.0]);
1050        let region = SemanticQuery::above_similarity(&query, &projection_clone, 0.5);
1051        let result = idx.search_region(&region);
1052
1053        for item in &result.items {
1054            assert!(region.contains(item.position()));
1055        }
1056    }
1057}