1use comp_cat_rs::effect::io::Io;
4
5use crate::error::Error;
6
7#[derive(Debug, Clone)]
9pub struct Embedding {
10 values: Vec<f64>,
11}
12
13impl Embedding {
14 #[must_use]
15 pub fn new(values: Vec<f64>) -> Self { Self { values } }
16
17 #[must_use]
18 pub fn values(&self) -> &[f64] { &self.values }
19
20 #[must_use]
21 pub fn dimension(&self) -> usize { self.values.len() }
22
23 pub fn cosine_similarity(&self, other: &Self) -> Result<f64, Error> {
29 if self.dimension() == other.dimension() {
30 let dot: f64 = self.values.iter()
31 .zip(other.values.iter())
32 .map(|(a, b)| a * b)
33 .sum();
34 let norm_a: f64 = self.values.iter().map(|x| x * x).sum::<f64>().sqrt();
35 let norm_b: f64 = other.values.iter().map(|x| x * x).sum::<f64>().sqrt();
36 let denom = norm_a * norm_b;
37 Ok(if denom == 0.0 { 0.0 } else { dot / denom })
38 } else {
39 Err(Error::DimensionMismatch {
40 expected: self.dimension(),
41 got: other.dimension(),
42 })
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct EmbeddingRequest {
50 texts: Vec<String>,
51}
52
53impl EmbeddingRequest {
54 #[must_use]
55 pub fn new(texts: Vec<String>) -> Self { Self { texts } }
56
57 #[must_use]
58 pub fn single(text: String) -> Self { Self { texts: vec![text] } }
59
60 #[must_use]
61 pub fn texts(&self) -> &[String] { &self.texts }
62}
63
64pub trait EmbeddingModel {
66 fn embed(&self, request: EmbeddingRequest) -> Io<Error, Vec<Embedding>>;
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn identical_vectors_have_similarity_one() -> Result<(), Error> {
76 let a = Embedding::new(vec![1.0, 0.0, 0.0]);
77 let b = Embedding::new(vec![1.0, 0.0, 0.0]);
78 let sim = a.cosine_similarity(&b)?;
79 assert!((sim - 1.0).abs() < 1e-10);
80 Ok(())
81 }
82
83 #[test]
84 fn orthogonal_vectors_have_similarity_zero() -> Result<(), Error> {
85 let a = Embedding::new(vec![1.0, 0.0]);
86 let b = Embedding::new(vec![0.0, 1.0]);
87 let sim = a.cosine_similarity(&b)?;
88 assert!(sim.abs() < 1e-10);
89 Ok(())
90 }
91
92 #[test]
93 fn opposite_vectors_have_similarity_negative_one() -> Result<(), Error> {
94 let a = Embedding::new(vec![1.0, 0.0]);
95 let b = Embedding::new(vec![-1.0, 0.0]);
96 let sim = a.cosine_similarity(&b)?;
97 assert!((sim + 1.0).abs() < 1e-10);
98 Ok(())
99 }
100
101 #[test]
102 fn dimension_mismatch_returns_error() {
103 let a = Embedding::new(vec![1.0, 0.0]);
104 let b = Embedding::new(vec![1.0, 0.0, 0.0]);
105 assert!(a.cosine_similarity(&b).is_err());
106 }
107
108 #[test]
109 fn zero_vector_similarity_is_zero() -> Result<(), Error> {
110 let a = Embedding::new(vec![0.0, 0.0]);
111 let b = Embedding::new(vec![1.0, 0.0]);
112 let sim = a.cosine_similarity(&b)?;
113 assert!(sim.abs() < 1e-10);
114 Ok(())
115 }
116}