Skip to main content

inference/
batch.rs

1//! Batch processing utilities for efficient embedding generation.
2
3use crate::error::{InferenceError, Result};
4use crate::models::EmbeddingModel;
5use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
6use tracing::{debug, instrument};
7
8/// Prepared batch of tokenized inputs ready for ORT inference.
9///
10/// All token arrays are flat (row-major) with shape `[batch_size, seq_len]`.
11/// Values are `i64` because ONNX Runtime BERT models expect `int64` inputs.
12#[derive(Debug)]
13pub struct PreparedBatch {
14    /// Input token IDs, flat `[batch_size * seq_len]`, i64
15    pub input_ids: Vec<i64>,
16    /// Attention mask, flat `[batch_size * seq_len]`, i64
17    pub attention_mask: Vec<i64>,
18    /// Token type IDs, flat `[batch_size * seq_len]`, i64
19    pub token_type_ids: Vec<i64>,
20    /// Number of items in this batch
21    pub batch_size: usize,
22    /// Sequence length (uniform after padding)
23    pub seq_len: usize,
24    /// Original text lengths (for debugging)
25    pub original_lengths: Vec<usize>,
26}
27
28/// Batch processor for preparing text inputs for embedding models.
29pub struct BatchProcessor {
30    tokenizer: Tokenizer,
31    model: EmbeddingModel,
32    max_batch_size: usize,
33}
34
35impl BatchProcessor {
36    /// Create a new batch processor.
37    pub fn new(mut tokenizer: Tokenizer, model: EmbeddingModel, max_batch_size: usize) -> Self {
38        // Configure padding
39        let padding = PaddingParams {
40            strategy: PaddingStrategy::BatchLongest,
41            pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
42            pad_token: tokenizer
43                .get_padding()
44                .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
45            ..Default::default()
46        };
47        tokenizer.with_padding(Some(padding));
48
49        // Configure truncation
50        let truncation = TruncationParams {
51            max_length: model.max_seq_length(),
52            ..Default::default()
53        };
54        let _ = tokenizer.with_truncation(Some(truncation));
55
56        Self {
57            tokenizer,
58            model,
59            max_batch_size,
60        }
61    }
62
63    /// Get the maximum batch size.
64    pub fn max_batch_size(&self) -> usize {
65        self.max_batch_size
66    }
67
68    /// Prepare texts for embedding, optionally applying model-specific prefixes.
69    #[instrument(skip(self, texts), fields(count = texts.len()))]
70    pub fn prepare_texts(&self, texts: &[String], is_query: bool) -> Vec<String> {
71        let prefix = if is_query {
72            self.model.query_prefix()
73        } else {
74            self.model.document_prefix()
75        };
76
77        match prefix {
78            Some(p) => texts.iter().map(|t| format!("{}{}", p, t)).collect(),
79            None => texts.to_vec(),
80        }
81    }
82
83    /// Tokenize a batch of texts and prepare flat i64 arrays for ORT inference.
84    #[instrument(skip(self, texts), fields(count = texts.len()))]
85    pub fn tokenize_batch(&self, texts: &[String]) -> Result<PreparedBatch> {
86        if texts.is_empty() {
87            return Err(InferenceError::InvalidInput("Empty text batch".into()));
88        }
89
90        if texts.len() > self.max_batch_size {
91            return Err(InferenceError::InvalidInput(format!(
92                "Batch size {} exceeds maximum {}",
93                texts.len(),
94                self.max_batch_size
95            )));
96        }
97
98        let original_lengths: Vec<usize> = texts.iter().map(|t| t.len()).collect();
99
100        debug!(
101            "Tokenizing {} texts, max length: {}",
102            texts.len(),
103            original_lengths.iter().max().unwrap_or(&0)
104        );
105
106        // Tokenize all texts
107        let encodings = self
108            .tokenizer
109            .encode_batch(texts.to_vec(), true)
110            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
111
112        let batch_size = encodings.len();
113        let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
114
115        debug!("Tokenized: batch_size={}, seq_len={}", batch_size, seq_len);
116
117        // Extract and flatten as i64 (ORT BERT models require int64)
118        let mut input_ids = Vec::with_capacity(batch_size * seq_len);
119        let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
120        let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
121
122        for enc in &encodings {
123            input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
124            attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
125
126            let type_ids = enc.get_type_ids();
127            if type_ids.is_empty() {
128                token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
129            } else {
130                token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
131            }
132        }
133
134        Ok(PreparedBatch {
135            input_ids,
136            attention_mask,
137            token_type_ids,
138            batch_size,
139            seq_len,
140            original_lengths,
141        })
142    }
143
144    /// Split texts into batches of maximum size.
145    pub fn split_into_batches<'a>(&self, texts: &'a [String]) -> Vec<&'a [String]> {
146        texts.chunks(self.max_batch_size).collect()
147    }
148}
149
150/// Apply mean pooling to ORT last-hidden-state output.
151///
152/// `last_hidden_state` is a flat row-major slice with shape `[batch, seq_len, hidden_size]`.
153/// `attention_mask` is flat `[batch * seq_len]` with i64 values (0 or 1).
154///
155/// Returns `Vec<Vec<f32>>` with shape `[batch, hidden_size]`.
156#[instrument(skip_all, fields(batch_size, seq_len, hidden_size))]
157pub fn mean_pooling(
158    last_hidden_state: &[f32],
159    batch_size: usize,
160    seq_len: usize,
161    hidden_size: usize,
162    attention_mask: &[i64],
163) -> Vec<Vec<f32>> {
164    let mut result = vec![vec![0.0f32; hidden_size]; batch_size];
165
166    for b in 0..batch_size {
167        // Sum of mask weights for this batch item
168        let mask_sum: f32 = (0..seq_len)
169            .map(|s| attention_mask[b * seq_len + s] as f32)
170            .sum::<f32>()
171            .max(1e-9);
172
173        for (h, cell) in result[b].iter_mut().enumerate() {
174            let weighted_sum: f32 = (0..seq_len)
175                .map(|s| {
176                    let lhs_idx = b * seq_len * hidden_size + s * hidden_size + h;
177                    last_hidden_state[lhs_idx] * attention_mask[b * seq_len + s] as f32
178                })
179                .sum();
180            *cell = weighted_sum / mask_sum;
181        }
182    }
183
184    debug!(
185        "Mean pooled: batch={}, hidden={}",
186        result.len(),
187        result.first().map(|v| v.len()).unwrap_or(0)
188    );
189
190    result
191}
192
193/// Normalize embeddings to unit length (L2 normalization), in-place.
194#[instrument(skip_all, fields(count = embeddings.len()))]
195pub fn normalize_embeddings(embeddings: &mut [Vec<f32>]) {
196    for emb in embeddings.iter_mut() {
197        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
198        for v in emb.iter_mut() {
199            *v /= norm;
200        }
201    }
202    debug!("Normalized {} embeddings", embeddings.len());
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    /// Create a minimal tokenizer for unit tests (no network required).
210    /// `prepare_texts` only uses the model's prefix logic, not the tokenizer,
211    /// so any valid tokenizer works here.
212    fn dummy_tokenizer() -> Tokenizer {
213        use tokenizers::models::bpe::BPE;
214        Tokenizer::new(BPE::default())
215    }
216
217    /// Create a WordLevel tokenizer with a small known vocabulary.
218    /// Unlike BPE::default(), this tokenizer can actually encode words,
219    /// enabling tests of the tokenize_batch happy path.
220    fn simple_tokenizer() -> Tokenizer {
221        use std::collections::HashMap;
222        use tokenizers::models::wordlevel::WordLevel;
223        use tokenizers::pre_tokenizers::whitespace::Whitespace;
224
225        let mut vocab: HashMap<String, u32> = HashMap::new();
226        for (i, w) in [
227            "[PAD]", "[UNK]", "hello", "world", "test", "text", "one", "two", "foo", "bar", "baz",
228        ]
229        .iter()
230        .enumerate()
231        {
232            vocab.insert(w.to_string(), i as u32);
233        }
234
235        let model = WordLevel::builder()
236            .vocab(vocab)
237            .unk_token("[UNK]".to_string())
238            .build()
239            .unwrap();
240
241        let mut tok = Tokenizer::new(model);
242        tok.with_pre_tokenizer(Some(Whitespace {}));
243        tok
244    }
245
246    #[test]
247    fn test_prepare_texts_with_prefix() {
248        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
249
250        let texts = vec!["Hello world".to_string(), "Test query".to_string()];
251        let prepared = processor.prepare_texts(&texts, true);
252
253        assert_eq!(prepared[0], "query: Hello world");
254        assert_eq!(prepared[1], "query: Test query");
255    }
256
257    #[test]
258    fn test_prepare_texts_no_prefix() {
259        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
260
261        let texts = vec!["Hello world".to_string()];
262        let prepared = processor.prepare_texts(&texts, true);
263
264        assert_eq!(prepared[0], "Hello world");
265    }
266
267    #[test]
268    fn test_prepare_texts_document_prefix_e5() {
269        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
270        let texts = vec!["Some document".to_string(), "Another doc".to_string()];
271        let prepared = processor.prepare_texts(&texts, false);
272        assert_eq!(prepared[0], "passage: Some document");
273        assert_eq!(prepared[1], "passage: Another doc");
274    }
275
276    #[test]
277    fn test_prepare_texts_bge_no_prefix_query() {
278        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
279        let texts = vec!["Test".to_string()];
280        let prepared = processor.prepare_texts(&texts, true);
281        assert_eq!(prepared[0], "Test");
282    }
283
284    #[test]
285    fn test_prepare_texts_bge_no_prefix_document() {
286        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
287        let texts = vec!["Doc text".to_string()];
288        let prepared = processor.prepare_texts(&texts, false);
289        assert_eq!(prepared[0], "Doc text");
290    }
291
292    #[test]
293    fn test_prepare_texts_empty_input() {
294        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
295        let texts: Vec<String> = vec![];
296        let prepared = processor.prepare_texts(&texts, true);
297        assert!(prepared.is_empty());
298    }
299
300    #[test]
301    fn test_max_batch_size() {
302        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 64);
303        assert_eq!(processor.max_batch_size(), 64);
304    }
305
306    #[test]
307    fn test_max_batch_size_default() {
308        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
309        assert_eq!(processor.max_batch_size(), 32);
310    }
311
312    #[test]
313    fn test_split_into_batches_exact_multiple() {
314        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 4);
315        let texts: Vec<String> = (0..8).map(|i| format!("text {i}")).collect();
316        let batches = processor.split_into_batches(&texts);
317        assert_eq!(batches.len(), 2);
318        assert_eq!(batches[0].len(), 4);
319        assert_eq!(batches[1].len(), 4);
320    }
321
322    #[test]
323    fn test_split_into_batches_partial_last() {
324        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 4);
325        let texts: Vec<String> = (0..6).map(|i| format!("text {i}")).collect();
326        let batches = processor.split_into_batches(&texts);
327        assert_eq!(batches.len(), 2);
328        assert_eq!(batches[0].len(), 4);
329        assert_eq!(batches[1].len(), 2);
330    }
331
332    #[test]
333    fn test_split_into_batches_smaller_than_max() {
334        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
335        let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
336        let batches = processor.split_into_batches(&texts);
337        assert_eq!(batches.len(), 1);
338        assert_eq!(batches[0].len(), 5);
339    }
340
341    #[test]
342    fn test_split_into_batches_empty() {
343        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
344        let texts: Vec<String> = vec![];
345        let batches = processor.split_into_batches(&texts);
346        assert!(batches.is_empty());
347    }
348
349    #[test]
350    fn test_split_into_batches_preserves_content() {
351        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 3);
352        let texts = vec![
353            "a".to_string(),
354            "b".to_string(),
355            "c".to_string(),
356            "d".to_string(),
357        ];
358        let batches = processor.split_into_batches(&texts);
359        assert_eq!(batches[0], &["a", "b", "c"]);
360        assert_eq!(batches[1], &["d"]);
361    }
362
363    #[test]
364    fn test_tokenize_batch_empty_error() {
365        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
366        let result = processor.tokenize_batch(&[]);
367        assert!(result.is_err());
368        let err = result.unwrap_err();
369        assert!(matches!(err, InferenceError::InvalidInput(_)));
370        assert!(err.to_string().contains("Empty text batch"));
371    }
372
373    #[test]
374    fn test_tokenize_batch_exceeds_max_size_error() {
375        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 2);
376        let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
377        let result = processor.tokenize_batch(&texts);
378        assert!(result.is_err());
379        let err = result.unwrap_err();
380        assert!(matches!(err, InferenceError::InvalidInput(_)));
381        assert!(err.to_string().contains("exceeds maximum"));
382    }
383
384    #[test]
385    fn test_tokenize_batch_exactly_at_max_size_does_not_error_before_encode() {
386        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 2);
387        let texts = vec!["text one".to_string(), "text two".to_string()];
388        let result = processor.tokenize_batch(&texts);
389        // The BPE default tokenizer may fail at encode — that is fine,
390        // what matters is it does NOT return an InvalidInput size error.
391        if let Err(InferenceError::InvalidInput(msg)) = &result {
392            assert!(
393                !msg.contains("exceeds maximum"),
394                "Batch at exactly max_size should pass size check, got: {msg}"
395            );
396        }
397    }
398
399    // ── mean_pooling tests ──────────────────────────────────────────────────
400
401    #[test]
402    fn test_mean_pooling_output_shape() {
403        // batch=2, seq_len=3, hidden=4 → should produce 2 embeddings of size 4
404        let lhs = vec![0.0f32; 2 * 3 * 4]; // all zeros
405        let mask = vec![1i64; 2 * 3]; // all active
406        let result = mean_pooling(&lhs, 2, 3, 4, &mask);
407        assert_eq!(result.len(), 2);
408        assert_eq!(result[0].len(), 4);
409        assert_eq!(result[1].len(), 4);
410    }
411
412    #[test]
413    fn test_mean_pooling_uniform_hidden_all_ones_mask() {
414        // batch=1, seq_len=4, hidden=3 — all hidden values = 2.0, all mask = 1
415        // Mean pool should return 2.0 for every dimension.
416        let lhs = vec![2.0f32; 1 * 4 * 3];
417        let mask = vec![1i64; 1 * 4];
418        let result = mean_pooling(&lhs, 1, 4, 3, &mask);
419        assert_eq!(result.len(), 1);
420        for v in &result[0] {
421            assert!((v - 2.0).abs() < 1e-5, "expected 2.0, got {v}");
422        }
423    }
424
425    #[test]
426    fn test_mean_pooling_masked_tokens_ignored() {
427        // batch=1, seq_len=2, hidden=2
428        // Token 0: hidden=[1.0, 1.0], mask=1; Token 1: hidden=[9.0, 9.0], mask=0
429        // Mean pool should give [1.0, 1.0]
430        let lhs = vec![1.0f32, 1.0, 9.0, 9.0];
431        let mask = vec![1i64, 0i64];
432        let result = mean_pooling(&lhs, 1, 2, 2, &mask);
433        assert!(
434            (result[0][0] - 1.0).abs() < 1e-5,
435            "expected 1.0, got {}",
436            result[0][0]
437        );
438        assert!(
439            (result[0][1] - 1.0).abs() < 1e-5,
440            "expected 1.0, got {}",
441            result[0][1]
442        );
443    }
444
445    #[test]
446    fn test_mean_pooling_batch_independence() {
447        // batch=2, seq_len=1, hidden=2
448        // Batch 0: hidden=[3.0, 4.0], mask=1
449        // Batch 1: hidden=[6.0, 8.0], mask=1
450        // Each should pool independently
451        let lhs = vec![3.0f32, 4.0, 6.0, 8.0];
452        let mask = vec![1i64, 1i64];
453        let result = mean_pooling(&lhs, 2, 1, 2, &mask);
454        assert_eq!(result.len(), 2);
455        assert!((result[0][0] - 3.0).abs() < 1e-5);
456        assert!((result[0][1] - 4.0).abs() < 1e-5);
457        assert!((result[1][0] - 6.0).abs() < 1e-5);
458        assert!((result[1][1] - 8.0).abs() < 1e-5);
459    }
460
461    // ── normalize_embeddings tests ──────────────────────────────────────────
462
463    #[test]
464    fn test_normalize_embeddings_unit_length() {
465        // After normalization, each row's L2 norm should be ≈ 1.0
466        // [3, 4] → L2 norm = 5.0 → normalized = [0.6, 0.8]
467        let mut embeddings = vec![vec![3.0f32, 4.0]];
468        normalize_embeddings(&mut embeddings);
469        let norm: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
470        assert!(
471            (norm - 1.0).abs() < 1e-5,
472            "L2 norm should be 1.0, got {norm}"
473        );
474    }
475
476    #[test]
477    fn test_normalize_embeddings_values() {
478        let mut embeddings = vec![vec![3.0f32, 4.0]];
479        normalize_embeddings(&mut embeddings);
480        assert!(
481            (embeddings[0][0] - 0.6).abs() < 1e-5,
482            "expected 0.6, got {}",
483            embeddings[0][0]
484        );
485        assert!(
486            (embeddings[0][1] - 0.8).abs() < 1e-5,
487            "expected 0.8, got {}",
488            embeddings[0][1]
489        );
490    }
491
492    #[test]
493    fn test_normalize_embeddings_batch() {
494        // Multiple rows — each should be independently normalized
495        let mut embeddings = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
496        normalize_embeddings(&mut embeddings);
497        let norm0: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
498        let norm1: f32 = embeddings[1].iter().map(|x| x * x).sum::<f32>().sqrt();
499        assert!((norm0 - 1.0).abs() < 1e-5);
500        assert!((norm1 - 1.0).abs() < 1e-5);
501    }
502
503    #[test]
504    fn test_normalize_embeddings_output_shape() {
505        let mut embeddings: Vec<Vec<f32>> = (1..=3)
506            .map(|i| (1..=4).map(|j| (i * j) as f32).collect())
507            .collect();
508        normalize_embeddings(&mut embeddings);
509        assert_eq!(embeddings.len(), 3);
510        assert!(embeddings.iter().all(|v| v.len() == 4));
511    }
512
513    #[test]
514    fn test_normalize_embeddings_near_zero_safe() {
515        // Near-zero vector should not produce NaN/Inf due to clamp
516        let mut embeddings = vec![vec![1e-14f32, 1e-14]];
517        normalize_embeddings(&mut embeddings);
518        for v in &embeddings[0] {
519            assert!(v.is_finite(), "expected finite value, got {v}");
520        }
521    }
522
523    // ── tokenize_batch happy-path (WordLevel tokenizer) ──────────────────────
524
525    #[test]
526    fn test_tokenize_batch_single_text_success() {
527        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
528        let texts = vec!["hello world".to_string()];
529        let result = processor.tokenize_batch(&texts);
530        assert!(result.is_ok(), "Expected Ok, got {:?}", result);
531        let batch = result.unwrap();
532        assert_eq!(batch.batch_size, 1);
533        assert_eq!(batch.original_lengths, vec![11]);
534    }
535
536    #[test]
537    fn test_tokenize_batch_tensor_shapes_single() {
538        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
539        let texts = vec!["hello world".to_string()];
540        let batch = processor.tokenize_batch(&texts).unwrap();
541        assert_eq!(batch.batch_size, 1);
542        assert_eq!(batch.input_ids.len(), batch.batch_size * batch.seq_len);
543        assert_eq!(batch.attention_mask.len(), batch.batch_size * batch.seq_len);
544        assert_eq!(batch.token_type_ids.len(), batch.batch_size * batch.seq_len);
545    }
546
547    #[test]
548    fn test_tokenize_batch_multiple_texts_batch_dim() {
549        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
550        let texts = vec!["hello".to_string(), "hello world test".to_string()];
551        let batch = processor.tokenize_batch(&texts).unwrap();
552        assert_eq!(batch.batch_size, 2);
553        assert_eq!(batch.original_lengths.len(), 2);
554        assert_eq!(batch.input_ids.len(), batch.batch_size * batch.seq_len);
555    }
556
557    #[test]
558    fn test_tokenize_batch_token_type_ids_default_zeros() {
559        // WordLevel tokenizer returns no type_ids → code fills with zeros
560        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
561        let texts = vec!["hello world".to_string()];
562        let batch = processor.tokenize_batch(&texts).unwrap();
563        for &v in &batch.token_type_ids {
564            assert_eq!(v, 0, "Expected zero token_type_id from WordLevel, got {v}");
565        }
566    }
567
568    #[test]
569    fn test_tokenize_batch_original_lengths_preserved() {
570        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
571        let texts = vec!["hello".to_string(), "hello world".to_string()];
572        let batch = processor.tokenize_batch(&texts).unwrap();
573        assert_eq!(batch.original_lengths[0], 5);
574        assert_eq!(batch.original_lengths[1], 11);
575    }
576
577    #[test]
578    fn test_tokenize_batch_three_texts_batch_size_field() {
579        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
580        let texts = vec!["hello".to_string(), "world".to_string(), "test".to_string()];
581        let batch = processor.tokenize_batch(&texts).unwrap();
582        assert_eq!(batch.batch_size, 3);
583    }
584
585    #[test]
586    fn test_tokenize_batch_all_arrays_consistent_length() {
587        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
588        let texts = vec!["foo bar".to_string(), "baz".to_string()];
589        let batch = processor.tokenize_batch(&texts).unwrap();
590        let expected_len = batch.batch_size * batch.seq_len;
591        assert_eq!(batch.input_ids.len(), expected_len);
592        assert_eq!(batch.attention_mask.len(), expected_len);
593        assert_eq!(batch.token_type_ids.len(), expected_len);
594    }
595
596    #[test]
597    fn test_tokenize_batch_ids_are_i64() {
598        let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
599        let texts = vec!["hello world".to_string()];
600        let batch = processor.tokenize_batch(&texts).unwrap();
601        // Verify all IDs are non-negative i64 (u32 upcast)
602        for &id in &batch.input_ids {
603            assert!(id >= 0, "input_id should be non-negative, got {id}");
604        }
605        for &m in &batch.attention_mask {
606            assert!(m == 0 || m == 1, "attention_mask should be 0 or 1, got {m}");
607        }
608    }
609}