Skip to main content

llama_gguf/model/
embeddings.rs

1//! Embedding extraction API
2//!
3//! This module provides functionality to extract embeddings from models,
4//! useful for:
5//! - Semantic similarity/search
6//! - Retrieval-Augmented Generation (RAG)
7//! - Clustering and classification
8//! - Vector databases
9
10use crate::model::{InferenceContext, Model, ModelConfig};
11use crate::tokenizer::Tokenizer;
12
13/// Embedding extraction configuration
14#[derive(Debug, Clone)]
15pub struct EmbeddingConfig {
16    /// Which layer to extract embeddings from (-1 = last layer)
17    pub layer: i32,
18    /// Pooling strategy for sequence embeddings
19    pub pooling: PoolingStrategy,
20    /// Whether to normalize embeddings
21    pub normalize: bool,
22    /// Maximum sequence length
23    pub max_length: usize,
24    /// Truncation strategy
25    pub truncation: TruncationStrategy,
26}
27
28impl Default for EmbeddingConfig {
29    fn default() -> Self {
30        Self {
31            layer: -1,
32            pooling: PoolingStrategy::Mean,
33            normalize: true,
34            max_length: 512,
35            truncation: TruncationStrategy::Right,
36        }
37    }
38}
39
40/// Pooling strategy for combining token embeddings
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum PoolingStrategy {
43    /// Use the last token's embedding (common for decoder models)
44    Last,
45    /// Use the first token's embedding (CLS token style)
46    First,
47    /// Average all token embeddings
48    Mean,
49    /// Max pooling across tokens
50    Max,
51    /// Weighted mean based on attention
52    WeightedMean,
53}
54
55/// Truncation strategy for long sequences
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum TruncationStrategy {
58    /// Truncate from the right (keep beginning)
59    Right,
60    /// Truncate from the left (keep end)
61    Left,
62    /// Keep both ends, truncate middle
63    Middle,
64}
65
66/// Embedding extractor for a model
67pub struct EmbeddingExtractor {
68    /// Embedding configuration
69    config: EmbeddingConfig,
70    /// Model hidden dimension
71    hidden_dim: usize,
72}
73
74impl EmbeddingExtractor {
75    /// Create a new embedding extractor
76    pub fn new(config: EmbeddingConfig, model_config: &ModelConfig) -> Self {
77        Self {
78            config,
79            hidden_dim: model_config.hidden_size,
80        }
81    }
82
83    /// Extract embedding for a single text
84    pub fn embed_text(
85        &self,
86        model: &dyn Model,
87        tokenizer: &Tokenizer,
88        ctx: &mut InferenceContext,
89        text: &str,
90    ) -> Result<Vec<f32>, EmbeddingError> {
91        // Tokenize (without BOS token for embeddings)
92        let tokens = tokenizer.encode(text, false)?;
93
94        // Truncate if needed
95        let tokens = self.truncate_tokens(&tokens);
96
97        // Get embeddings for all tokens
98        let embeddings = self.get_token_embeddings(model, ctx, &tokens)?;
99
100        // Pool to single embedding
101        let pooled = self.pool_embeddings(&embeddings, tokens.len());
102
103        // Normalize if requested
104        if self.config.normalize {
105            Ok(self.normalize_embedding(&pooled))
106        } else {
107            Ok(pooled)
108        }
109    }
110
111    /// Extract embeddings for multiple texts (batched)
112    pub fn embed_batch(
113        &self,
114        model: &dyn Model,
115        tokenizer: &Tokenizer,
116        ctx: &mut InferenceContext,
117        texts: &[&str],
118    ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
119        let mut results = Vec::with_capacity(texts.len());
120
121        for text in texts {
122            // Reset context for each text
123            ctx.reset();
124            let embedding = self.embed_text(model, tokenizer, ctx, text)?;
125            results.push(embedding);
126        }
127
128        Ok(results)
129    }
130
131    /// Truncate tokens based on configuration
132    fn truncate_tokens(&self, tokens: &[u32]) -> Vec<u32> {
133        if tokens.len() <= self.config.max_length {
134            return tokens.to_vec();
135        }
136
137        match self.config.truncation {
138            TruncationStrategy::Right => tokens[..self.config.max_length].to_vec(),
139            TruncationStrategy::Left => tokens[tokens.len() - self.config.max_length..].to_vec(),
140            TruncationStrategy::Middle => {
141                let half = self.config.max_length / 2;
142                let mut truncated = tokens[..half].to_vec();
143                truncated.extend_from_slice(&tokens[tokens.len() - half..]);
144                truncated
145            }
146        }
147    }
148
149    /// Get embeddings for each token
150    fn get_token_embeddings(
151        &self,
152        model: &dyn Model,
153        ctx: &mut InferenceContext,
154        tokens: &[u32],
155    ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
156        let mut embeddings = Vec::with_capacity(tokens.len());
157
158        // Process tokens one at a time and capture hidden states
159        for token in tokens {
160            let logits = model.forward(&[*token], ctx)?;
161
162            // For now, use logits as a proxy for hidden state
163            // A full implementation would capture the actual hidden states
164            // from the model's internal layers
165            let logits_data = logits.as_f32()?;
166
167            // Create embedding from the beginning of logits (approximation)
168            let dim = self.hidden_dim.min(logits_data.len());
169            embeddings.push(logits_data[..dim].to_vec());
170        }
171
172        Ok(embeddings)
173    }
174
175    /// Pool token embeddings into a single embedding
176    fn pool_embeddings(&self, embeddings: &[Vec<f32>], _seq_len: usize) -> Vec<f32> {
177        if embeddings.is_empty() {
178            return vec![0.0; self.hidden_dim];
179        }
180
181        let dim = embeddings[0].len();
182
183        match self.config.pooling {
184            PoolingStrategy::Last => embeddings.last().cloned().unwrap_or_else(|| vec![0.0; dim]),
185            PoolingStrategy::First => embeddings
186                .first()
187                .cloned()
188                .unwrap_or_else(|| vec![0.0; dim]),
189            PoolingStrategy::Mean => {
190                let mut mean = vec![0.0f32; dim];
191                for emb in embeddings {
192                    for (i, &v) in emb.iter().enumerate() {
193                        mean[i] += v;
194                    }
195                }
196                let n = embeddings.len() as f32;
197                for v in &mut mean {
198                    *v /= n;
199                }
200                mean
201            }
202            PoolingStrategy::Max => {
203                let mut max = vec![f32::NEG_INFINITY; dim];
204                for emb in embeddings {
205                    for (i, &v) in emb.iter().enumerate() {
206                        max[i] = max[i].max(v);
207                    }
208                }
209                max
210            }
211            PoolingStrategy::WeightedMean => {
212                // Simple linear weighting - later tokens get more weight
213                let mut weighted = vec![0.0f32; dim];
214                let mut total_weight = 0.0f32;
215
216                for (pos, emb) in embeddings.iter().enumerate() {
217                    let weight = (pos + 1) as f32;
218                    total_weight += weight;
219                    for (i, &v) in emb.iter().enumerate() {
220                        weighted[i] += v * weight;
221                    }
222                }
223
224                for v in &mut weighted {
225                    *v /= total_weight;
226                }
227                weighted
228            }
229        }
230    }
231
232    /// L2 normalize an embedding
233    fn normalize_embedding(&self, embedding: &[f32]) -> Vec<f32> {
234        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
235        if norm > 0.0 {
236            embedding.iter().map(|x| x / norm).collect()
237        } else {
238            embedding.to_vec()
239        }
240    }
241
242    /// Get embedding dimension
243    pub fn embedding_dim(&self) -> usize {
244        self.hidden_dim
245    }
246}
247
248/// Embedding-specific error type
249#[derive(thiserror::Error, Debug)]
250pub enum EmbeddingError {
251    #[error("Tokenization error: {0}")]
252    Tokenization(#[from] crate::tokenizer::TokenizerError),
253
254    #[error("Model error: {0}")]
255    Model(#[from] crate::model::ModelError),
256
257    #[error("Tensor error: {0}")]
258    Tensor(#[from] crate::tensor::TensorError),
259
260    #[error("Empty input")]
261    EmptyInput,
262}
263
264/// Compute cosine similarity between two embeddings
265pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
266    if a.len() != b.len() {
267        return 0.0;
268    }
269
270    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
271    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
272    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
273
274    if norm_a > 0.0 && norm_b > 0.0 {
275        dot / (norm_a * norm_b)
276    } else {
277        0.0
278    }
279}
280
281/// Compute Euclidean distance between two embeddings
282pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
283    if a.len() != b.len() {
284        return f32::INFINITY;
285    }
286
287    a.iter()
288        .zip(b.iter())
289        .map(|(x, y)| (x - y).powi(2))
290        .sum::<f32>()
291        .sqrt()
292}
293
294/// Compute dot product between two embeddings
295pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
296    if a.len() != b.len() {
297        return 0.0;
298    }
299
300    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
301}
302
303/// Find k nearest neighbors from a set of embeddings
304pub fn find_nearest(query: &[f32], embeddings: &[Vec<f32>], k: usize) -> Vec<(usize, f32)> {
305    let mut scores: Vec<(usize, f32)> = embeddings
306        .iter()
307        .enumerate()
308        .map(|(i, emb)| (i, cosine_similarity(query, emb)))
309        .collect();
310
311    // Sort by similarity (descending)
312    scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
313
314    scores.into_iter().take(k).collect()
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_embedding_config_default() {
323        let config = EmbeddingConfig::default();
324        assert_eq!(config.layer, -1);
325        assert!(config.normalize);
326        assert_eq!(config.pooling, PoolingStrategy::Mean);
327    }
328
329    #[test]
330    fn test_cosine_similarity() {
331        let a = vec![1.0, 0.0, 0.0];
332        let b = vec![1.0, 0.0, 0.0];
333        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
334
335        let c = vec![0.0, 1.0, 0.0];
336        assert!((cosine_similarity(&a, &c)).abs() < 0.001);
337    }
338
339    #[test]
340    fn test_euclidean_distance() {
341        let a = vec![0.0, 0.0];
342        let b = vec![3.0, 4.0];
343        assert!((euclidean_distance(&a, &b) - 5.0).abs() < 0.001);
344    }
345
346    #[test]
347    fn test_find_nearest() {
348        let query = vec![1.0, 0.0];
349        let embeddings = vec![
350            vec![1.0, 0.0], // Most similar
351            vec![0.0, 1.0], // Orthogonal
352            vec![0.7, 0.7], // Somewhat similar
353        ];
354
355        let nearest = find_nearest(&query, &embeddings, 2);
356        assert_eq!(nearest.len(), 2);
357        assert_eq!(nearest[0].0, 0); // First embedding is most similar
358    }
359
360    #[test]
361    fn test_normalize() {
362        let extractor = EmbeddingExtractor {
363            config: EmbeddingConfig::default(),
364            hidden_dim: 3,
365        };
366
367        let embedding = vec![3.0, 4.0, 0.0];
368        let normalized = extractor.normalize_embedding(&embedding);
369
370        let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
371        assert!((norm - 1.0).abs() < 0.001);
372    }
373
374    #[test]
375    fn test_pooling_mean() {
376        let extractor = EmbeddingExtractor {
377            config: EmbeddingConfig {
378                pooling: PoolingStrategy::Mean,
379                ..Default::default()
380            },
381            hidden_dim: 2,
382        };
383
384        let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
385
386        let pooled = extractor.pool_embeddings(&embeddings, 2);
387        assert!((pooled[0] - 0.5).abs() < 0.001);
388        assert!((pooled[1] - 0.5).abs() < 0.001);
389    }
390}