1use crate::model::{InferenceContext, Model, ModelConfig};
11use crate::tokenizer::Tokenizer;
12
13#[derive(Debug, Clone)]
15pub struct EmbeddingConfig {
16 pub layer: i32,
18 pub pooling: PoolingStrategy,
20 pub normalize: bool,
22 pub max_length: usize,
24 pub truncation: TruncationStrategy,
26}
27
28impl Default for EmbeddingConfig {
29 fn default() -> Self {
30 Self {
31 layer: -1,
32 pooling: PoolingStrategy::Mean,
33 normalize: true,
34 max_length: 512,
35 truncation: TruncationStrategy::Right,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum PoolingStrategy {
43 Last,
45 First,
47 Mean,
49 Max,
51 WeightedMean,
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum TruncationStrategy {
58 Right,
60 Left,
62 Middle,
64}
65
66pub struct EmbeddingExtractor {
68 config: EmbeddingConfig,
70 hidden_dim: usize,
72}
73
74impl EmbeddingExtractor {
75 pub fn new(config: EmbeddingConfig, model_config: &ModelConfig) -> Self {
77 Self {
78 config,
79 hidden_dim: model_config.hidden_size,
80 }
81 }
82
83 pub fn embed_text(
85 &self,
86 model: &dyn Model,
87 tokenizer: &Tokenizer,
88 ctx: &mut InferenceContext,
89 text: &str,
90 ) -> Result<Vec<f32>, EmbeddingError> {
91 let tokens = tokenizer.encode(text, false)?;
93
94 let tokens = self.truncate_tokens(&tokens);
96
97 let embeddings = self.get_token_embeddings(model, ctx, &tokens)?;
99
100 let pooled = self.pool_embeddings(&embeddings, tokens.len());
102
103 if self.config.normalize {
105 Ok(self.normalize_embedding(&pooled))
106 } else {
107 Ok(pooled)
108 }
109 }
110
111 pub fn embed_batch(
113 &self,
114 model: &dyn Model,
115 tokenizer: &Tokenizer,
116 ctx: &mut InferenceContext,
117 texts: &[&str],
118 ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
119 let mut results = Vec::with_capacity(texts.len());
120
121 for text in texts {
122 ctx.reset();
124 let embedding = self.embed_text(model, tokenizer, ctx, text)?;
125 results.push(embedding);
126 }
127
128 Ok(results)
129 }
130
131 fn truncate_tokens(&self, tokens: &[u32]) -> Vec<u32> {
133 if tokens.len() <= self.config.max_length {
134 return tokens.to_vec();
135 }
136
137 match self.config.truncation {
138 TruncationStrategy::Right => tokens[..self.config.max_length].to_vec(),
139 TruncationStrategy::Left => tokens[tokens.len() - self.config.max_length..].to_vec(),
140 TruncationStrategy::Middle => {
141 let half = self.config.max_length / 2;
142 let mut truncated = tokens[..half].to_vec();
143 truncated.extend_from_slice(&tokens[tokens.len() - half..]);
144 truncated
145 }
146 }
147 }
148
149 fn get_token_embeddings(
151 &self,
152 model: &dyn Model,
153 ctx: &mut InferenceContext,
154 tokens: &[u32],
155 ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
156 let mut embeddings = Vec::with_capacity(tokens.len());
157
158 for token in tokens {
160 let logits = model.forward(&[*token], ctx)?;
161
162 let logits_data = logits.as_f32()?;
166
167 let dim = self.hidden_dim.min(logits_data.len());
169 embeddings.push(logits_data[..dim].to_vec());
170 }
171
172 Ok(embeddings)
173 }
174
175 fn pool_embeddings(&self, embeddings: &[Vec<f32>], _seq_len: usize) -> Vec<f32> {
177 if embeddings.is_empty() {
178 return vec![0.0; self.hidden_dim];
179 }
180
181 let dim = embeddings[0].len();
182
183 match self.config.pooling {
184 PoolingStrategy::Last => embeddings.last().cloned().unwrap_or_else(|| vec![0.0; dim]),
185 PoolingStrategy::First => embeddings
186 .first()
187 .cloned()
188 .unwrap_or_else(|| vec![0.0; dim]),
189 PoolingStrategy::Mean => {
190 let mut mean = vec![0.0f32; dim];
191 for emb in embeddings {
192 for (i, &v) in emb.iter().enumerate() {
193 mean[i] += v;
194 }
195 }
196 let n = embeddings.len() as f32;
197 for v in &mut mean {
198 *v /= n;
199 }
200 mean
201 }
202 PoolingStrategy::Max => {
203 let mut max = vec![f32::NEG_INFINITY; dim];
204 for emb in embeddings {
205 for (i, &v) in emb.iter().enumerate() {
206 max[i] = max[i].max(v);
207 }
208 }
209 max
210 }
211 PoolingStrategy::WeightedMean => {
212 let mut weighted = vec![0.0f32; dim];
214 let mut total_weight = 0.0f32;
215
216 for (pos, emb) in embeddings.iter().enumerate() {
217 let weight = (pos + 1) as f32;
218 total_weight += weight;
219 for (i, &v) in emb.iter().enumerate() {
220 weighted[i] += v * weight;
221 }
222 }
223
224 for v in &mut weighted {
225 *v /= total_weight;
226 }
227 weighted
228 }
229 }
230 }
231
232 fn normalize_embedding(&self, embedding: &[f32]) -> Vec<f32> {
234 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
235 if norm > 0.0 {
236 embedding.iter().map(|x| x / norm).collect()
237 } else {
238 embedding.to_vec()
239 }
240 }
241
242 pub fn embedding_dim(&self) -> usize {
244 self.hidden_dim
245 }
246}
247
248#[derive(thiserror::Error, Debug)]
250pub enum EmbeddingError {
251 #[error("Tokenization error: {0}")]
252 Tokenization(#[from] crate::tokenizer::TokenizerError),
253
254 #[error("Model error: {0}")]
255 Model(#[from] crate::model::ModelError),
256
257 #[error("Tensor error: {0}")]
258 Tensor(#[from] crate::tensor::TensorError),
259
260 #[error("Empty input")]
261 EmptyInput,
262}
263
264pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
266 if a.len() != b.len() {
267 return 0.0;
268 }
269
270 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
271 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
272 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
273
274 if norm_a > 0.0 && norm_b > 0.0 {
275 dot / (norm_a * norm_b)
276 } else {
277 0.0
278 }
279}
280
281pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
283 if a.len() != b.len() {
284 return f32::INFINITY;
285 }
286
287 a.iter()
288 .zip(b.iter())
289 .map(|(x, y)| (x - y).powi(2))
290 .sum::<f32>()
291 .sqrt()
292}
293
294pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
296 if a.len() != b.len() {
297 return 0.0;
298 }
299
300 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
301}
302
303pub fn find_nearest(query: &[f32], embeddings: &[Vec<f32>], k: usize) -> Vec<(usize, f32)> {
305 let mut scores: Vec<(usize, f32)> = embeddings
306 .iter()
307 .enumerate()
308 .map(|(i, emb)| (i, cosine_similarity(query, emb)))
309 .collect();
310
311 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
313
314 scores.into_iter().take(k).collect()
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_embedding_config_default() {
323 let config = EmbeddingConfig::default();
324 assert_eq!(config.layer, -1);
325 assert!(config.normalize);
326 assert_eq!(config.pooling, PoolingStrategy::Mean);
327 }
328
329 #[test]
330 fn test_cosine_similarity() {
331 let a = vec![1.0, 0.0, 0.0];
332 let b = vec![1.0, 0.0, 0.0];
333 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
334
335 let c = vec![0.0, 1.0, 0.0];
336 assert!((cosine_similarity(&a, &c)).abs() < 0.001);
337 }
338
339 #[test]
340 fn test_euclidean_distance() {
341 let a = vec![0.0, 0.0];
342 let b = vec![3.0, 4.0];
343 assert!((euclidean_distance(&a, &b) - 5.0).abs() < 0.001);
344 }
345
346 #[test]
347 fn test_find_nearest() {
348 let query = vec![1.0, 0.0];
349 let embeddings = vec![
350 vec![1.0, 0.0], vec![0.0, 1.0], vec![0.7, 0.7], ];
354
355 let nearest = find_nearest(&query, &embeddings, 2);
356 assert_eq!(nearest.len(), 2);
357 assert_eq!(nearest[0].0, 0); }
359
360 #[test]
361 fn test_normalize() {
362 let extractor = EmbeddingExtractor {
363 config: EmbeddingConfig::default(),
364 hidden_dim: 3,
365 };
366
367 let embedding = vec![3.0, 4.0, 0.0];
368 let normalized = extractor.normalize_embedding(&embedding);
369
370 let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
371 assert!((norm - 1.0).abs() < 0.001);
372 }
373
374 #[test]
375 fn test_pooling_mean() {
376 let extractor = EmbeddingExtractor {
377 config: EmbeddingConfig {
378 pooling: PoolingStrategy::Mean,
379 ..Default::default()
380 },
381 hidden_dim: 2,
382 };
383
384 let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
385
386 let pooled = extractor.pool_embeddings(&embeddings, 2);
387 assert!((pooled[0] - 0.5).abs() < 0.001);
388 assert!((pooled[1] - 0.5).abs() < 0.001);
389 }
390}