phago_embeddings/
embedder.rs1use thiserror::Error;
4
5#[derive(Debug, Error)]
7pub enum EmbeddingError {
8 #[error("Model not loaded: {0}")]
9 ModelNotLoaded(String),
10
11 #[error("Tokenization failed: {0}")]
12 TokenizationFailed(String),
13
14 #[error("Inference failed: {0}")]
15 InferenceFailed(String),
16
17 #[error("API error: {0}")]
18 ApiError(String),
19
20 #[error("Invalid input: {0}")]
21 InvalidInput(String),
22
23 #[error("Dimension mismatch: expected {expected}, got {got}")]
24 DimensionMismatch { expected: usize, got: usize },
25
26 #[error("IO error: {0}")]
27 Io(#[from] std::io::Error),
28}
29
30pub type EmbeddingResult<T> = Result<T, EmbeddingError>;
32
33pub trait Embedder: Send + Sync {
37 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>>;
39
40 fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
42 texts.iter().map(|t| self.embed(t)).collect()
44 }
45
46 fn dimension(&self) -> usize;
48
49 fn model_name(&self) -> &str;
51
52 fn similarity(&self, a: &[f32], b: &[f32]) -> EmbeddingResult<f32> {
54 if a.len() != b.len() {
55 return Err(EmbeddingError::DimensionMismatch {
56 expected: a.len(),
57 got: b.len(),
58 });
59 }
60
61 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
62 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
63 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
64
65 if norm_a == 0.0 || norm_b == 0.0 {
66 return Ok(0.0);
67 }
68
69 Ok(dot / (norm_a * norm_b))
70 }
71
72 fn most_similar<'a>(
74 &self,
75 query: &str,
76 candidates: &[&'a str],
77 ) -> EmbeddingResult<Option<(&'a str, f32)>> {
78 if candidates.is_empty() {
79 return Ok(None);
80 }
81
82 let query_vec = self.embed(query)?;
83 let candidate_vecs = self.embed_batch(candidates)?;
84
85 let mut best: Option<(usize, f32)> = None;
86 for (i, vec) in candidate_vecs.iter().enumerate() {
87 let sim = self.similarity(&query_vec, vec)?;
88 if best.is_none() || sim > best.unwrap().1 {
89 best = Some((i, sim));
90 }
91 }
92
93 Ok(best.map(|(i, sim)| (candidates[i], sim)))
94 }
95
96 fn top_k_similar<'a>(
98 &self,
99 query: &str,
100 candidates: &[&'a str],
101 k: usize,
102 ) -> EmbeddingResult<Vec<(&'a str, f32)>> {
103 if candidates.is_empty() || k == 0 {
104 return Ok(vec![]);
105 }
106
107 let query_vec = self.embed(query)?;
108 let candidate_vecs = self.embed_batch(candidates)?;
109
110 let mut scores: Vec<(usize, f32)> = candidate_vecs
111 .iter()
112 .enumerate()
113 .map(|(i, vec)| {
114 let sim = self.similarity(&query_vec, vec).unwrap_or(0.0);
115 (i, sim)
116 })
117 .collect();
118
119 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
121
122 Ok(scores
123 .into_iter()
124 .take(k)
125 .map(|(i, sim)| (candidates[i], sim))
126 .collect())
127 }
128}
129
130#[derive(Debug, Clone)]
132pub struct Embedding {
133 pub vector: Vec<f32>,
135 pub text: Option<String>,
137 pub tokens: usize,
139}
140
141impl Embedding {
142 pub fn new(vector: Vec<f32>) -> Self {
144 Self {
145 vector,
146 text: None,
147 tokens: 0,
148 }
149 }
150
151 pub fn with_text(vector: Vec<f32>, text: String, tokens: usize) -> Self {
153 Self {
154 vector,
155 text: Some(text),
156 tokens,
157 }
158 }
159
160 pub fn dimension(&self) -> usize {
162 self.vector.len()
163 }
164}