Skip to main content

sphereql_embed/
query.rs

1use std::collections::{BinaryHeap, HashMap};
2use std::sync::{Arc, Mutex};
3
4use sphereql_core::*;
5use sphereql_index::*;
6
7use crate::category::BridgeClassification;
8use crate::projection::Projection;
9use crate::types::{Embedding, ProjectedPoint};
10
11/// k-NN adjacency snapshot cached between calls to
12/// [`EmbeddingIndex::concept_path`] and its bridged variant.
13///
14/// The graph is undirected and built for a specific `k`. Different `k`
15/// invalidates; so does any `&mut self` mutation on the index.
16struct KnnCache {
17    k: usize,
18    adj: Arc<Vec<Vec<(usize, f64)>>>,
19}
20
21#[derive(Debug, Clone)]
22pub struct EmbeddingItem {
23    pub id: String,
24    pub position: SphericalPoint,
25    pub original_magnitude: f64,
26    /// Rich projection metadata. `None` for items inserted via `insert()` (legacy path).
27    pub projected: Option<ProjectedPoint>,
28}
29
30impl SpatialItem for EmbeddingItem {
31    type Id = String;
32    fn id(&self) -> &String {
33        &self.id
34    }
35    fn position(&self) -> &SphericalPoint {
36        &self.position
37    }
38}
39
40impl EmbeddingItem {
41    /// Certainty of this point's projection. Falls back to 1.0 if no rich metadata.
42    pub fn certainty(&self) -> f64 {
43        self.projected.map_or(1.0, |p| p.certainty)
44    }
45
46    /// Intensity (pre-normalization magnitude) of the original embedding.
47    pub fn intensity(&self) -> f64 {
48        self.projected
49            .map_or(self.original_magnitude, |p| p.intensity)
50    }
51
52    /// PCA projection magnitude — how strongly this point projects onto
53    /// the 3 principal components. Low values indicate ambiguous points.
54    pub fn projection_magnitude(&self) -> f64 {
55        self.projected.map_or(1.0, |p| p.projection_magnitude)
56    }
57}
58
59#[must_use = "builders must be terminated with `.build()` to construct the index"]
60pub struct EmbeddingIndexBuilder<P> {
61    projection: P,
62    inner: SpatialIndexBuilder,
63}
64
65impl<P: Projection> EmbeddingIndexBuilder<P> {
66    pub fn new(projection: P) -> Self {
67        Self {
68            projection,
69            inner: SpatialIndexBuilder::new(),
70        }
71    }
72
73    pub fn shell_boundary(mut self, r: f64) -> Self {
74        self.inner = self.inner.shell_boundary(r);
75        self
76    }
77
78    pub fn uniform_shells(mut self, count: usize, max_r: f64) -> Self {
79        self.inner = self.inner.uniform_shells(count, max_r);
80        self
81    }
82
83    pub fn theta_divisions(mut self, n: usize) -> Self {
84        self.inner = self.inner.theta_divisions(n);
85        self
86    }
87
88    pub fn phi_divisions(mut self, n: usize) -> Self {
89        self.inner = self.inner.phi_divisions(n);
90        self
91    }
92
93    pub fn build(self) -> EmbeddingIndex<P> {
94        EmbeddingIndex {
95            projection: self.projection,
96            index: self.inner.build(),
97            knn_cache: Mutex::new(None),
98        }
99    }
100}
101
102pub struct EmbeddingIndex<P> {
103    projection: P,
104    index: SpatialIndex<EmbeddingItem>,
105    /// k-NN adjacency cache for `concept_path` and friends. Shared
106    /// behind `Mutex<Option<_>>` so the `&self` query methods can
107    /// memoize across calls; `&mut self` mutations clear it via
108    /// `get_mut`, bypassing the lock.
109    knn_cache: Mutex<Option<KnnCache>>,
110}
111
112impl<P: Projection> EmbeddingIndex<P> {
113    pub fn builder(projection: P) -> EmbeddingIndexBuilder<P> {
114        EmbeddingIndexBuilder::new(projection)
115    }
116
117    pub fn insert(&mut self, id: impl Into<String>, embedding: &Embedding) {
118        let rich = self.projection.project_rich(embedding);
119        self.index.insert(EmbeddingItem {
120            id: id.into(),
121            position: rich.position,
122            original_magnitude: embedding.magnitude(),
123            projected: Some(rich),
124        });
125        self.invalidate_knn_cache();
126    }
127
128    /// Insert with an explicit radial value, overriding the projection's RadialStrategy.
129    /// The angular coordinates (theta, phi) are still determined by the projection.
130    /// Use this for metadata-driven radius: recency scores, importance weights, etc.
131    pub fn insert_with_radius(&mut self, id: impl Into<String>, embedding: &Embedding, r: f64) {
132        let rich = self.projection.project_rich(embedding);
133        let position = SphericalPoint::new_unchecked(r, rich.position.theta, rich.position.phi);
134        self.index.insert(EmbeddingItem {
135            id: id.into(),
136            position,
137            original_magnitude: embedding.magnitude(),
138            projected: Some(ProjectedPoint { position, ..rich }),
139        });
140        self.invalidate_knn_cache();
141    }
142
143    /// Drop any cached k-NN adjacency. Called by every `&mut self`
144    /// mutation that could change the graph. Uses `get_mut` to skip
145    /// the lock when we already hold `&mut self`.
146    fn invalidate_knn_cache(&mut self) {
147        if let Ok(slot) = self.knn_cache.get_mut() {
148            *slot = None;
149        }
150    }
151
152    /// Return the k-NN adjacency snapshot for the given `k`, rebuilding
153    /// only on cache miss.
154    ///
155    /// The shared `Arc` is cheap to clone, so callers can drop the lock
156    /// while they run Dijkstra. Previously `concept_path` rebuilt the
157    /// entire graph on every call — O(n² · k) per query.
158    fn knn_adjacency(&self, items: &[&EmbeddingItem], k: usize) -> Arc<Vec<Vec<(usize, f64)>>> {
159        {
160            let cache = self.knn_cache.lock().expect("knn cache mutex poisoned");
161            if let Some(cached) = cache.as_ref()
162                && cached.k == k
163                && cached.adj.len() == items.len()
164            {
165                return Arc::clone(&cached.adj);
166            }
167        }
168
169        // Miss — build a fresh undirected adjacency. Symmetrize in one
170        // O(E) pass using a HashSet instead of the previous O(n · k²)
171        // linear scan.
172        let n = items.len();
173        let id_to_idx: HashMap<&str, usize> = items
174            .iter()
175            .enumerate()
176            .map(|(i, item)| (item.id.as_str(), i))
177            .collect();
178        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::with_capacity(k); n];
179        let mut seen: std::collections::HashSet<(usize, usize)> =
180            std::collections::HashSet::with_capacity(n * k);
181        for (i, item) in items.iter().enumerate() {
182            let nearest = self.index.nearest(item.position(), k + 1);
183            for result in &nearest {
184                let Some(&j) = id_to_idx.get(result.item.id.as_str()) else {
185                    continue;
186                };
187                if i == j {
188                    continue;
189                }
190                let key = if i < j { (i, j) } else { (j, i) };
191                if seen.insert(key) {
192                    adj[i].push((j, result.distance));
193                    adj[j].push((i, result.distance));
194                }
195            }
196        }
197
198        let adj = Arc::new(adj);
199        let mut cache = self.knn_cache.lock().expect("knn cache mutex poisoned");
200        *cache = Some(KnnCache {
201            k,
202            adj: Arc::clone(&adj),
203        });
204        adj
205    }
206
207    /// Find the k embeddings whose projected directions are closest to the query.
208    pub fn search_nearest(&self, query: &Embedding, k: usize) -> Vec<NearestResult<EmbeddingItem>> {
209        let projected = self.projection.project(query);
210        self.index.nearest(&projected, k)
211    }
212
213    /// Find all embeddings whose projected cosine similarity to the query
214    /// is at least `min_cosine_similarity`.
215    ///
216    /// Internally maps cos(sim) → angular distance and uses `within_distance`.
217    pub fn search_similar(
218        &self,
219        query: &Embedding,
220        min_cosine_similarity: f64,
221    ) -> SpatialQueryResult<EmbeddingItem> {
222        let projected = self.projection.project(query);
223        let max_angle = min_cosine_similarity.clamp(-1.0, 1.0).acos();
224        self.index.within_distance(&projected, max_angle)
225    }
226
227    pub fn search_region(&self, region: &Region) -> SpatialQueryResult<EmbeddingItem> {
228        self.index.query_region(region)
229    }
230
231    pub fn remove(&mut self, id: &str) -> Option<EmbeddingItem> {
232        let removed = self.index.remove(&id.to_string());
233        if removed.is_some() {
234            self.invalidate_knn_cache();
235        }
236        removed
237    }
238
239    pub fn get(&self, id: &str) -> Option<&EmbeddingItem> {
240        self.index.get(&id.to_string())
241    }
242
243    pub fn len(&self) -> usize {
244        self.index.len()
245    }
246
247    pub fn is_empty(&self) -> bool {
248        self.index.is_empty()
249    }
250
251    pub fn projection(&self) -> &P {
252        &self.projection
253    }
254
255    pub fn all_items(&self) -> Vec<&EmbeddingItem> {
256        self.index.all_items()
257    }
258
259    /// Find the shortest semantic path between two items through a k-NN graph.
260    ///
261    /// Builds a k-nearest-neighbor graph over all indexed embeddings, then
262    /// runs Dijkstra's algorithm weighted by angular distance. The resulting
263    /// path traces the chain of closest intermediate concepts connecting
264    /// the source to the target.
265    ///
266    /// The k-NN graph is memoized per `k`: the first call at a given `k`
267    /// builds it in O(n · log n · k) (index-assisted) and every
268    /// subsequent call reuses the snapshot until the index mutates.
269    /// Dijkstra itself is O((n + E) · log n) via a binary heap.
270    pub fn concept_path(&self, source_id: &str, target_id: &str, k: usize) -> Option<ConceptPath> {
271        let items = self.index.all_items();
272        let n = items.len();
273        if n < 2 {
274            return None;
275        }
276
277        let id_to_idx: HashMap<&str, usize> = items
278            .iter()
279            .enumerate()
280            .map(|(i, item)| (item.id.as_str(), i))
281            .collect();
282
283        let source_idx = *id_to_idx.get(source_id)?;
284        let target_idx = *id_to_idx.get(target_id)?;
285
286        let adj = self.knn_adjacency(&items, k);
287
288        // Dijkstra (min-heap via reversed Ord)
289        let mut dist = vec![f64::INFINITY; n];
290        let mut prev: Vec<Option<usize>> = vec![None; n];
291        let mut heap = BinaryHeap::new();
292
293        dist[source_idx] = 0.0;
294        heap.push(DijkstraEntry {
295            dist: 0.0,
296            node: source_idx,
297        });
298
299        while let Some(entry) = heap.pop() {
300            let u = entry.node;
301            if entry.dist > dist[u] {
302                continue;
303            }
304            if u == target_idx {
305                break;
306            }
307            for &(v, w) in &adj[u] {
308                let nd = dist[u] + w;
309                if nd < dist[v] {
310                    dist[v] = nd;
311                    prev[v] = Some(u);
312                    heap.push(DijkstraEntry { dist: nd, node: v });
313                }
314            }
315        }
316
317        if dist[target_idx].is_infinite() {
318            return None;
319        }
320
321        // Reconstruct
322        let mut path = Vec::new();
323        let mut cur = target_idx;
324        loop {
325            let hop_distance = prev[cur]
326                .and_then(|p| adj[p].iter().find(|&&(v, _)| v == cur).map(|&(_, d)| d))
327                .unwrap_or(0.0);
328            path.push(PathStep {
329                id: items[cur].id.clone(),
330                cumulative_distance: dist[cur],
331                hop_distance,
332                category: None,
333                bridge_strength: None,
334            });
335            match prev[cur] {
336                Some(p) => cur = p,
337                None => break,
338            }
339        }
340        path.reverse();
341
342        Some(ConceptPath {
343            total_distance: dist[target_idx],
344            steps: path,
345        })
346    }
347
348    /// Find a semantic path that prefers hops with strong conceptual bridges.
349    ///
350    /// Like [`concept_path`](Self::concept_path), but when a hop crosses a
351    /// category boundary, the edge weight is penalized based on the bridge's
352    /// classification:
353    /// - [`BridgeClassification::Genuine`]: `angular_dist / (strength + 0.1)`
354    /// - [`BridgeClassification::Weak`]: `angular_dist / (strength + 0.01)`
355    /// - [`BridgeClassification::OverlapArtifact`]: `angular_dist * 2.0`
356    ///   (shared-territory bridges are actively discouraged — they aren't
357    ///   real connectors).
358    ///
359    /// - `categories`: maps item ID → category index.
360    /// - `bridge_strengths`: maps `(cat_a, cat_b) → (max_bridge_strength, classification)`.
361    ///   Missing entries are treated as a weak no-bridge (strength 0, Weak).
362    ///
363    /// Same-category hops use raw angular distance.
364    pub fn concept_path_bridged(
365        &self,
366        source_id: &str,
367        target_id: &str,
368        k: usize,
369        categories: &HashMap<&str, usize>,
370        bridge_strengths: &HashMap<(usize, usize), (f64, BridgeClassification)>,
371    ) -> Option<ConceptPath> {
372        let items = self.index.all_items();
373        let n = items.len();
374        if n < 2 {
375            return None;
376        }
377
378        let id_to_idx: HashMap<&str, usize> = items
379            .iter()
380            .enumerate()
381            .map(|(i, item)| (item.id.as_str(), i))
382            .collect();
383
384        let source_idx = *id_to_idx.get(source_id)?;
385        let target_idx = *id_to_idx.get(target_id)?;
386
387        // Look up category for each item index
388        let item_cats: Vec<Option<usize>> = items
389            .iter()
390            .map(|item| categories.get(item.id.as_str()).copied())
391            .collect();
392
393        // Reuse the cached raw-angular k-NN adjacency; bridge-aware
394        // weights are derived per edge at Dijkstra time. The previous
395        // implementation materialized a second n-row Vec<Vec<...>>
396        // that duplicated the neighborhood structure — this version
397        // shares the angular graph with `concept_path`.
398        let adj = self.knn_adjacency(&items, k);
399
400        // Dijkstra on effective weights
401        let mut dist = vec![f64::INFINITY; n];
402        let mut prev: Vec<Option<usize>> = vec![None; n];
403        let mut heap = BinaryHeap::new();
404
405        dist[source_idx] = 0.0;
406        heap.push(DijkstraEntry {
407            dist: 0.0,
408            node: source_idx,
409        });
410
411        while let Some(entry) = heap.pop() {
412            let u = entry.node;
413            if entry.dist > dist[u] {
414                continue;
415            }
416            if u == target_idx {
417                break;
418            }
419            for &(v, raw_d) in &adj[u] {
420                let (w, _) = cross_category_weight(raw_d, &item_cats, u, v, bridge_strengths);
421                let nd = dist[u] + w;
422                if nd < dist[v] {
423                    dist[v] = nd;
424                    prev[v] = Some(u);
425                    heap.push(DijkstraEntry { dist: nd, node: v });
426                }
427            }
428        }
429
430        if dist[target_idx].is_infinite() {
431            return None;
432        }
433
434        // Reconstruct with bridge metadata
435        let mut path = Vec::new();
436        let mut cur = target_idx;
437        loop {
438            let edge_info = prev[cur].and_then(|p| {
439                adj[p].iter().find(|&&(v, _)| v == cur).map(|&(_, raw_d)| {
440                    let (_, bs) =
441                        cross_category_weight(raw_d, &item_cats, p, cur, bridge_strengths);
442                    (raw_d, bs)
443                })
444            });
445            let hop_distance = edge_info.map_or(0.0, |(d, _)| d);
446            let bridge_str = edge_info.and_then(|(_, bs)| bs);
447
448            path.push(PathStep {
449                id: items[cur].id.clone(),
450                cumulative_distance: dist[cur],
451                hop_distance,
452                category: item_cats[cur],
453                bridge_strength: bridge_str,
454            });
455            match prev[cur] {
456                Some(p) => cur = p,
457                None => break,
458            }
459        }
460        path.reverse();
461
462        Some(ConceptPath {
463            total_distance: dist[target_idx],
464            steps: path,
465        })
466    }
467}
468
469// --- Concept path types ---
470
471#[derive(Debug, Clone)]
472pub struct ConceptPath {
473    pub steps: Vec<PathStep>,
474    pub total_distance: f64,
475}
476
477#[derive(Debug, Clone)]
478pub struct PathStep {
479    pub id: String,
480    pub cumulative_distance: f64,
481    /// Angular distance of this hop (0.0 for the first step).
482    pub hop_distance: f64,
483    /// Category index of this item (None if no category info was provided).
484    pub category: Option<usize>,
485    /// Bridge strength used on the hop *to* this step (None for same-category
486    /// hops or the first step). Present only when `concept_path_bridged` is used.
487    pub bridge_strength: Option<f64>,
488}
489
490#[derive(PartialEq)]
491struct DijkstraEntry {
492    dist: f64,
493    node: usize,
494}
495
496// Safety: dist values come from cosine_proxy on unit vectors, never NaN in practice.
497// Ord impl uses unwrap_or(Equal) as a NaN guard.
498impl Eq for DijkstraEntry {}
499
500impl PartialOrd for DijkstraEntry {
501    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
502        Some(self.cmp(other))
503    }
504}
505
506impl Ord for DijkstraEntry {
507    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
508        // Reversed: BinaryHeap is a max-heap, so smaller dist = higher priority
509        other
510            .dist
511            .partial_cmp(&self.dist)
512            .unwrap_or(std::cmp::Ordering::Equal)
513    }
514}
515
516// --- Slicing manifold ---
517
518/// A 2D plane fitted through the 3D projected point cloud that captures
519/// the maximum variance. Found by PCA on the Cartesian coordinates of
520/// the projected embeddings.
521///
522/// The plane is defined by:
523/// - `centroid`: the mean of all 3D points
524/// - `basis_u`, `basis_v`: orthonormal vectors spanning the plane (directions of max variance)
525/// - `normal`: vector perpendicular to the plane (direction of minimum variance)
526#[derive(Debug, Clone)]
527pub struct SlicingManifold {
528    pub centroid: [f64; 3],
529    pub normal: [f64; 3],
530    pub basis_u: [f64; 3],
531    pub basis_v: [f64; 3],
532    pub variance_ratio: f64,
533}
534
535impl SlicingManifold {
536    /// Fit the optimal slicing plane to a set of 3D points.
537    /// Each point is (x, y, z) in Cartesian coordinates.
538    ///
539    /// Callers must supply at least 3 points; `fit_local` guarantees this via
540    /// `k.max(3)`. Passing fewer is a programming error, not a runtime failure.
541    pub fn fit(points: &[[f64; 3]]) -> Self {
542        let n = points.len() as f64;
543        debug_assert!(n >= 3.0, "need at least 3 points to fit a plane");
544
545        // Centroid
546        let mut c = [0.0; 3];
547        for p in points {
548            for i in 0..3 {
549                c[i] += p[i];
550            }
551        }
552        for ci in &mut c {
553            *ci /= n;
554        }
555
556        // 3×3 covariance matrix (symmetric)
557        let mut cov = [[0.0f64; 3]; 3];
558        for p in points {
559            let d = [p[0] - c[0], p[1] - c[1], p[2] - c[2]];
560            for i in 0..3 {
561                for j in 0..3 {
562                    cov[i][j] += d[i] * d[j];
563                }
564            }
565        }
566        for row in &mut cov {
567            for v in row.iter_mut() {
568                *v /= n;
569            }
570        }
571
572        // Eigendecomposition of 3×3 symmetric matrix via Jacobi iteration
573        let (eigenvalues, eigenvectors) = eigen_symmetric_3x3(&cov);
574
575        // eigenvalues are sorted descending: λ₀ ≥ λ₁ ≥ λ₂
576        // basis_u = eigenvector of λ₀, basis_v = eigenvector of λ₁, normal = eigenvector of λ₂
577        let total_var = eigenvalues[0] + eigenvalues[1] + eigenvalues[2];
578        let variance_ratio = if total_var > 0.0 {
579            (eigenvalues[0] + eigenvalues[1]) / total_var
580        } else {
581            1.0
582        };
583
584        Self {
585            centroid: c,
586            normal: eigenvectors[2],
587            basis_u: eigenvectors[0],
588            basis_v: eigenvectors[1],
589            variance_ratio,
590        }
591    }
592
593    /// Project a 3D point onto the plane, returning (u, v) coordinates.
594    pub fn project_2d(&self, point: &[f64; 3]) -> (f64, f64) {
595        let d = [
596            point[0] - self.centroid[0],
597            point[1] - self.centroid[1],
598            point[2] - self.centroid[2],
599        ];
600        let u = d[0] * self.basis_u[0] + d[1] * self.basis_u[1] + d[2] * self.basis_u[2];
601        let v = d[0] * self.basis_v[0] + d[1] * self.basis_v[1] + d[2] * self.basis_v[2];
602        (u, v)
603    }
604
605    /// Signed distance from the plane (positive = same side as normal).
606    pub fn distance(&self, point: &[f64; 3]) -> f64 {
607        let d = [
608            point[0] - self.centroid[0],
609            point[1] - self.centroid[1],
610            point[2] - self.centroid[2],
611        ];
612        d[0] * self.normal[0] + d[1] * self.normal[1] + d[2] * self.normal[2]
613    }
614
615    /// Fit a local manifold around a query point using its k nearest neighbors.
616    ///
617    /// The local plane captures the shape of the semantic neighborhood:
618    /// - If variance_ratio ≈ 1.0, the neighborhood is flat (concepts spread in a plane)
619    /// - If variance_ratio ≈ 0.67, concepts are uniformly distributed (spherical)
620    /// - The normal direction reveals which semantic axis is least relevant locally
621    ///
622    /// This enables directional search narrowing: once you know the local geometry,
623    /// you can restrict subsequent queries to the dominant plane, cutting the
624    /// effective search dimensionality from 3D to 2D in that region.
625    pub fn fit_local(query: &[f64; 3], all_points: &[[f64; 3]], k: usize) -> Self {
626        let mut dists: Vec<(usize, f64)> = all_points
627            .iter()
628            .enumerate()
629            .map(|(i, p)| (i, dist3(query, p)))
630            .collect();
631        // `total_cmp` gives a total order over all f64 including NaN
632        // (which sorts to the end). Previously `.partial_cmp().unwrap()`
633        // panicked on NaN — and NaN is reachable whenever `all_points`
634        // contains a degenerate entry from a lossy projection, making
635        // this one of the few query-path panic sites in the crate.
636        dists.sort_by(|a, b| a.1.total_cmp(&b.1));
637
638        let neighborhood: Vec<[f64; 3]> = dists
639            .iter()
640            .take(k.max(3))
641            .map(|&(i, _)| all_points[i])
642            .collect();
643
644        Self::fit(&neighborhood)
645    }
646}
647
648/// Eigendecomposition of a 3×3 symmetric matrix via Jacobi rotations.
649/// Returns (eigenvalues_desc, eigenvectors_desc) sorted by decreasing eigenvalue.
650fn eigen_symmetric_3x3(m: &[[f64; 3]; 3]) -> ([f64; 3], [[f64; 3]; 3]) {
651    let mut a = *m;
652    let mut v = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]; // eigenvector matrix
653
654    #[allow(clippy::needless_range_loop)]
655    for _ in 0..50 {
656        // Find largest off-diagonal element
657        let mut p = 0;
658        let mut q = 1;
659        let mut max_val = a[0][1].abs();
660        for i in 0..3 {
661            for j in (i + 1)..3 {
662                if a[i][j].abs() > max_val {
663                    max_val = a[i][j].abs();
664                    p = i;
665                    q = j;
666                }
667            }
668        }
669        if max_val < 1e-15 {
670            break;
671        }
672
673        // Jacobi rotation to zero out a[p][q]
674        let theta = if (a[p][p] - a[q][q]).abs() < 1e-30 {
675            std::f64::consts::FRAC_PI_4
676        } else {
677            0.5 * (2.0 * a[p][q] / (a[p][p] - a[q][q])).atan()
678        };
679        let c = theta.cos();
680        let s = theta.sin();
681
682        // Rotate a ← GᵀaG
683        let mut new_a = a;
684        for i in 0..3 {
685            new_a[i][p] = c * a[i][p] + s * a[i][q];
686            new_a[i][q] = -s * a[i][p] + c * a[i][q];
687        }
688        let snapshot = new_a;
689        for j in 0..3 {
690            new_a[p][j] = c * snapshot[p][j] + s * snapshot[q][j];
691            new_a[q][j] = -s * snapshot[p][j] + c * snapshot[q][j];
692        }
693        new_a[p][q] = 0.0;
694        new_a[q][p] = 0.0;
695        a = new_a;
696
697        // Rotate eigenvectors: V ← VG
698        let mut new_v = v;
699        for i in 0..3 {
700            new_v[i][p] = c * v[i][p] + s * v[i][q];
701            new_v[i][q] = -s * v[i][p] + c * v[i][q];
702        }
703        v = new_v;
704    }
705
706    let eigenvalues = [a[0][0], a[1][1], a[2][2]];
707
708    // Sort by descending eigenvalue. Jacobi on a real symmetric matrix
709    // never produces NaN eigenvalues, so total_cmp and partial_cmp are
710    // equivalent here; total_cmp is used to satisfy clippy.
711    let mut order = [0usize, 1, 2];
712    order.sort_by(|&a, &b| eigenvalues[b].total_cmp(&eigenvalues[a]));
713
714    let sorted_vals = [
715        eigenvalues[order[0]],
716        eigenvalues[order[1]],
717        eigenvalues[order[2]],
718    ];
719    // Eigenvectors are columns of v
720    let sorted_vecs = [
721        [v[0][order[0]], v[1][order[0]], v[2][order[0]]],
722        [v[0][order[1]], v[1][order[1]], v[2][order[1]]],
723        [v[0][order[2]], v[1][order[2]], v[2][order[2]]],
724    ];
725
726    (sorted_vals, sorted_vecs)
727}
728
729// --- Concept Globs (spherical k-means + silhouette auto-k) ---
730
731/// A cluster of semantically related embeddings in the projected 3D space.
732#[derive(Debug, Clone)]
733pub struct ConceptGlob {
734    pub id: usize,
735    pub centroid: [f64; 3],
736    pub member_ids: Vec<String>,
737    pub member_distances: Vec<f64>,
738    pub radius: f64,
739}
740
741/// Result of glob detection: the set of all globs plus quality metrics.
742#[derive(Debug, Clone)]
743pub struct GlobResult {
744    pub globs: Vec<ConceptGlob>,
745    pub k: usize,
746    pub silhouette: f64,
747}
748
749impl GlobResult {
750    /// Detect concept globs from 3D projected points.
751    ///
752    /// If `k` is `Some`, uses that many clusters.
753    /// If `None`, auto-selects k ∈ [2, max_k] by maximizing the silhouette score.
754    pub fn detect(points: &[[f64; 3]], ids: &[String], k: Option<usize>, max_k: usize) -> Self {
755        let n = points.len();
756        // Both slices are built from the same pipeline corpus, so
757        // mismatched lengths or an empty corpus are caller bugs.
758        debug_assert_eq!(n, ids.len());
759        debug_assert!(n >= 2, "need at least 2 points for clustering");
760
761        // Silhouette-driven auto-search needs at least k = 2 to be well-defined.
762        // Clamping here means callers can pass any non-negative max_k without
763        // tripping the `k.clamp(2, max_k)` panic when max < min.
764        let max_k = max_k.max(2).min(n);
765
766        if let Some(k) = k {
767            let k = k.clamp(2, max_k);
768            let (assignments, silhouette) = kmeans_3d(points, k);
769            let globs = build_globs(points, ids, &assignments, k);
770            return Self {
771                globs,
772                k,
773                silhouette,
774            };
775        }
776
777        // Auto-detect: try k = 2..=max_k, pick best silhouette
778        let mut best_k = 2;
779        let mut best_sil = f64::NEG_INFINITY;
780        let mut best_assignments = vec![0usize; n];
781
782        for trial_k in 2..=max_k {
783            let (assignments, sil) = kmeans_3d(points, trial_k);
784            if sil > best_sil {
785                best_sil = sil;
786                best_k = trial_k;
787                best_assignments = assignments;
788            }
789        }
790
791        let globs = build_globs(points, ids, &best_assignments, best_k);
792        Self {
793            globs,
794            k: best_k,
795            silhouette: best_sil,
796        }
797    }
798}
799
800fn kmeans_3d(points: &[[f64; 3]], k: usize) -> (Vec<usize>, f64) {
801    let n = points.len();
802    let max_iter = 50;
803
804    // Init: spread initial centers evenly across the point set
805    let mut centers: Vec<[f64; 3]> = (0..k).map(|i| points[i * n / k]).collect();
806
807    let mut assignments = vec![0usize; n];
808
809    for _ in 0..max_iter {
810        let mut changed = false;
811
812        // Assign by angular distance (direction, not position)
813        for (i, p) in points.iter().enumerate() {
814            let mut best = 0;
815            let mut best_d = f64::MAX;
816            for (j, c) in centers.iter().enumerate() {
817                let d = angular_dist3(p, c);
818                if d < best_d {
819                    best_d = d;
820                    best = j;
821                }
822            }
823            if assignments[i] != best {
824                assignments[i] = best;
825                changed = true;
826            }
827        }
828
829        if !changed {
830            break;
831        }
832
833        // Update centers: mean direction (Euclidean mean of unit vectors, then normalize).
834        // This is the Fréchet mean on S² for concentrated clusters.
835        let mut sums = vec![[0.0f64; 3]; k];
836        let mut counts = vec![0usize; k];
837        for (i, &a) in assignments.iter().enumerate() {
838            let norm_p = normalize3(&points[i]);
839            for (d, &np) in norm_p.iter().enumerate() {
840                sums[a][d] += np;
841            }
842            counts[a] += 1;
843        }
844        for j in 0..k {
845            if counts[j] > 0 {
846                centers[j] = normalize3(&sums[j]);
847            }
848        }
849    }
850
851    let sil = silhouette_score(points, &assignments, k);
852    (assignments, sil)
853}
854
855fn silhouette_score(points: &[[f64; 3]], assignments: &[usize], k: usize) -> f64 {
856    let n = points.len();
857    if n <= 1 || k <= 1 {
858        return 0.0;
859    }
860
861    let mut total = 0.0;
862    for i in 0..n {
863        let ci = assignments[i];
864
865        // a(i): mean angular dist to same-cluster members
866        let mut a_sum = 0.0;
867        let mut a_cnt = 0;
868        for j in 0..n {
869            if j != i && assignments[j] == ci {
870                a_sum += angular_dist3(&points[i], &points[j]);
871                a_cnt += 1;
872            }
873        }
874        let a = if a_cnt > 0 { a_sum / a_cnt as f64 } else { 0.0 };
875
876        // b(i): min mean angular dist to any other cluster
877        let mut b = f64::MAX;
878        for ck in 0..k {
879            if ck == ci {
880                continue;
881            }
882            let mut b_sum = 0.0;
883            let mut b_cnt = 0;
884            for j in 0..n {
885                if assignments[j] == ck {
886                    b_sum += angular_dist3(&points[i], &points[j]);
887                    b_cnt += 1;
888                }
889            }
890            if b_cnt > 0 {
891                b = b.min(b_sum / b_cnt as f64);
892            }
893        }
894        if b == f64::MAX {
895            b = 0.0;
896        }
897
898        let denom = a.max(b);
899        total += if denom > 0.0 { (b - a) / denom } else { 0.0 };
900    }
901
902    total / n as f64
903}
904
905fn build_globs(
906    points: &[[f64; 3]],
907    ids: &[String],
908    assignments: &[usize],
909    k: usize,
910) -> Vec<ConceptGlob> {
911    let mut globs = Vec::with_capacity(k);
912
913    for cluster_id in 0..k {
914        let member_indices: Vec<usize> = assignments
915            .iter()
916            .enumerate()
917            .filter(|&(_, &a)| a == cluster_id)
918            .map(|(i, _)| i)
919            .collect();
920
921        if member_indices.is_empty() {
922            continue;
923        }
924
925        // Centroid: mean direction (normalize to get angular centroid)
926        let mut centroid = [0.0; 3];
927        for &i in &member_indices {
928            let norm_p = normalize3(&points[i]);
929            for (d, c) in centroid.iter_mut().enumerate() {
930                *c += norm_p[d];
931            }
932        }
933        centroid = normalize3(&centroid);
934
935        // Member angular distances from centroid
936        let member_distances: Vec<f64> = member_indices
937            .iter()
938            .map(|&i| angular_dist3(&points[i], &centroid))
939            .collect();
940
941        let radius = member_distances.iter().cloned().fold(0.0f64, f64::max);
942
943        let member_ids: Vec<String> = member_indices.iter().map(|&i| ids[i].clone()).collect();
944
945        globs.push(ConceptGlob {
946            id: cluster_id,
947            centroid,
948            member_ids,
949            member_distances,
950            radius,
951        });
952    }
953
954    globs
955}
956
957fn dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
958    let dx = a[0] - b[0];
959    let dy = a[1] - b[1];
960    let dz = a[2] - b[2];
961    (dx * dx + dy * dy + dz * dz).sqrt()
962}
963
964/// Angular distance between two 3D points treated as direction vectors.
965/// Returns the angle in radians [0, π]. Ignores magnitude differences.
966fn angular_dist3(a: &[f64; 3], b: &[f64; 3]) -> f64 {
967    let dot = a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
968    let ma = (a[0] * a[0] + a[1] * a[1] + a[2] * a[2]).sqrt();
969    let mb = (b[0] * b[0] + b[1] * b[1] + b[2] * b[2]).sqrt();
970    let denom = ma * mb;
971    if denom < f64::EPSILON {
972        return 0.0;
973    }
974    (dot / denom).clamp(-1.0, 1.0).acos()
975}
976
977/// Normalize a 3D vector to unit length. Returns zero vector if input is zero.
978fn normalize3(v: &[f64; 3]) -> [f64; 3] {
979    let mag = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
980    if mag < f64::EPSILON {
981        return [0.0; 3];
982    }
983    [v[0] / mag, v[1] / mag, v[2] / mag]
984}
985
986/// Compute the effective edge weight for a hop between two items.
987///
988/// Same-category hops use raw angular distance. Cross-category hops are
989/// weighted by the bridge's quality classification:
990/// - `Genuine`:         `angular_dist / (strength + 0.1)`
991/// - `Weak`:            `angular_dist / (strength + 0.01)`  (harsher penalty)
992/// - `OverlapArtifact`: `angular_dist * 2.0`  (actively discouraged — the
993///   two categories share territory, so the "bridge" isn't real)
994///
995/// Unknown cross-category pairs (no entry in `bridge_strengths`) are treated
996/// as `Weak` with strength 0.
997///
998/// Returns (effective_weight, Option<bridge_strength>).
999/// The bridge_strength is None for same-category hops.
1000fn cross_category_weight(
1001    angular_dist: f64,
1002    item_cats: &[Option<usize>],
1003    i: usize,
1004    j: usize,
1005    bridge_strengths: &HashMap<(usize, usize), (f64, BridgeClassification)>,
1006) -> (f64, Option<f64>) {
1007    match (item_cats[i], item_cats[j]) {
1008        (Some(ci), Some(cj)) if ci != cj => {
1009            let (strength, classification) = bridge_strengths
1010                .get(&(ci, cj))
1011                .or_else(|| bridge_strengths.get(&(cj, ci)))
1012                .copied()
1013                .unwrap_or((0.0, BridgeClassification::Weak));
1014            let weight = match classification {
1015                BridgeClassification::Genuine => angular_dist / (strength + 0.1),
1016                BridgeClassification::OverlapArtifact => angular_dist * 2.0,
1017                BridgeClassification::Weak => angular_dist / (strength + 0.01),
1018            };
1019            (weight, Some(strength))
1020        }
1021        _ => (angular_dist, None),
1022    }
1023}
1024
1025/// Builds SphereQL [`Region`]s from semantic constraints on embeddings.
1026pub struct SemanticQuery;
1027
1028impl SemanticQuery {
1029    /// Spherical cap: all points within `max_angular_distance` radians of the query.
1030    pub fn within_angle<P: Projection>(
1031        query: &Embedding,
1032        projection: &P,
1033        max_angular_distance: f64,
1034    ) -> Region {
1035        let point = projection.project(query);
1036        // Clamp to (0, π] so Cap::new never sees an out-of-range angle.
1037        let half_angle = max_angular_distance.clamp(1e-10, std::f64::consts::PI);
1038        Region::Cap(
1039            Cap::new(
1040                SphericalPoint::new_unchecked(1.0, point.theta, point.phi),
1041                half_angle,
1042            )
1043            // Invariant: half_angle is in (0, π] after the clamp above.
1044            .expect("clamped half_angle is always a valid Cap angle"),
1045        )
1046    }
1047
1048    /// Spherical cap from a cosine similarity threshold.
1049    /// cos_sim >= threshold ↔ angular_distance <= arccos(threshold).
1050    pub fn above_similarity<P: Projection>(
1051        query: &Embedding,
1052        projection: &P,
1053        min_similarity: f64,
1054    ) -> Region {
1055        let half_angle = min_similarity.clamp(-1.0, 1.0).acos();
1056        Self::within_angle(query, projection, half_angle)
1057    }
1058
1059    /// Radial shell: embeddings whose projected radius falls in [inner, outer].
1060    ///
1061    /// # Panics
1062    ///
1063    /// Panics if `inner > outer` or either bound is negative. Both are
1064    /// caller contracts; use `Shell::new` directly if you need a `Result`.
1065    pub fn in_shell(inner: f64, outer: f64) -> Region {
1066        Region::Shell(
1067            Shell::new(inner, outer).expect("inner must be <= outer and both non-negative"),
1068        )
1069    }
1070
1071    /// Intersection of a similarity cap with a radial shell.
1072    /// "Semantically similar AND within a magnitude/metadata range."
1073    pub fn similar_in_shell<P: Projection>(
1074        query: &Embedding,
1075        projection: &P,
1076        min_similarity: f64,
1077        shell_inner: f64,
1078        shell_outer: f64,
1079    ) -> Region {
1080        Region::intersection(vec![
1081            Self::above_similarity(query, projection, min_similarity),
1082            Self::in_shell(shell_inner, shell_outer),
1083        ])
1084    }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089    use super::*;
1090    use crate::projection::{PcaProjection, RandomProjection};
1091    use crate::types::RadialStrategy;
1092    use sphereql_core::angular_distance;
1093
1094    fn emb(vals: &[f64]) -> Embedding {
1095        Embedding::new(vals.to_vec())
1096    }
1097
1098    fn test_corpus() -> Vec<Embedding> {
1099        vec![
1100            emb(&[1.0, 0.0, 0.0, 0.1, 0.0]),
1101            emb(&[0.0, 1.0, 0.0, 0.0, 0.1]),
1102            emb(&[0.0, 0.0, 1.0, 0.1, 0.0]),
1103            emb(&[1.0, 1.0, 0.0, 0.05, 0.05]),
1104            emb(&[-1.0, 0.0, 0.0, -0.1, 0.0]),
1105            emb(&[0.0, -1.0, 0.0, 0.0, -0.1]),
1106        ]
1107    }
1108
1109    // --- EmbeddingIndex ---
1110
1111    #[test]
1112    fn insert_and_get() {
1113        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1114        let mut idx = EmbeddingIndex::builder(rp)
1115            .theta_divisions(4)
1116            .phi_divisions(3)
1117            .build();
1118
1119        idx.insert("a", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1120        idx.insert("b", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1121
1122        assert_eq!(idx.len(), 2);
1123        assert!(!idx.is_empty());
1124        assert!(idx.get("a").is_some());
1125        assert!(idx.get("b").is_some());
1126        assert!(idx.get("c").is_none());
1127    }
1128
1129    #[test]
1130    fn remove() {
1131        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1132        let mut idx = EmbeddingIndex::builder(rp).build();
1133
1134        idx.insert("a", &emb(&[1.0; 5]));
1135        assert_eq!(idx.len(), 1);
1136
1137        let removed = idx.remove("a");
1138        assert!(removed.is_some());
1139        assert_eq!(removed.unwrap().id, "a");
1140        assert_eq!(idx.len(), 0);
1141        assert!(idx.get("a").is_none());
1142    }
1143
1144    #[test]
1145    fn remove_nonexistent() {
1146        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1147        let mut idx = EmbeddingIndex::builder(rp).build();
1148        assert!(idx.remove("nope").is_none());
1149    }
1150
1151    #[test]
1152    fn search_nearest_returns_sorted() {
1153        let corpus = test_corpus();
1154        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1155        let mut idx = EmbeddingIndex::builder(pca)
1156            .theta_divisions(4)
1157            .phi_divisions(3)
1158            .build();
1159
1160        for (i, e) in corpus.iter().enumerate() {
1161            idx.insert(format!("item-{i}"), e);
1162        }
1163
1164        let query = emb(&[0.95, 0.1, 0.0, 0.05, 0.0]);
1165        let results = idx.search_nearest(&query, 3);
1166
1167        assert_eq!(results.len(), 3);
1168        assert!(results[0].distance <= results[1].distance);
1169        assert!(results[1].distance <= results[2].distance);
1170    }
1171
1172    #[test]
1173    fn search_similar_respects_threshold() {
1174        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1175        let mut idx = EmbeddingIndex::builder(rp)
1176            .theta_divisions(4)
1177            .phi_divisions(3)
1178            .build();
1179
1180        idx.insert("close_a", &emb(&[1.0, 0.1, 0.0, 0.0, 0.0]));
1181        idx.insert("close_b", &emb(&[0.9, 0.2, 0.0, 0.0, 0.0]));
1182        idx.insert("far", &emb(&[-1.0, 0.0, 0.0, 0.0, 0.0]));
1183
1184        let query = emb(&[1.0, 0.0, 0.0, 0.0, 0.0]);
1185        let projected_query = idx.projection().project(&query);
1186        let result = idx.search_similar(&query, 0.5);
1187
1188        let max_angle = 0.5_f64.acos();
1189        for item in &result.items {
1190            let d = angular_distance(&projected_query, item.position());
1191            assert!(d <= max_angle + 1e-10, "item {} too far: {d}", item.id);
1192        }
1193    }
1194
1195    #[test]
1196    fn insert_with_radius_overrides() {
1197        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1198        let mut idx = EmbeddingIndex::builder(rp).build();
1199
1200        idx.insert_with_radius("custom", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]), 42.0);
1201        let item = idx.get("custom").unwrap();
1202        assert!((item.position.r - 42.0).abs() < 1e-12);
1203    }
1204
1205    #[test]
1206    fn original_magnitude_stored() {
1207        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1208        let mut idx = EmbeddingIndex::builder(rp).build();
1209
1210        let e = emb(&[3.0, 4.0, 0.0, 0.0, 0.0]);
1211        idx.insert("vec", &e);
1212        let item = idx.get("vec").unwrap();
1213        assert!((item.original_magnitude - 5.0).abs() < 1e-10);
1214    }
1215
1216    #[test]
1217    fn magnitude_radial_with_shell_query() {
1218        let corpus = test_corpus();
1219        let pca = PcaProjection::fit(&corpus, RadialStrategy::Magnitude).unwrap();
1220        let mut idx = EmbeddingIndex::builder(pca)
1221            .uniform_shells(5, 10.0)
1222            .theta_divisions(4)
1223            .phi_divisions(3)
1224            .build();
1225
1226        idx.insert("small", &emb(&[0.1, 0.0, 0.0, 0.0, 0.0]));
1227        idx.insert("medium", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1228        idx.insert("large", &emb(&[5.0, 0.0, 0.0, 0.0, 0.0]));
1229
1230        let shell = Shell::new(0.5, 2.0).unwrap();
1231        let result = idx.search_region(&Region::Shell(shell));
1232
1233        let ids: Vec<&str> = result.items.iter().map(|i| i.id.as_str()).collect();
1234        assert!(
1235            ids.contains(&"medium"),
1236            "medium (mag=1.0) should be in [0.5, 2.0]"
1237        );
1238        assert!(
1239            !ids.contains(&"large"),
1240            "large (mag=5.0) should not be in [0.5, 2.0]"
1241        );
1242    }
1243
1244    #[test]
1245    fn empty_index() {
1246        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1247        let idx = EmbeddingIndex::builder(rp).build();
1248
1249        assert!(idx.is_empty());
1250        assert_eq!(idx.len(), 0);
1251        assert!(idx.get("x").is_none());
1252
1253        let results = idx.search_nearest(&emb(&[1.0; 5]), 5);
1254        assert!(results.is_empty());
1255    }
1256
1257    #[test]
1258    fn projection_accessor() {
1259        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1260        let idx = EmbeddingIndex::builder(rp).build();
1261        assert_eq!(idx.projection().dimensionality(), 5);
1262    }
1263
1264    // --- SemanticQuery ---
1265
1266    #[test]
1267    fn above_similarity_creates_cap() {
1268        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1269        let region = SemanticQuery::above_similarity(&emb(&[1.0; 5]), &rp, 0.8);
1270        assert!(matches!(region, Region::Cap(_)));
1271    }
1272
1273    #[test]
1274    fn within_angle_creates_cap() {
1275        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1276        let region = SemanticQuery::within_angle(&emb(&[1.0; 5]), &rp, 0.5);
1277        assert!(matches!(region, Region::Cap(_)));
1278    }
1279
1280    #[test]
1281    fn in_shell_creates_shell() {
1282        let region = SemanticQuery::in_shell(1.0, 5.0);
1283        assert!(matches!(region, Region::Shell(_)));
1284    }
1285
1286    #[test]
1287    fn similar_in_shell_creates_intersection() {
1288        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1289        let region = SemanticQuery::similar_in_shell(&emb(&[1.0; 5]), &rp, 0.7, 1.0, 5.0);
1290
1291        match region {
1292            Region::Intersection(parts) => {
1293                assert_eq!(parts.len(), 2);
1294                assert!(matches!(parts[0], Region::Cap(_)));
1295                assert!(matches!(parts[1], Region::Shell(_)));
1296            }
1297            other => panic!("expected Intersection, got {other:?}"),
1298        }
1299    }
1300
1301    #[test]
1302    fn semantic_query_region_used_in_index() {
1303        let corpus = test_corpus();
1304        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1305        let projection_clone = pca.clone();
1306        let mut idx = EmbeddingIndex::builder(pca)
1307            .theta_divisions(4)
1308            .phi_divisions(3)
1309            .build();
1310
1311        for (i, e) in corpus.iter().enumerate() {
1312            idx.insert(format!("item-{i}"), e);
1313        }
1314
1315        let query = emb(&[1.0, 0.0, 0.0, 0.05, 0.0]);
1316        let region = SemanticQuery::above_similarity(&query, &projection_clone, 0.5);
1317        let result = idx.search_region(&region);
1318
1319        for item in &result.items {
1320            assert!(region.contains(item.position()));
1321        }
1322    }
1323
1324    // --- concept_path PathStep fields ---
1325
1326    #[test]
1327    fn concept_path_populates_hop_distance() {
1328        let corpus = test_corpus();
1329        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1330        let mut idx = EmbeddingIndex::builder(pca)
1331            .theta_divisions(4)
1332            .phi_divisions(3)
1333            .build();
1334
1335        for (i, e) in corpus.iter().enumerate() {
1336            idx.insert(format!("item-{i}"), e);
1337        }
1338
1339        if let Some(path) = idx.concept_path("item-0", "item-4", 3) {
1340            assert!(path.steps[0].hop_distance == 0.0, "first step has no hop");
1341            for step in &path.steps[1..] {
1342                assert!(
1343                    step.hop_distance > 0.0,
1344                    "subsequent steps should have a hop distance"
1345                );
1346            }
1347            assert!(path.steps.iter().all(|s| s.category.is_none()));
1348            assert!(path.steps.iter().all(|s| s.bridge_strength.is_none()));
1349        }
1350    }
1351
1352    // --- concept_path_bridged ---
1353
1354    #[test]
1355    fn concept_path_bridged_same_category_equals_unbridged() {
1356        let corpus = test_corpus();
1357        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
1358        let mut idx = EmbeddingIndex::builder(pca)
1359            .theta_divisions(4)
1360            .phi_divisions(3)
1361            .build();
1362
1363        for (i, e) in corpus.iter().enumerate() {
1364            idx.insert(format!("item-{i}"), e);
1365        }
1366
1367        // All items in the same category — bridged path should equal unbridged
1368        let categories: HashMap<&str, usize> = (0..6)
1369            .map(|i| {
1370                (
1371                    ["item-0", "item-1", "item-2", "item-3", "item-4", "item-5"][i],
1372                    0,
1373                )
1374            })
1375            .collect();
1376        let bridges = HashMap::new();
1377
1378        let unbridged = idx.concept_path("item-0", "item-3", 3);
1379        let bridged = idx.concept_path_bridged("item-0", "item-3", 3, &categories, &bridges);
1380
1381        match (unbridged, bridged) {
1382            (Some(u), Some(b)) => {
1383                assert_eq!(u.steps.len(), b.steps.len());
1384                assert!((u.total_distance - b.total_distance).abs() < 1e-10);
1385                for step in &b.steps {
1386                    assert_eq!(step.category, Some(0));
1387                    assert!(step.bridge_strength.is_none());
1388                }
1389            }
1390            (None, None) => {} // both unreachable is fine
1391            _ => panic!("bridged and unbridged should agree on reachability"),
1392        }
1393    }
1394
1395    #[test]
1396    fn concept_path_bridged_penalizes_weak_bridges() {
1397        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1398        let mut idx = EmbeddingIndex::builder(rp)
1399            .theta_divisions(4)
1400            .phi_divisions(3)
1401            .build();
1402
1403        // Create two clusters in different categories
1404        // Category 0: items close to [1, 0, 0, 0, 0]
1405        idx.insert("a0", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1406        idx.insert("a1", &emb(&[0.9, 0.1, 0.0, 0.0, 0.0]));
1407        // Category 1: items close to [0, 1, 0, 0, 0]
1408        idx.insert("b0", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1409        idx.insert("b1", &emb(&[0.1, 0.9, 0.0, 0.0, 0.0]));
1410
1411        let mut categories: HashMap<&str, usize> = HashMap::new();
1412        categories.insert("a0", 0);
1413        categories.insert("a1", 0);
1414        categories.insert("b0", 1);
1415        categories.insert("b1", 1);
1416
1417        // Weak bridge between categories
1418        let mut weak_bridges = HashMap::new();
1419        weak_bridges.insert((0, 1), (0.05, BridgeClassification::Weak));
1420
1421        // Strong bridge between categories
1422        let mut strong_bridges = HashMap::new();
1423        strong_bridges.insert((0, 1), (0.95, BridgeClassification::Genuine));
1424
1425        let weak_path = idx.concept_path_bridged("a0", "b0", 3, &categories, &weak_bridges);
1426        let strong_path = idx.concept_path_bridged("a0", "b0", 3, &categories, &strong_bridges);
1427
1428        // Both should find a path (same topology)
1429        // But weak bridge should have higher total_distance
1430        if let (Some(weak), Some(strong)) = (weak_path, strong_path) {
1431            assert!(
1432                weak.total_distance > strong.total_distance,
1433                "weak bridge ({:.4}) should produce higher cost than strong ({:.4})",
1434                weak.total_distance,
1435                strong.total_distance
1436            );
1437        }
1438    }
1439
1440    #[test]
1441    fn concept_path_bridged_populates_bridge_metadata() {
1442        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
1443        let mut idx = EmbeddingIndex::builder(rp)
1444            .theta_divisions(4)
1445            .phi_divisions(3)
1446            .build();
1447
1448        idx.insert("a", &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]));
1449        idx.insert("b", &emb(&[0.5, 0.5, 0.0, 0.0, 0.0]));
1450        idx.insert("c", &emb(&[0.0, 1.0, 0.0, 0.0, 0.0]));
1451
1452        let mut categories: HashMap<&str, usize> = HashMap::new();
1453        categories.insert("a", 0);
1454        categories.insert("b", 0);
1455        categories.insert("c", 1);
1456
1457        let mut bridges = HashMap::new();
1458        bridges.insert((0, 1), (0.7, BridgeClassification::Genuine));
1459
1460        if let Some(path) = idx.concept_path_bridged("a", "c", 3, &categories, &bridges) {
1461            // Each step should have category metadata
1462            for step in &path.steps {
1463                assert!(step.category.is_some());
1464            }
1465            // At least one cross-category hop should have bridge_strength
1466            let has_bridge = path.steps.iter().any(|s| s.bridge_strength.is_some());
1467            assert!(
1468                has_bridge,
1469                "should record bridge strength on cross-category hop"
1470            );
1471        }
1472    }
1473
1474    // --- cross_category_weight ---
1475
1476    #[test]
1477    fn cross_category_weight_same_category() {
1478        let cats = vec![Some(0), Some(0)];
1479        let bridges = HashMap::new();
1480        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1481        assert!((weight - 0.5).abs() < 1e-10);
1482        assert!(bs.is_none());
1483    }
1484
1485    #[test]
1486    fn cross_category_weight_different_categories_no_bridge() {
1487        let cats = vec![Some(0), Some(1)];
1488        let bridges = HashMap::new();
1489        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1490        // Missing entry → treated as Weak with strength 0: 0.5 / (0 + 0.01) = 50.0
1491        assert!((weight - 50.0).abs() < 1e-10);
1492        assert_eq!(bs, Some(0.0));
1493    }
1494
1495    #[test]
1496    fn cross_category_weight_genuine_bridge() {
1497        let cats = vec![Some(0), Some(1)];
1498        let mut bridges = HashMap::new();
1499        bridges.insert((0, 1), (0.9, BridgeClassification::Genuine));
1500        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1501        // Genuine: 0.5 / (0.9 + 0.1) = 0.5
1502        assert!((weight - 0.5).abs() < 1e-10);
1503        assert_eq!(bs, Some(0.9));
1504    }
1505
1506    #[test]
1507    fn cross_category_weight_weak_bridge() {
1508        let cats = vec![Some(0), Some(1)];
1509        let mut bridges = HashMap::new();
1510        bridges.insert((0, 1), (0.3, BridgeClassification::Weak));
1511        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1512        // Weak: 0.5 / (0.3 + 0.01) ≈ 1.6129
1513        assert!((weight - 0.5 / 0.31).abs() < 1e-10);
1514        assert_eq!(bs, Some(0.3));
1515    }
1516
1517    #[test]
1518    fn cross_category_weight_overlap_artifact_discouraged() {
1519        let cats = vec![Some(0), Some(1)];
1520        let mut bridges = HashMap::new();
1521        bridges.insert((0, 1), (0.8, BridgeClassification::OverlapArtifact));
1522        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1523        // OverlapArtifact: 0.5 * 2.0 = 1.0 (penalty, not reward)
1524        assert!((weight - 1.0).abs() < 1e-10);
1525        assert_eq!(bs, Some(0.8));
1526    }
1527
1528    #[test]
1529    fn cross_category_weight_no_category_info() {
1530        let cats = vec![None, Some(1)];
1531        let bridges = HashMap::new();
1532        let (weight, bs) = cross_category_weight(0.5, &cats, 0, 1, &bridges);
1533        assert!((weight - 0.5).abs() < 1e-10);
1534        assert!(bs.is_none());
1535    }
1536}