Skip to main content

rig_cat/embedding/
mod.rs

1//! Embedding model trait and vector types.
2
3use comp_cat_rs::effect::io::Io;
4
5use crate::error::Error;
6
7/// A dense vector embedding.
8#[derive(Debug, Clone)]
9pub struct Embedding {
10    values: Vec<f64>,
11}
12
13impl Embedding {
14    #[must_use]
15    pub fn new(values: Vec<f64>) -> Self { Self { values } }
16
17    #[must_use]
18    pub fn values(&self) -> &[f64] { &self.values }
19
20    #[must_use]
21    pub fn dimension(&self) -> usize { self.values.len() }
22
23    /// Cosine similarity between two embeddings.
24    ///
25    /// # Errors
26    ///
27    /// Returns `Error::DimensionMismatch` if dimensions differ.
28    pub fn cosine_similarity(&self, other: &Self) -> Result<f64, Error> {
29        if self.dimension() == other.dimension() {
30            let dot: f64 = self.values.iter()
31                .zip(other.values.iter())
32                .map(|(a, b)| a * b)
33                .sum();
34            let norm_a: f64 = self.values.iter().map(|x| x * x).sum::<f64>().sqrt();
35            let norm_b: f64 = other.values.iter().map(|x| x * x).sum::<f64>().sqrt();
36            let denom = norm_a * norm_b;
37            Ok(if denom == 0.0 { 0.0 } else { dot / denom })
38        } else {
39            Err(Error::DimensionMismatch {
40                expected: self.dimension(),
41                got: other.dimension(),
42            })
43        }
44    }
45}
46
47/// An embedding request: one or more texts to embed.
48#[derive(Debug, Clone)]
49pub struct EmbeddingRequest {
50    texts: Vec<String>,
51}
52
53impl EmbeddingRequest {
54    #[must_use]
55    pub fn new(texts: Vec<String>) -> Self { Self { texts } }
56
57    #[must_use]
58    pub fn single(text: String) -> Self { Self { texts: vec![text] } }
59
60    #[must_use]
61    pub fn texts(&self) -> &[String] { &self.texts }
62}
63
64/// The core embedding abstraction: send text, get vectors.
65pub trait EmbeddingModel {
66    /// Embed one or more texts.
67    fn embed(&self, request: EmbeddingRequest) -> Io<Error, Vec<Embedding>>;
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn identical_vectors_have_similarity_one() -> Result<(), Error> {
76        let a = Embedding::new(vec![1.0, 0.0, 0.0]);
77        let b = Embedding::new(vec![1.0, 0.0, 0.0]);
78        let sim = a.cosine_similarity(&b)?;
79        assert!((sim - 1.0).abs() < 1e-10);
80        Ok(())
81    }
82
83    #[test]
84    fn orthogonal_vectors_have_similarity_zero() -> Result<(), Error> {
85        let a = Embedding::new(vec![1.0, 0.0]);
86        let b = Embedding::new(vec![0.0, 1.0]);
87        let sim = a.cosine_similarity(&b)?;
88        assert!(sim.abs() < 1e-10);
89        Ok(())
90    }
91
92    #[test]
93    fn opposite_vectors_have_similarity_negative_one() -> Result<(), Error> {
94        let a = Embedding::new(vec![1.0, 0.0]);
95        let b = Embedding::new(vec![-1.0, 0.0]);
96        let sim = a.cosine_similarity(&b)?;
97        assert!((sim + 1.0).abs() < 1e-10);
98        Ok(())
99    }
100
101    #[test]
102    fn dimension_mismatch_returns_error() {
103        let a = Embedding::new(vec![1.0, 0.0]);
104        let b = Embedding::new(vec![1.0, 0.0, 0.0]);
105        assert!(a.cosine_similarity(&b).is_err());
106    }
107
108    #[test]
109    fn zero_vector_similarity_is_zero() -> Result<(), Error> {
110        let a = Embedding::new(vec![0.0, 0.0]);
111        let b = Embedding::new(vec![1.0, 0.0]);
112        let sim = a.cosine_similarity(&b)?;
113        assert!(sim.abs() < 1e-10);
114        Ok(())
115    }
116}