Skip to main content

inference/
batch.rs

1//! Batch processing utilities for efficient embedding generation.
2
3use crate::error::{InferenceError, Result};
4use crate::models::EmbeddingModel;
5use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
6use tracing::{debug, instrument};
7
8/// Prepared batch of tokenized inputs ready for ORT inference.
9///
10/// All token arrays are flat (row-major) with shape `[batch_size, seq_len]`.
11/// Values are `i64` because ONNX Runtime BERT models expect `int64` inputs.
12#[derive(Debug)]
13pub struct PreparedBatch {
14    /// Input token IDs, flat `[batch_size * seq_len]`, i64
15    pub input_ids: Vec<i64>,
16    /// Attention mask, flat `[batch_size * seq_len]`, i64
17    pub attention_mask: Vec<i64>,
18    /// Token type IDs, flat `[batch_size * seq_len]`, i64
19    pub token_type_ids: Vec<i64>,
20    /// Number of items in this batch
21    pub batch_size: usize,
22    /// Sequence length (uniform after padding)
23    pub seq_len: usize,
24    /// Original text lengths (for debugging)
25    pub original_lengths: Vec<usize>,
26}
27
28/// Batch processor for preparing text inputs for embedding models.
29pub struct BatchProcessor {
30    tokenizer: Tokenizer,
31    model: EmbeddingModel,
32    max_batch_size: usize,
33}
34
35impl BatchProcessor {
36    /// Create a new batch processor.
37    pub fn new(mut tokenizer: Tokenizer, model: EmbeddingModel, max_batch_size: usize) -> Self {
38        // Configure padding
39        let padding = PaddingParams {
40            strategy: PaddingStrategy::BatchLongest,
41            pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
42            pad_token: tokenizer
43                .get_padding()
44                .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
45            ..Default::default()
46        };
47        tokenizer.with_padding(Some(padding));
48
49        // Configure truncation
50        let truncation = TruncationParams {
51            max_length: model.max_seq_length(),
52            ..Default::default()
53        };
54        let _ = tokenizer.with_truncation(Some(truncation));
55
56        Self {
57            tokenizer,
58            model,
59            max_batch_size,
60        }
61    }
62
63    /// Get the maximum batch size.
64    pub fn max_batch_size(&self) -> usize {
65        self.max_batch_size
66    }
67
68    /// Prepare texts for embedding, optionally applying model-specific prefixes.
69    #[instrument(skip(self, texts), fields(count = texts.len()))]
70    pub fn prepare_texts(&self, texts: &[String], is_query: bool) -> Vec<String> {
71        let prefix = if is_query {
72            self.model.query_prefix()
73        } else {
74            self.model.document_prefix()
75        };
76
77        match prefix {
78            Some(p) => texts.iter().map(|t| format!("{}{}", p, t)).collect(),
79            None => texts.to_vec(),
80        }
81    }
82
83    /// Tokenize a batch of texts and prepare flat i64 arrays for ORT inference.
84    #[instrument(skip(self, texts), fields(count = texts.len()))]
85    pub fn tokenize_batch(&self, texts: &[String]) -> Result<PreparedBatch> {
86        if texts.is_empty() {
87            return Err(InferenceError::InvalidInput("Empty text batch".into()));
88        }
89
90        if texts.len() > self.max_batch_size {
91            return Err(InferenceError::InvalidInput(format!(
92                "Batch size {} exceeds maximum {}",
93                texts.len(),
94                self.max_batch_size
95            )));
96        }
97
98        let original_lengths: Vec<usize> = texts.iter().map(|t| t.len()).collect();
99
100        debug!(
101            "Tokenizing {} texts, max length: {}",
102            texts.len(),
103            original_lengths.iter().max().unwrap_or(&0)
104        );
105
106        // Tokenize all texts
107        let encodings = self
108            .tokenizer
109            .encode_batch(texts.to_vec(), true)
110            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
111
112        let batch_size = encodings.len();
113        let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
114
115        debug!("Tokenized: batch_size={}, seq_len={}", batch_size, seq_len);
116
117        // Extract and flatten as i64 (ORT BERT models require int64)
118        let mut input_ids = Vec::with_capacity(batch_size * seq_len);
119        let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
120        let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
121
122        for enc in &encodings {
123            input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
124            attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
125
126            let type_ids = enc.get_type_ids();
127            if type_ids.is_empty() {
128                token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
129            } else {
130                token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
131            }
132        }
133
134        Ok(PreparedBatch {
135            input_ids,
136            attention_mask,
137            token_type_ids,
138            batch_size,
139            seq_len,
140            original_lengths,
141        })
142    }
143
144    /// Split texts into batches of maximum size.
145    pub fn split_into_batches<'a>(&self, texts: &'a [String]) -> Vec<&'a [String]> {
146        texts.chunks(self.max_batch_size).collect()
147    }
148}
149
150/// Apply mean pooling to ORT last-hidden-state output.
151///
152/// `last_hidden_state` is a flat row-major slice with shape `[batch, seq_len, hidden_size]`.
153/// `attention_mask` is flat `[batch * seq_len]` with i64 values (0 or 1).
154///
155/// Returns `Vec<Vec<f32>>` with shape `[batch, hidden_size]`.
156#[instrument(skip_all, fields(batch_size, seq_len, hidden_size))]
157pub fn mean_pooling(
158    last_hidden_state: &[f32],
159    batch_size: usize,
160    seq_len: usize,
161    hidden_size: usize,
162    attention_mask: &[i64],
163) -> Vec<Vec<f32>> {
164    let mut result = vec![vec![0.0f32; hidden_size]; batch_size];
165
166    for b in 0..batch_size {
167        // Sum of mask weights for this batch item
168        let mask_sum: f32 = (0..seq_len)
169            .map(|s| attention_mask[b * seq_len + s] as f32)
170            .sum::<f32>()
171            .max(1e-9);
172
173        for (h, cell) in result[b].iter_mut().enumerate() {
174            let weighted_sum: f32 = (0..seq_len)
175                .map(|s| {
176                    let lhs_idx = b * seq_len * hidden_size + s * hidden_size + h;
177                    last_hidden_state[lhs_idx] * attention_mask[b * seq_len + s] as f32
178                })
179                .sum();
180            *cell = weighted_sum / mask_sum;
181        }
182    }
183
184    debug!(
185        "Mean pooled: batch={}, hidden={}",
186        result.len(),
187        result.first().map(|v| v.len()).unwrap_or(0)
188    );
189
190    result
191}
192
193/// Normalize embeddings to unit length (L2 normalization), in-place.
194#[instrument(skip_all, fields(count = embeddings.len()))]
195pub fn normalize_embeddings(embeddings: &mut [Vec<f32>]) {
196    for emb in embeddings.iter_mut() {
197        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
198        for v in emb.iter_mut() {
199            *v /= norm;
200        }
201    }
202    debug!("Normalized {} embeddings", embeddings.len());
203}
204
205/// Truncate a full-dimension embedding to `target_dim` MRL sub-dimension, then re-normalise.
206///
207/// Matryoshka Representation Learning (MRL) trains models such that the first `N` dimensions
208/// of the embedding capture the most semantic information.  Truncating to `N` followed by
209/// L2 re-normalisation is semantically valid and gives a high-quality lower-dimensional
210/// representation.
211///
212/// Returns an error if `target_dim > embedding.len()` or `target_dim == 0`.
213pub fn truncate_mrl(embedding: &[f32], target_dim: usize) -> crate::error::Result<Vec<f32>> {
214    if target_dim == 0 || target_dim > embedding.len() {
215        return Err(crate::error::InferenceError::InvalidInput(format!(
216            "MRL target_dim={} is out of range for embedding of length {}",
217            target_dim,
218            embedding.len()
219        )));
220    }
221    let mut truncated = embedding[..target_dim].to_vec();
222    let norm: f32 = truncated
223        .iter()
224        .map(|x| x * x)
225        .sum::<f32>()
226        .sqrt()
227        .max(1e-12);
228    for v in truncated.iter_mut() {
229        *v /= norm;
230    }
231    Ok(truncated)
232}
233
234/// Batch input accumulator that respects a total *token budget* instead of a fixed item count.
235///
236/// Current approach: estimates token count as `text.len() / 4` (chars per token approximation).
237/// A production deployment can inject a real tokeniser via `TokenBudgetBatcher::with_token_fn`.
238///
239/// # Example
240///
241/// ```no_run
242/// use inference::batch::TokenBudgetBatcher;
243///
244/// let mut batcher = TokenBudgetBatcher::new(2048);
245/// batcher.push("short text".to_string());
246/// batcher.push("another short text".to_string());
247/// let batches = batcher.finish();
248/// assert_eq!(batches.len(), 1);
249/// ```
250pub struct TokenBudgetBatcher {
251    token_budget: usize,
252    current_batch: Vec<String>,
253    current_tokens: usize,
254    finished_batches: Vec<Vec<String>>,
255    /// Optional token-count estimator.  Defaults to `text.len() / 4`.
256    token_count_fn: Box<dyn Fn(&str) -> usize + Send + Sync>,
257}
258
259impl TokenBudgetBatcher {
260    /// Create a batcher with the given token budget.
261    ///
262    /// Defaults to character-count estimation (`len / 4`).
263    /// Use `with_token_fn` to inject a real tokeniser.
264    pub fn new(token_budget: usize) -> Self {
265        let budget = std::env::var("DAKERA_TOKEN_BUDGET")
266            .ok()
267            .and_then(|v| v.parse::<usize>().ok())
268            .filter(|&n| n > 0)
269            .unwrap_or(token_budget)
270            .max(1);
271
272        Self {
273            token_budget: budget,
274            current_batch: Vec::new(),
275            current_tokens: 0,
276            finished_batches: Vec::new(),
277            token_count_fn: Box::new(|text| (text.len() / 4).max(1)),
278        }
279    }
280
281    /// Replace the default character-count estimator with a real token counter.
282    pub fn with_token_fn(mut self, f: impl Fn(&str) -> usize + Send + Sync + 'static) -> Self {
283        self.token_count_fn = Box::new(f);
284        self
285    }
286
287    /// Add a text to the current batch.
288    ///
289    /// If adding `text` would exceed the token budget, the current batch is flushed first and a
290    /// new batch starting with `text` is begun.
291    pub fn push(&mut self, text: String) {
292        let tokens = (self.token_count_fn)(&text);
293        if !self.current_batch.is_empty() && self.current_tokens + tokens > self.token_budget {
294            // Flush current batch
295            let batch = std::mem::take(&mut self.current_batch);
296            self.finished_batches.push(batch);
297            self.current_tokens = 0;
298        }
299        self.current_tokens += tokens;
300        self.current_batch.push(text);
301    }
302
303    /// Add multiple texts at once.
304    pub fn push_all(&mut self, texts: impl IntoIterator<Item = String>) {
305        for t in texts {
306            self.push(t);
307        }
308    }
309
310    /// Flush any pending batch and return all accumulated batches.
311    ///
312    /// The batcher is reset after this call and can be reused.
313    pub fn finish(&mut self) -> Vec<Vec<String>> {
314        if !self.current_batch.is_empty() {
315            let batch = std::mem::take(&mut self.current_batch);
316            self.finished_batches.push(batch);
317            self.current_tokens = 0;
318        }
319        std::mem::take(&mut self.finished_batches)
320    }
321
322    /// Number of texts accumulated in the current (unflushed) batch.
323    pub fn pending_count(&self) -> usize {
324        self.current_batch.len()
325    }
326
327    /// Total tokens accumulated in the current (unflushed) batch.
328    pub fn pending_tokens(&self) -> usize {
329        self.current_tokens
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    /// Create a minimal tokenizer for unit tests (no network required).
338    /// `prepare_texts` only uses the model's prefix logic, not the tokenizer,
339    /// so any valid tokenizer works here.
340    fn dummy_tokenizer() -> Tokenizer {
341        use tokenizers::models::bpe::BPE;
342        Tokenizer::new(BPE::default())
343    }
344
345    /// Create a WordLevel tokenizer with a small known vocabulary.
346    /// Unlike BPE::default(), this tokenizer can actually encode words,
347    /// enabling tests of the tokenize_batch happy path.
348    fn simple_tokenizer() -> Tokenizer {
349        use std::collections::HashMap;
350        use tokenizers::models::wordlevel::WordLevel;
351        use tokenizers::pre_tokenizers::whitespace::Whitespace;
352
353        let mut vocab: HashMap<String, u32> = HashMap::new();
354        for (i, w) in [
355            "[PAD]", "[UNK]", "hello", "world", "test", "text", "one", "two", "foo", "bar", "baz",
356        ]
357        .iter()
358        .enumerate()
359        {
360            vocab.insert(w.to_string(), i as u32);
361        }
362
363        let model = WordLevel::builder()
364            .vocab(vocab)
365            .unk_token("[UNK]".to_string())
366            .build()
367            .unwrap();
368
369        let mut tok = Tokenizer::new(model);
370        tok.with_pre_tokenizer(Some(Whitespace {}));
371        tok
372    }
373
374    #[test]
375    fn test_prepare_texts_with_prefix() {
376        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
377
378        let texts = vec!["Hello world".to_string(), "Test query".to_string()];
379        let prepared = processor.prepare_texts(&texts, true);
380
381        assert_eq!(prepared[0], "query: Hello world");
382        assert_eq!(prepared[1], "query: Test query");
383    }
384
385    #[test]
386    fn test_prepare_texts_no_prefix() {
387        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
388
389        let texts = vec!["Hello world".to_string()];
390        let prepared = processor.prepare_texts(&texts, true);
391
392        assert_eq!(prepared[0], "Hello world");
393    }
394
395    #[test]
396    fn test_prepare_texts_document_prefix_e5() {
397        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
398        let texts = vec!["Some document".to_string(), "Another doc".to_string()];
399        let prepared = processor.prepare_texts(&texts, false);
400        assert_eq!(prepared[0], "passage: Some document");
401        assert_eq!(prepared[1], "passage: Another doc");
402    }
403
404    #[test]
405    fn test_prepare_texts_bge_no_prefix_query() {
406        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
407        let texts = vec!["Test".to_string()];
408        let prepared = processor.prepare_texts(&texts, true);
409        assert_eq!(prepared[0], "Test");
410    }
411
412    #[test]
413    fn test_prepare_texts_bge_no_prefix_document() {
414        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
415        let texts = vec!["Doc text".to_string()];
416        let prepared = processor.prepare_texts(&texts, false);
417        assert_eq!(prepared[0], "Doc text");
418    }
419
420    #[test]
421    fn test_prepare_texts_empty_input() {
422        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
423        let texts: Vec<String> = vec![];
424        let prepared = processor.prepare_texts(&texts, true);
425        assert!(prepared.is_empty());
426    }
427
428    #[test]
429    fn test_max_batch_size() {
430        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 64);
431        assert_eq!(processor.max_batch_size(), 64);
432    }
433
434    #[test]
435    fn test_max_batch_size_default() {
436        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
437        assert_eq!(processor.max_batch_size(), 32);
438    }
439
440    #[test]
441    fn test_split_into_batches_exact_multiple() {
442        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 4);
443        let texts: Vec<String> = (0..8).map(|i| format!("text {i}")).collect();
444        let batches = processor.split_into_batches(&texts);
445        assert_eq!(batches.len(), 2);
446        assert_eq!(batches[0].len(), 4);
447        assert_eq!(batches[1].len(), 4);
448    }
449
450    #[test]
451    fn test_split_into_batches_partial_last() {
452        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 4);
453        let texts: Vec<String> = (0..6).map(|i| format!("text {i}")).collect();
454        let batches = processor.split_into_batches(&texts);
455        assert_eq!(batches.len(), 2);
456        assert_eq!(batches[0].len(), 4);
457        assert_eq!(batches[1].len(), 2);
458    }
459
460    #[test]
461    fn test_split_into_batches_smaller_than_max() {
462        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
463        let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
464        let batches = processor.split_into_batches(&texts);
465        assert_eq!(batches.len(), 1);
466        assert_eq!(batches[0].len(), 5);
467    }
468
469    #[test]
470    fn test_split_into_batches_empty() {
471        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
472        let texts: Vec<String> = vec![];
473        let batches = processor.split_into_batches(&texts);
474        assert!(batches.is_empty());
475    }
476
477    #[test]
478    fn test_split_into_batches_preserves_content() {
479        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 3);
480        let texts = vec![
481            "a".to_string(),
482            "b".to_string(),
483            "c".to_string(),
484            "d".to_string(),
485        ];
486        let batches = processor.split_into_batches(&texts);
487        assert_eq!(batches[0], &["a", "b", "c"]);
488        assert_eq!(batches[1], &["d"]);
489    }
490
491    #[test]
492    fn test_tokenize_batch_empty_error() {
493        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
494        let result = processor.tokenize_batch(&[]);
495        assert!(result.is_err());
496        let err = result.unwrap_err();
497        assert!(matches!(err, InferenceError::InvalidInput(_)));
498        assert!(err.to_string().contains("Empty text batch"));
499    }
500
501    #[test]
502    fn test_tokenize_batch_exceeds_max_size_error() {
503        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 2);
504        let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
505        let result = processor.tokenize_batch(&texts);
506        assert!(result.is_err());
507        let err = result.unwrap_err();
508        assert!(matches!(err, InferenceError::InvalidInput(_)));
509        assert!(err.to_string().contains("exceeds maximum"));
510    }
511
512    #[test]
513    fn test_tokenize_batch_exactly_at_max_size_does_not_error_before_encode() {
514        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 2);
515        let texts = vec!["text one".to_string(), "text two".to_string()];
516        let result = processor.tokenize_batch(&texts);
517        // The BPE default tokenizer may fail at encode — that is fine,
518        // what matters is it does NOT return an InvalidInput size error.
519        if let Err(InferenceError::InvalidInput(msg)) = &result {
520            assert!(
521                !msg.contains("exceeds maximum"),
522                "Batch at exactly max_size should pass size check, got: {msg}"
523            );
524        }
525    }
526
527    // ── mean_pooling tests ──────────────────────────────────────────────────
528
529    #[test]
530    fn test_mean_pooling_output_shape() {
531        // batch=2, seq_len=3, hidden=4 → should produce 2 embeddings of size 4
532        let lhs = vec![0.0f32; 2 * 3 * 4]; // all zeros
533        let mask = vec![1i64; 2 * 3]; // all active
534        let result = mean_pooling(&lhs, 2, 3, 4, &mask);
535        assert_eq!(result.len(), 2);
536        assert_eq!(result[0].len(), 4);
537        assert_eq!(result[1].len(), 4);
538    }
539
540    #[test]
541    fn test_mean_pooling_uniform_hidden_all_ones_mask() {
542        // batch=1, seq_len=4, hidden=3 — all hidden values = 2.0, all mask = 1
543        // Mean pool should return 2.0 for every dimension.
544        let lhs = vec![2.0f32; 4 * 3];
545        let mask = vec![1i64; 4];
546        let result = mean_pooling(&lhs, 1, 4, 3, &mask);
547        assert_eq!(result.len(), 1);
548        for v in &result[0] {
549            assert!((v - 2.0).abs() < 1e-5, "expected 2.0, got {v}");
550        }
551    }
552
553    #[test]
554    fn test_mean_pooling_masked_tokens_ignored() {
555        // batch=1, seq_len=2, hidden=2
556        // Token 0: hidden=[1.0, 1.0], mask=1; Token 1: hidden=[9.0, 9.0], mask=0
557        // Mean pool should give [1.0, 1.0]
558        let lhs = vec![1.0f32, 1.0, 9.0, 9.0];
559        let mask = vec![1i64, 0i64];
560        let result = mean_pooling(&lhs, 1, 2, 2, &mask);
561        assert!(
562            (result[0][0] - 1.0).abs() < 1e-5,
563            "expected 1.0, got {}",
564            result[0][0]
565        );
566        assert!(
567            (result[0][1] - 1.0).abs() < 1e-5,
568            "expected 1.0, got {}",
569            result[0][1]
570        );
571    }
572
573    #[test]
574    fn test_mean_pooling_batch_independence() {
575        // batch=2, seq_len=1, hidden=2
576        // Batch 0: hidden=[3.0, 4.0], mask=1
577        // Batch 1: hidden=[6.0, 8.0], mask=1
578        // Each should pool independently
579        let lhs = vec![3.0f32, 4.0, 6.0, 8.0];
580        let mask = vec![1i64, 1i64];
581        let result = mean_pooling(&lhs, 2, 1, 2, &mask);
582        assert_eq!(result.len(), 2);
583        assert!((result[0][0] - 3.0).abs() < 1e-5);
584        assert!((result[0][1] - 4.0).abs() < 1e-5);
585        assert!((result[1][0] - 6.0).abs() < 1e-5);
586        assert!((result[1][1] - 8.0).abs() < 1e-5);
587    }
588
589    // ── normalize_embeddings tests ──────────────────────────────────────────
590
591    #[test]
592    fn test_normalize_embeddings_unit_length() {
593        // After normalization, each row's L2 norm should be ≈ 1.0
594        // [3, 4] → L2 norm = 5.0 → normalized = [0.6, 0.8]
595        let mut embeddings = vec![vec![3.0f32, 4.0]];
596        normalize_embeddings(&mut embeddings);
597        let norm: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
598        assert!(
599            (norm - 1.0).abs() < 1e-5,
600            "L2 norm should be 1.0, got {norm}"
601        );
602    }
603
604    #[test]
605    fn test_normalize_embeddings_values() {
606        let mut embeddings = vec![vec![3.0f32, 4.0]];
607        normalize_embeddings(&mut embeddings);
608        assert!(
609            (embeddings[0][0] - 0.6).abs() < 1e-5,
610            "expected 0.6, got {}",
611            embeddings[0][0]
612        );
613        assert!(
614            (embeddings[0][1] - 0.8).abs() < 1e-5,
615            "expected 0.8, got {}",
616            embeddings[0][1]
617        );
618    }
619
620    #[test]
621    fn test_normalize_embeddings_batch() {
622        // Multiple rows — each should be independently normalized
623        let mut embeddings = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
624        normalize_embeddings(&mut embeddings);
625        let norm0: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
626        let norm1: f32 = embeddings[1].iter().map(|x| x * x).sum::<f32>().sqrt();
627        assert!((norm0 - 1.0).abs() < 1e-5);
628        assert!((norm1 - 1.0).abs() < 1e-5);
629    }
630
631    #[test]
632    fn test_normalize_embeddings_output_shape() {
633        let mut embeddings: Vec<Vec<f32>> = (1..=3)
634            .map(|i| (1..=4).map(|j| (i * j) as f32).collect())
635            .collect();
636        normalize_embeddings(&mut embeddings);
637        assert_eq!(embeddings.len(), 3);
638        assert!(embeddings.iter().all(|v| v.len() == 4));
639    }
640
641    #[test]
642    fn test_normalize_embeddings_near_zero_safe() {
643        // Near-zero vector should not produce NaN/Inf due to clamp
644        let mut embeddings = vec![vec![1e-14f32, 1e-14]];
645        normalize_embeddings(&mut embeddings);
646        for v in &embeddings[0] {
647            assert!(v.is_finite(), "expected finite value, got {v}");
648        }
649    }
650
651    // ── tokenize_batch happy-path (WordLevel tokenizer) ──────────────────────
652
653    #[test]
654    fn test_tokenize_batch_single_text_success() {
655        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
656        let texts = vec!["hello world".to_string()];
657        let result = processor.tokenize_batch(&texts);
658        assert!(result.is_ok(), "Expected Ok, got {:?}", result);
659        let batch = result.unwrap();
660        assert_eq!(batch.batch_size, 1);
661        assert_eq!(batch.original_lengths, vec![11]);
662    }
663
664    #[test]
665    fn test_tokenize_batch_tensor_shapes_single() {
666        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
667        let texts = vec!["hello world".to_string()];
668        let batch = processor.tokenize_batch(&texts).unwrap();
669        assert_eq!(batch.batch_size, 1);
670        assert_eq!(batch.input_ids.len(), batch.batch_size * batch.seq_len);
671        assert_eq!(batch.attention_mask.len(), batch.batch_size * batch.seq_len);
672        assert_eq!(batch.token_type_ids.len(), batch.batch_size * batch.seq_len);
673    }
674
675    #[test]
676    fn test_tokenize_batch_multiple_texts_batch_dim() {
677        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
678        let texts = vec!["hello".to_string(), "hello world test".to_string()];
679        let batch = processor.tokenize_batch(&texts).unwrap();
680        assert_eq!(batch.batch_size, 2);
681        assert_eq!(batch.original_lengths.len(), 2);
682        assert_eq!(batch.input_ids.len(), batch.batch_size * batch.seq_len);
683    }
684
685    #[test]
686    fn test_tokenize_batch_token_type_ids_default_zeros() {
687        // WordLevel tokenizer returns no type_ids → code fills with zeros
688        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
689        let texts = vec!["hello world".to_string()];
690        let batch = processor.tokenize_batch(&texts).unwrap();
691        for &v in &batch.token_type_ids {
692            assert_eq!(v, 0, "Expected zero token_type_id from WordLevel, got {v}");
693        }
694    }
695
696    #[test]
697    fn test_tokenize_batch_original_lengths_preserved() {
698        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
699        let texts = vec!["hello".to_string(), "hello world".to_string()];
700        let batch = processor.tokenize_batch(&texts).unwrap();
701        assert_eq!(batch.original_lengths[0], 5);
702        assert_eq!(batch.original_lengths[1], 11);
703    }
704
705    #[test]
706    fn test_tokenize_batch_three_texts_batch_size_field() {
707        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
708        let texts = vec!["hello".to_string(), "world".to_string(), "test".to_string()];
709        let batch = processor.tokenize_batch(&texts).unwrap();
710        assert_eq!(batch.batch_size, 3);
711    }
712
713    #[test]
714    fn test_tokenize_batch_all_arrays_consistent_length() {
715        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
716        let texts = vec!["foo bar".to_string(), "baz".to_string()];
717        let batch = processor.tokenize_batch(&texts).unwrap();
718        let expected_len = batch.batch_size * batch.seq_len;
719        assert_eq!(batch.input_ids.len(), expected_len);
720        assert_eq!(batch.attention_mask.len(), expected_len);
721        assert_eq!(batch.token_type_ids.len(), expected_len);
722    }
723
724    #[test]
725    fn test_tokenize_batch_ids_are_i64() {
726        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
727        let texts = vec!["hello world".to_string()];
728        let batch = processor.tokenize_batch(&texts).unwrap();
729        // Verify all IDs are non-negative i64 (u32 upcast)
730        for &id in &batch.input_ids {
731            assert!(id >= 0, "input_id should be non-negative, got {id}");
732        }
733        for &m in &batch.attention_mask {
734            assert!(m == 0 || m == 1, "attention_mask should be 0 or 1, got {m}");
735        }
736    }
737
738    // ── TokenBudgetBatcher ───────────────────────────────────────────────────
739
740    /// Helper: use exact token counts so tests are deterministic.
741    fn exact_batcher(budget: usize) -> TokenBudgetBatcher {
742        TokenBudgetBatcher::new(budget).with_token_fn(|text| text.len())
743    }
744
745    #[test]
746    fn test_token_budget_batcher_empty_finish() {
747        let mut batcher = exact_batcher(100);
748        let batches = batcher.finish();
749        assert!(batches.is_empty());
750    }
751
752    #[test]
753    fn test_token_budget_batcher_single_text_single_batch() {
754        let mut batcher = exact_batcher(100);
755        batcher.push("hello".to_string()); // 5 tokens
756        let batches = batcher.finish();
757        assert_eq!(batches.len(), 1);
758        assert_eq!(batches[0], vec!["hello".to_string()]);
759    }
760
761    #[test]
762    fn test_token_budget_batcher_fits_small_texts_in_one_batch() {
763        let mut batcher = exact_batcher(50);
764        for i in 0..5 {
765            batcher.push(format!("t{i}")); // 2 tokens each → 10 total, fits in 50
766        }
767        let batches = batcher.finish();
768        assert_eq!(batches.len(), 1);
769        assert_eq!(batches[0].len(), 5);
770    }
771
772    #[test]
773    fn test_token_budget_batcher_splits_on_budget_exceeded() {
774        // budget=10; first 5 texts of 2 tokens each = 10 → 5th text stays in batch
775        // 6th text of 2 tokens → would exceed 12, so flush first
776        let mut batcher = exact_batcher(10);
777        for _ in 0..5 {
778            batcher.push("ab".to_string()); // 2 tokens
779        }
780        // Now at budget exactly; push one more
781        batcher.push("cd".to_string()); // 2 tokens
782        let batches = batcher.finish();
783        // First batch: 5 × 2 = 10 tokens (fits exactly)
784        // Second batch: "cd"
785        assert_eq!(batches.len(), 2);
786        assert_eq!(batches[0].len(), 5);
787        assert_eq!(batches[1].len(), 1);
788    }
789
790    #[test]
791    fn test_token_budget_batcher_large_single_text_gets_own_batch() {
792        let mut batcher = exact_batcher(10);
793        batcher.push("small".to_string()); // 5 tokens
794        batcher.push("a".repeat(50)); // 50 tokens > budget → flushes "small", starts new batch
795        let batches = batcher.finish();
796        assert_eq!(batches.len(), 2);
797        assert_eq!(batches[0][0], "small");
798    }
799
800    #[test]
801    fn test_token_budget_batcher_finish_resets_state() {
802        let mut batcher = exact_batcher(100);
803        batcher.push("hello".to_string());
804        let _first = batcher.finish();
805        batcher.push("world".to_string());
806        let second = batcher.finish();
807        assert_eq!(second.len(), 1);
808        assert_eq!(second[0][0], "world");
809    }
810
811    #[test]
812    fn test_token_budget_batcher_push_all() {
813        let mut batcher = exact_batcher(100);
814        batcher.push_all(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
815        let batches = batcher.finish();
816        assert_eq!(batches.len(), 1);
817        assert_eq!(batches[0].len(), 3);
818    }
819
820    #[test]
821    fn test_token_budget_batcher_pending_count() {
822        let mut batcher = exact_batcher(100);
823        assert_eq!(batcher.pending_count(), 0);
824        batcher.push("hello".to_string());
825        assert_eq!(batcher.pending_count(), 1);
826        batcher.push("world".to_string());
827        assert_eq!(batcher.pending_count(), 2);
828    }
829
830    // ── truncate_mrl tests ───────────────────────────────────────────────────
831
832    #[test]
833    fn test_mrl_truncation_basic() {
834        let embedding = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
835        let truncated = truncate_mrl(&embedding, 4).unwrap();
836        assert_eq!(truncated.len(), 4);
837    }
838
839    #[test]
840    fn test_mrl_truncation_normalized() {
841        let embedding = vec![3.0f32, 4.0, 0.0, 0.0];
842        let truncated = truncate_mrl(&embedding, 2).unwrap();
843        // [3, 4] norm = 5 → [0.6, 0.8]
844        assert!((truncated[0] - 0.6).abs() < 1e-5);
845        assert!((truncated[1] - 0.8).abs() < 1e-5);
846    }
847
848    #[test]
849    fn test_mrl_truncation_256_from_1024() {
850        let embedding: Vec<f32> = (0..1024).map(|i| i as f32).collect();
851        let truncated = truncate_mrl(&embedding, 256).unwrap();
852        assert_eq!(truncated.len(), 256);
853        // L2 norm should be ~1.0
854        let norm: f32 = truncated.iter().map(|x| x * x).sum::<f32>().sqrt();
855        assert!((norm - 1.0).abs() < 1e-4, "norm={norm}");
856    }
857
858    #[test]
859    fn test_mrl_truncation_full_dimension_is_noop_shape() {
860        let embedding = vec![0.0f32; 1024];
861        // Near-zero → won't change direction but shape is preserved
862        let truncated = truncate_mrl(&embedding, 1024).unwrap();
863        assert_eq!(truncated.len(), 1024);
864    }
865
866    #[test]
867    fn test_mrl_truncation_zero_target_dim_error() {
868        let embedding = vec![1.0f32; 10];
869        let result = truncate_mrl(&embedding, 0);
870        assert!(result.is_err());
871    }
872
873    #[test]
874    fn test_mrl_truncation_target_exceeds_length_error() {
875        let embedding = vec![1.0f32; 4];
876        let result = truncate_mrl(&embedding, 5);
877        assert!(result.is_err());
878    }
879
880    #[test]
881    fn test_mrl_preserves_semantic_direction() {
882        // MRL property: truncated+renormed embedding should point in the same direction
883        // as the full embedding. Mathematically, dot(truncated_renormed, full_first_256)
884        // equals the L2 norm of the first target_dim slice of the full embedding.
885        // For the assertion to reach >0.9, the first 256 dims must contain >81% of
886        // total squared norm — here achieved by zeroing dims 256..1024.
887        let mut embedding: Vec<f32> = (0..1024)
888            .map(|i| if i < 256 { (i % 16) as f32 + 1.0 } else { 0.0 })
889            .collect();
890        let norm: f32 = embedding
891            .iter()
892            .map(|x| x * x)
893            .sum::<f32>()
894            .sqrt()
895            .max(1e-12);
896        for v in embedding.iter_mut() {
897            *v /= norm;
898        }
899        let truncated = truncate_mrl(&embedding, 256).unwrap();
900        // dot(truncated_renormed, first_256_of_unit_embedding) = partial_norm = 1.0 here
901        let dot: f32 = truncated
902            .iter()
903            .zip(embedding.iter().take(256))
904            .map(|(a, b)| a * b)
905            .sum();
906        assert!(dot > 0.9, "cosine similarity {dot} should be >0.9");
907    }
908}