Skip to main content

oxirs_embed/
batch_encoder.rs

1//! Batch text encoding pipeline with chunking and pooling strategies.
2//!
3//! Provides deterministic text-to-embedding conversion with configurable pooling,
4//! normalization, and similarity computation — all without external ML dependencies.
5
6use std::f64::consts::PI;
7
8/// Pooling strategy for aggregating token-level embeddings.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum PoolingStrategy {
11    /// Mean of all token embeddings
12    Mean,
13    /// Element-wise maximum across token embeddings
14    Max,
15    /// Use the first token (CLS-like) embedding
16    CLS,
17    /// Use the last token embedding
18    Last,
19}
20
21/// Configuration for the batch encoder.
22#[derive(Debug, Clone)]
23pub struct EncodingConfig {
24    /// Maximum number of tokens per text (truncated if exceeded)
25    pub max_length: usize,
26    /// Number of texts to process per batch
27    pub batch_size: usize,
28    /// Pooling strategy to aggregate token embeddings
29    pub pooling: PoolingStrategy,
30    /// Whether to L2-normalise the final embedding
31    pub normalize: bool,
32}
33
34impl Default for EncodingConfig {
35    fn default() -> Self {
36        Self {
37            max_length: 128,
38            batch_size: 32,
39            pooling: PoolingStrategy::Mean,
40            normalize: true,
41        }
42    }
43}
44
45/// A tokenised representation of a single text string.
46#[derive(Debug, Clone)]
47pub struct TokenizedText {
48    /// Raw token strings
49    pub tokens: Vec<String>,
50    /// Sequential integer IDs assigned to each token
51    pub ids: Vec<u32>,
52    /// Attention mask (1 = real token, 0 = padding — always 1 here)
53    pub attention_mask: Vec<u8>,
54}
55
56/// The output of encoding a batch of texts.
57#[derive(Debug, Clone)]
58pub struct EncodedBatch {
59    /// One embedding vector per input text
60    pub embeddings: Vec<Vec<f32>>,
61    /// Number of tokens for each input text (after truncation)
62    pub token_counts: Vec<usize>,
63    /// The actual number of texts in this batch
64    pub batch_size: usize,
65}
66
67/// Embedding dimensionality produced by this encoder.
68const EMBED_DIM: usize = 128;
69
70/// A large prime used in the deterministic ID hash to spread token IDs.
71const HASH_PRIME: u32 = 7919;
72
73/// Batch text encoder: tokenises, embeds, pools, and normalises text.
74pub struct BatchEncoder {
75    config: EncodingConfig,
76    /// Stable token vocabulary built lazily (token string → ID).
77    vocab: std::collections::HashMap<String, u32>,
78    /// Next available vocabulary ID.
79    next_id: u32,
80}
81
82impl BatchEncoder {
83    /// Create a new encoder with the given configuration.
84    pub fn new(config: EncodingConfig) -> Self {
85        Self {
86            config,
87            vocab: std::collections::HashMap::new(),
88            next_id: 1, // 0 reserved for unknown/padding
89        }
90    }
91
92    /// Tokenise `text` by splitting on whitespace, truncating to `max_length`,
93    /// and assigning sequential IDs from a growing vocabulary.
94    pub fn tokenize(&mut self, text: &str) -> TokenizedText {
95        let raw_tokens: Vec<String> = text.split_whitespace().map(|t| t.to_lowercase()).collect();
96
97        let truncated: Vec<String> = raw_tokens
98            .into_iter()
99            .take(self.config.max_length)
100            .collect();
101
102        let ids: Vec<u32> = truncated
103            .iter()
104            .map(|tok| {
105                if let Some(&id) = self.vocab.get(tok) {
106                    id
107                } else {
108                    let id = self.next_id;
109                    self.vocab.insert(tok.clone(), id);
110                    self.next_id = self.next_id.saturating_add(1);
111                    id
112                }
113            })
114            .collect();
115
116        let attention_mask = vec![1u8; truncated.len()];
117
118        TokenizedText {
119            tokens: truncated,
120            ids,
121            attention_mask,
122        }
123    }
124
125    /// Produce a deterministic 128-dimensional embedding for a single token ID.
126    ///
127    /// Each dimension `d` is computed as:
128    ///   `cos(2π * ((id * HASH_PRIME + d) mod 997) / 997)`
129    /// for even dimensions, and the sine counterpart for odd dimensions.
130    /// This ensures distinctness across tokens without any randomness.
131    fn token_embedding(id: u32) -> Vec<f32> {
132        let mut emb = Vec::with_capacity(EMBED_DIM);
133        for d in 0..EMBED_DIM {
134            let phase = ((id.wrapping_mul(HASH_PRIME).wrapping_add(d as u32)) % 997) as f64 / 997.0
135                * 2.0
136                * PI;
137            let val = if d % 2 == 0 { phase.cos() } else { phase.sin() };
138            emb.push(val as f32);
139        }
140        emb
141    }
142
143    /// Encode a single text string into a 128-dimensional embedding.
144    ///
145    /// Steps: tokenise → produce per-token embeddings → pool → optionally normalise.
146    pub fn encode_single(&mut self, text: &str) -> Vec<f32> {
147        let tokenized = self.tokenize(text);
148
149        if tokenized.ids.is_empty() {
150            // Return zero vector for empty input
151            return vec![0.0f32; EMBED_DIM];
152        }
153
154        let token_embs: Vec<Vec<f32>> = tokenized
155            .ids
156            .iter()
157            .map(|&id| Self::token_embedding(id))
158            .collect();
159
160        let mut pooled = Self::pool(token_embs, &self.config.pooling.clone());
161
162        if self.config.normalize {
163            Self::normalize_l2(&mut pooled);
164        }
165
166        pooled
167    }
168
169    /// Encode a slice of text strings in chunks of `batch_size`.
170    pub fn encode_batch(&mut self, texts: &[&str]) -> EncodedBatch {
171        let mut embeddings = Vec::with_capacity(texts.len());
172        let mut token_counts = Vec::with_capacity(texts.len());
173
174        // Process in chunks of batch_size
175        for chunk in texts.chunks(self.config.batch_size) {
176            for &text in chunk {
177                let tokenized = self.tokenize(text);
178                let count = tokenized.ids.len();
179                token_counts.push(count);
180
181                if tokenized.ids.is_empty() {
182                    embeddings.push(vec![0.0f32; EMBED_DIM]);
183                    continue;
184                }
185
186                let token_embs: Vec<Vec<f32>> = tokenized
187                    .ids
188                    .iter()
189                    .map(|&id| Self::token_embedding(id))
190                    .collect();
191
192                let mut pooled = Self::pool(token_embs, &self.config.pooling.clone());
193
194                if self.config.normalize {
195                    Self::normalize_l2(&mut pooled);
196                }
197
198                embeddings.push(pooled);
199            }
200        }
201
202        let batch_size = embeddings.len();
203        EncodedBatch {
204            embeddings,
205            token_counts,
206            batch_size,
207        }
208    }
209
210    /// Aggregate a list of per-token embedding vectors according to `strategy`.
211    pub fn pool(token_embeddings: Vec<Vec<f32>>, strategy: &PoolingStrategy) -> Vec<f32> {
212        if token_embeddings.is_empty() {
213            return vec![0.0f32; EMBED_DIM];
214        }
215
216        let dim = token_embeddings[0].len();
217        let n = token_embeddings.len();
218
219        match strategy {
220            PoolingStrategy::Mean => {
221                let mut result = vec![0.0f32; dim];
222                for emb in &token_embeddings {
223                    for (r, &v) in result.iter_mut().zip(emb.iter()) {
224                        *r += v;
225                    }
226                }
227                for r in result.iter_mut() {
228                    *r /= n as f32;
229                }
230                result
231            }
232            PoolingStrategy::Max => {
233                let mut result = vec![f32::NEG_INFINITY; dim];
234                for emb in &token_embeddings {
235                    for (r, &v) in result.iter_mut().zip(emb.iter()) {
236                        if v > *r {
237                            *r = v;
238                        }
239                    }
240                }
241                result
242            }
243            PoolingStrategy::CLS => {
244                // First token (index 0)
245                token_embeddings[0].clone()
246            }
247            PoolingStrategy::Last => {
248                // Last token
249                token_embeddings[n - 1].clone()
250            }
251        }
252    }
253
254    /// Normalise a vector in-place to unit L2 norm.
255    /// If the norm is zero, the vector is left unchanged.
256    pub fn normalize_l2(v: &mut [f32]) {
257        let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
258        if norm > 1e-10 {
259            for x in v.iter_mut() {
260                *x /= norm;
261            }
262        }
263    }
264
265    /// Cosine similarity between two embedding vectors.
266    /// Returns 0.0 if either vector has zero norm.
267    pub fn similarity(a: &[f32], b: &[f32]) -> f64 {
268        if a.len() != b.len() || a.is_empty() {
269            return 0.0;
270        }
271        let dot: f64 = a
272            .iter()
273            .zip(b.iter())
274            .map(|(&x, &y)| x as f64 * y as f64)
275            .sum();
276        let norm_a: f64 = a
277            .iter()
278            .map(|&x| (x as f64) * (x as f64))
279            .sum::<f64>()
280            .sqrt();
281        let norm_b: f64 = b
282            .iter()
283            .map(|&x| (x as f64) * (x as f64))
284            .sum::<f64>()
285            .sqrt();
286        if norm_a < 1e-10 || norm_b < 1e-10 {
287            return 0.0;
288        }
289        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
290    }
291
292    /// Return the number of unique tokens in the vocabulary so far.
293    pub fn vocab_size(&self) -> usize {
294        self.vocab.len()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    fn default_encoder() -> BatchEncoder {
303        BatchEncoder::new(EncodingConfig::default())
304    }
305
306    // --- Tokenization tests ---
307
308    #[test]
309    fn test_tokenize_basic() {
310        let mut enc = default_encoder();
311        let t = enc.tokenize("hello world");
312        assert_eq!(t.tokens, vec!["hello", "world"]);
313        assert_eq!(t.ids.len(), 2);
314        assert_eq!(t.attention_mask, vec![1, 1]);
315    }
316
317    #[test]
318    fn test_tokenize_empty_string() {
319        let mut enc = default_encoder();
320        let t = enc.tokenize("");
321        assert!(t.tokens.is_empty());
322        assert!(t.ids.is_empty());
323        assert!(t.attention_mask.is_empty());
324    }
325
326    #[test]
327    fn test_tokenize_single_token() {
328        let mut enc = default_encoder();
329        let t = enc.tokenize("rust");
330        assert_eq!(t.tokens, vec!["rust"]);
331        assert_eq!(t.ids.len(), 1);
332    }
333
334    #[test]
335    fn test_tokenize_lowercases() {
336        let mut enc = default_encoder();
337        let t = enc.tokenize("Hello WORLD");
338        assert_eq!(t.tokens, vec!["hello", "world"]);
339    }
340
341    #[test]
342    fn test_tokenize_truncation() {
343        let config = EncodingConfig {
344            max_length: 3,
345            ..EncodingConfig::default()
346        };
347        let mut enc = BatchEncoder::new(config);
348        let t = enc.tokenize("a b c d e");
349        assert_eq!(t.tokens.len(), 3);
350        assert_eq!(t.ids.len(), 3);
351    }
352
353    #[test]
354    fn test_tokenize_max_length_exact() {
355        let config = EncodingConfig {
356            max_length: 2,
357            ..EncodingConfig::default()
358        };
359        let mut enc = BatchEncoder::new(config);
360        let t = enc.tokenize("x y");
361        assert_eq!(t.tokens.len(), 2);
362    }
363
364    #[test]
365    fn test_tokenize_consistent_ids() {
366        let mut enc = default_encoder();
367        let t1 = enc.tokenize("hello");
368        let t2 = enc.tokenize("hello");
369        assert_eq!(t1.ids, t2.ids);
370    }
371
372    #[test]
373    fn test_tokenize_different_words_different_ids() {
374        let mut enc = default_encoder();
375        let t1 = enc.tokenize("foo");
376        let t2 = enc.tokenize("bar");
377        assert_ne!(t1.ids[0], t2.ids[0]);
378    }
379
380    // --- Encode single tests ---
381
382    #[test]
383    fn test_encode_single_returns_128_dim() {
384        let mut enc = default_encoder();
385        let emb = enc.encode_single("hello world");
386        assert_eq!(emb.len(), EMBED_DIM);
387    }
388
389    #[test]
390    fn test_encode_single_deterministic() {
391        let mut enc1 = default_encoder();
392        let mut enc2 = default_encoder();
393        let e1 = enc1.encode_single("deterministic test");
394        let e2 = enc2.encode_single("deterministic test");
395        assert_eq!(e1, e2);
396    }
397
398    #[test]
399    fn test_encode_single_normalized_when_flag_set() {
400        let mut enc = default_encoder();
401        let emb = enc.encode_single("normalize me please");
402        let norm: f32 = emb.iter().map(|&x| x * x).sum::<f32>().sqrt();
403        assert!((norm - 1.0).abs() < 1e-5, "Expected unit norm, got {norm}");
404    }
405
406    #[test]
407    fn test_encode_single_no_normalize() {
408        let config = EncodingConfig {
409            normalize: false,
410            ..EncodingConfig::default()
411        };
412        let mut enc = BatchEncoder::new(config);
413        let emb = enc.encode_single("no norm");
414        let norm: f32 = emb.iter().map(|&x| x * x).sum::<f32>().sqrt();
415        // Not necessarily unit norm
416        assert!(norm >= 0.0);
417    }
418
419    #[test]
420    fn test_encode_single_empty_returns_zeros() {
421        let mut enc = default_encoder();
422        let emb = enc.encode_single("");
423        assert_eq!(emb.len(), EMBED_DIM);
424        assert!(emb.iter().all(|&x| x == 0.0));
425    }
426
427    #[test]
428    fn test_encode_single_different_texts_different_embeddings() {
429        let mut enc = default_encoder();
430        let e1 = enc.encode_single("apple banana cherry");
431        let e2 = enc.encode_single("dog cat fish");
432        // With the same encoder, same tokens get same IDs; different tokens → different embeddings
433        assert_ne!(e1, e2);
434    }
435
436    // --- Encode batch tests ---
437
438    #[test]
439    fn test_encode_batch_count() {
440        let mut enc = default_encoder();
441        let texts = ["one", "two", "three"];
442        let batch = enc.encode_batch(&texts);
443        assert_eq!(batch.batch_size, 3);
444        assert_eq!(batch.embeddings.len(), 3);
445        assert_eq!(batch.token_counts.len(), 3);
446    }
447
448    #[test]
449    fn test_encode_batch_each_128_dim() {
450        let mut enc = default_encoder();
451        let texts = ["alpha", "beta gamma", "delta epsilon zeta"];
452        let batch = enc.encode_batch(&texts);
453        for emb in &batch.embeddings {
454            assert_eq!(emb.len(), EMBED_DIM);
455        }
456    }
457
458    #[test]
459    fn test_encode_batch_token_counts_correct() {
460        let mut enc = BatchEncoder::new(EncodingConfig {
461            max_length: 10,
462            ..EncodingConfig::default()
463        });
464        let texts = ["a b c", "x", "one two three four"];
465        let batch = enc.encode_batch(&texts);
466        assert_eq!(batch.token_counts[0], 3);
467        assert_eq!(batch.token_counts[1], 1);
468        assert_eq!(batch.token_counts[2], 4);
469    }
470
471    #[test]
472    fn test_encode_batch_chunking() {
473        let config = EncodingConfig {
474            batch_size: 2,
475            ..EncodingConfig::default()
476        };
477        let mut enc = BatchEncoder::new(config);
478        let texts: Vec<&str> = (0..5).map(|_| "hello world").collect();
479        let batch = enc.encode_batch(&texts);
480        assert_eq!(batch.batch_size, 5);
481    }
482
483    #[test]
484    fn test_encode_batch_empty_texts() {
485        let mut enc = default_encoder();
486        let texts: Vec<&str> = vec![];
487        let batch = enc.encode_batch(&texts);
488        assert_eq!(batch.batch_size, 0);
489    }
490
491    #[test]
492    fn test_encode_batch_single_text() {
493        let mut enc = default_encoder();
494        let texts = ["only one"];
495        let batch = enc.encode_batch(&texts);
496        assert_eq!(batch.batch_size, 1);
497    }
498
499    // --- Pooling strategy tests ---
500
501    fn sample_token_embeddings() -> Vec<Vec<f32>> {
502        vec![
503            vec![1.0, 0.0, 2.0, -1.0],
504            vec![0.0, 3.0, -1.0, 2.0],
505            vec![2.0, 1.0, 0.0, 0.5],
506        ]
507    }
508
509    #[test]
510    fn test_pool_mean() {
511        let embs = sample_token_embeddings();
512        let result = BatchEncoder::pool(embs, &PoolingStrategy::Mean);
513        let expected = [1.0, 4.0 / 3.0, 1.0 / 3.0, 0.5];
514        for (r, e) in result.iter().zip(expected.iter()) {
515            assert!((r - e).abs() < 1e-5, "{r} != {e}");
516        }
517    }
518
519    #[test]
520    fn test_pool_max() {
521        let embs = sample_token_embeddings();
522        let result = BatchEncoder::pool(embs, &PoolingStrategy::Max);
523        let expected = vec![2.0f32, 3.0, 2.0, 2.0];
524        assert_eq!(result, expected);
525    }
526
527    #[test]
528    fn test_pool_cls() {
529        let embs = sample_token_embeddings();
530        let result = BatchEncoder::pool(embs, &PoolingStrategy::CLS);
531        assert_eq!(result, vec![1.0, 0.0, 2.0, -1.0]);
532    }
533
534    #[test]
535    fn test_pool_last() {
536        let embs = sample_token_embeddings();
537        let result = BatchEncoder::pool(embs, &PoolingStrategy::Last);
538        assert_eq!(result, vec![2.0, 1.0, 0.0, 0.5]);
539    }
540
541    #[test]
542    fn test_pool_empty() {
543        let result = BatchEncoder::pool(vec![], &PoolingStrategy::Mean);
544        assert_eq!(result.len(), EMBED_DIM);
545        assert!(result.iter().all(|&x| x == 0.0));
546    }
547
548    #[test]
549    fn test_pool_single_token_mean() {
550        let embs = vec![vec![1.0, 2.0, 3.0]];
551        let result = BatchEncoder::pool(embs.clone(), &PoolingStrategy::Mean);
552        assert_eq!(result, embs[0]);
553    }
554
555    #[test]
556    fn test_pool_single_token_max() {
557        let embs = vec![vec![4.0, 5.0, 6.0]];
558        let result = BatchEncoder::pool(embs.clone(), &PoolingStrategy::Max);
559        assert_eq!(result, embs[0]);
560    }
561
562    // --- Normalize tests ---
563
564    #[test]
565    fn test_normalize_unit_norm() {
566        let mut v = vec![3.0f32, 4.0, 0.0];
567        BatchEncoder::normalize_l2(&mut v);
568        let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
569        assert!((norm - 1.0).abs() < 1e-6);
570        assert!((v[0] - 0.6).abs() < 1e-5);
571        assert!((v[1] - 0.8).abs() < 1e-5);
572    }
573
574    #[test]
575    fn test_normalize_zero_vector() {
576        let mut v = vec![0.0f32, 0.0, 0.0];
577        BatchEncoder::normalize_l2(&mut v);
578        // Should remain zero
579        assert!(v.iter().all(|&x| x == 0.0));
580    }
581
582    #[test]
583    fn test_normalize_already_unit() {
584        let mut v = vec![1.0f32, 0.0, 0.0];
585        BatchEncoder::normalize_l2(&mut v);
586        assert!((v[0] - 1.0).abs() < 1e-6);
587    }
588
589    // --- Similarity tests ---
590
591    #[test]
592    fn test_similarity_identical_vectors() {
593        let v = vec![1.0f32, 0.0, 0.0];
594        let sim = BatchEncoder::similarity(&v, &v);
595        assert!((sim - 1.0).abs() < 1e-6);
596    }
597
598    #[test]
599    fn test_similarity_orthogonal_vectors() {
600        let a = vec![1.0f32, 0.0, 0.0];
601        let b = vec![0.0f32, 1.0, 0.0];
602        let sim = BatchEncoder::similarity(&a, &b);
603        assert!(sim.abs() < 1e-6);
604    }
605
606    #[test]
607    fn test_similarity_opposite_vectors() {
608        let a = vec![1.0f32, 0.0];
609        let b = vec![-1.0f32, 0.0];
610        let sim = BatchEncoder::similarity(&a, &b);
611        assert!((sim - (-1.0)).abs() < 1e-6);
612    }
613
614    #[test]
615    fn test_similarity_zero_vector() {
616        let a = vec![1.0f32, 0.0];
617        let b = vec![0.0f32, 0.0];
618        let sim = BatchEncoder::similarity(&a, &b);
619        assert_eq!(sim, 0.0);
620    }
621
622    #[test]
623    fn test_similarity_mismatched_len() {
624        let a = vec![1.0f32, 0.0];
625        let b = vec![1.0f32, 0.0, 0.5];
626        let sim = BatchEncoder::similarity(&a, &b);
627        assert_eq!(sim, 0.0);
628    }
629
630    #[test]
631    fn test_similarity_empty_vectors() {
632        let sim = BatchEncoder::similarity(&[], &[]);
633        assert_eq!(sim, 0.0);
634    }
635
636    #[test]
637    fn test_similarity_bounded() {
638        let mut enc = default_encoder();
639        let e1 = enc.encode_single("semantic similarity test");
640        let e2 = enc.encode_single("another sentence here");
641        let sim = BatchEncoder::similarity(&e1, &e2);
642        assert!((-1.0..=1.0).contains(&sim));
643    }
644
645    // --- Vocab tests ---
646
647    #[test]
648    fn test_vocab_grows() {
649        let mut enc = default_encoder();
650        assert_eq!(enc.vocab_size(), 0);
651        enc.tokenize("alpha beta gamma");
652        assert_eq!(enc.vocab_size(), 3);
653        enc.tokenize("alpha delta"); // "alpha" already known
654        assert_eq!(enc.vocab_size(), 4);
655    }
656
657    #[test]
658    fn test_encode_batch_matches_single() {
659        let mut enc = default_encoder();
660        let texts = ["hello world", "foo bar baz"];
661        let e_single_a = enc.encode_single(texts[0]);
662        let e_single_b = enc.encode_single(texts[1]);
663        let batch = enc.encode_batch(&texts);
664        assert_eq!(batch.embeddings[0], e_single_a);
665        assert_eq!(batch.embeddings[1], e_single_b);
666    }
667}