Skip to main content

sphereql_embed/
mapper.rs

1use sphereql_core::SphericalPoint;
2use sphereql_layout::DimensionMapper;
3
4use crate::projection::Projection;
5use crate::types::Embedding;
6
7/// Adapts any [`Projection`] into a [`DimensionMapper`] for use with
8/// sphereql-layout's layout strategies (Uniform, Clustered, ForceDirected).
9pub struct EmbeddingMapper<P> {
10    projection: P,
11}
12
13impl<P> EmbeddingMapper<P> {
14    pub fn new(projection: P) -> Self {
15        Self { projection }
16    }
17
18    pub fn projection(&self) -> &P {
19        &self.projection
20    }
21}
22
23impl<P: Projection> DimensionMapper for EmbeddingMapper<P> {
24    type Item = Embedding;
25
26    fn map(&self, item: &Embedding) -> SphericalPoint {
27        self.projection.project(item)
28    }
29}
30
31#[cfg(test)]
32mod tests {
33    use super::*;
34    use crate::projection::RandomProjection;
35    use crate::types::RadialStrategy;
36    use sphereql_layout::{LayoutStrategy, UniformLayout};
37
38    #[test]
39    fn mapper_delegates_to_projection() {
40        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
41        let mapper = EmbeddingMapper::new(rp);
42
43        let e = Embedding::new(vec![1.0, 0.0, 0.0, 0.0, 0.0]);
44        let sp = mapper.map(&e);
45        assert!((sp.r - 1.0).abs() < 1e-12);
46    }
47
48    #[test]
49    fn mapper_with_layout_strategy() {
50        let rp = RandomProjection::new(5, RadialStrategy::Fixed(1.0), 42);
51        let mapper = EmbeddingMapper::new(rp);
52
53        let embeddings = vec![
54            Embedding::new(vec![1.0, 0.0, 0.0, 0.0, 0.0]),
55            Embedding::new(vec![0.0, 1.0, 0.0, 0.0, 0.0]),
56            Embedding::new(vec![0.0, 0.0, 1.0, 0.0, 0.0]),
57        ];
58
59        let layout = UniformLayout::new();
60        let result = layout.layout(&embeddings, &mapper);
61        assert_eq!(result.entries.len(), 3);
62
63        for entry in &result.entries {
64            assert!((entry.position.r - 1.0).abs() < 1e-12);
65        }
66    }
67
68    #[test]
69    fn mapper_exposes_projection() {
70        let rp = RandomProjection::new(8, RadialStrategy::Fixed(1.0), 99);
71        let mapper = EmbeddingMapper::new(rp);
72        assert_eq!(mapper.projection().dimensionality(), 8);
73    }
74}