1use crate::error::{InferenceError, Result};
4use crate::models::EmbeddingModel;
5use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
6use tracing::{debug, instrument};
7
8#[derive(Debug)]
13pub struct PreparedBatch {
14 pub input_ids: Vec<i64>,
16 pub attention_mask: Vec<i64>,
18 pub token_type_ids: Vec<i64>,
20 pub batch_size: usize,
22 pub seq_len: usize,
24 pub original_lengths: Vec<usize>,
26}
27
28pub struct BatchProcessor {
30 tokenizer: Tokenizer,
31 model: EmbeddingModel,
32 max_batch_size: usize,
33}
34
35impl BatchProcessor {
36 pub fn new(mut tokenizer: Tokenizer, model: EmbeddingModel, max_batch_size: usize) -> Self {
38 let padding = PaddingParams {
40 strategy: PaddingStrategy::BatchLongest,
41 pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
42 pad_token: tokenizer
43 .get_padding()
44 .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
45 ..Default::default()
46 };
47 tokenizer.with_padding(Some(padding));
48
49 let truncation = TruncationParams {
51 max_length: model.max_seq_length(),
52 ..Default::default()
53 };
54 let _ = tokenizer.with_truncation(Some(truncation));
55
56 Self {
57 tokenizer,
58 model,
59 max_batch_size,
60 }
61 }
62
63 pub fn max_batch_size(&self) -> usize {
65 self.max_batch_size
66 }
67
68 #[instrument(skip(self, texts), fields(count = texts.len()))]
70 pub fn prepare_texts(&self, texts: &[String], is_query: bool) -> Vec<String> {
71 let prefix = if is_query {
72 self.model.query_prefix()
73 } else {
74 self.model.document_prefix()
75 };
76
77 match prefix {
78 Some(p) => texts.iter().map(|t| format!("{}{}", p, t)).collect(),
79 None => texts.to_vec(),
80 }
81 }
82
83 #[instrument(skip(self, texts), fields(count = texts.len()))]
85 pub fn tokenize_batch(&self, texts: &[String]) -> Result<PreparedBatch> {
86 if texts.is_empty() {
87 return Err(InferenceError::InvalidInput("Empty text batch".into()));
88 }
89
90 if texts.len() > self.max_batch_size {
91 return Err(InferenceError::InvalidInput(format!(
92 "Batch size {} exceeds maximum {}",
93 texts.len(),
94 self.max_batch_size
95 )));
96 }
97
98 let original_lengths: Vec<usize> = texts.iter().map(|t| t.len()).collect();
99
100 debug!(
101 "Tokenizing {} texts, max length: {}",
102 texts.len(),
103 original_lengths.iter().max().unwrap_or(&0)
104 );
105
106 let encodings = self
108 .tokenizer
109 .encode_batch(texts.to_vec(), true)
110 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
111
112 let batch_size = encodings.len();
113 let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
114
115 debug!("Tokenized: batch_size={}, seq_len={}", batch_size, seq_len);
116
117 let mut input_ids = Vec::with_capacity(batch_size * seq_len);
119 let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
120 let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
121
122 for enc in &encodings {
123 input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
124 attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
125
126 let type_ids = enc.get_type_ids();
127 if type_ids.is_empty() {
128 token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
129 } else {
130 token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
131 }
132 }
133
134 Ok(PreparedBatch {
135 input_ids,
136 attention_mask,
137 token_type_ids,
138 batch_size,
139 seq_len,
140 original_lengths,
141 })
142 }
143
144 pub fn split_into_batches<'a>(&self, texts: &'a [String]) -> Vec<&'a [String]> {
146 texts.chunks(self.max_batch_size).collect()
147 }
148}
149
150#[instrument(skip_all, fields(batch_size, seq_len, hidden_size))]
157pub fn mean_pooling(
158 last_hidden_state: &[f32],
159 batch_size: usize,
160 seq_len: usize,
161 hidden_size: usize,
162 attention_mask: &[i64],
163) -> Vec<Vec<f32>> {
164 let mut result = vec![vec![0.0f32; hidden_size]; batch_size];
165
166 for b in 0..batch_size {
167 let mask_sum: f32 = (0..seq_len)
169 .map(|s| attention_mask[b * seq_len + s] as f32)
170 .sum::<f32>()
171 .max(1e-9);
172
173 for (h, cell) in result[b].iter_mut().enumerate() {
174 let weighted_sum: f32 = (0..seq_len)
175 .map(|s| {
176 let lhs_idx = b * seq_len * hidden_size + s * hidden_size + h;
177 last_hidden_state[lhs_idx] * attention_mask[b * seq_len + s] as f32
178 })
179 .sum();
180 *cell = weighted_sum / mask_sum;
181 }
182 }
183
184 debug!(
185 "Mean pooled: batch={}, hidden={}",
186 result.len(),
187 result.first().map(|v| v.len()).unwrap_or(0)
188 );
189
190 result
191}
192
193#[instrument(skip_all, fields(count = embeddings.len()))]
195pub fn normalize_embeddings(embeddings: &mut [Vec<f32>]) {
196 for emb in embeddings.iter_mut() {
197 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
198 for v in emb.iter_mut() {
199 *v /= norm;
200 }
201 }
202 debug!("Normalized {} embeddings", embeddings.len());
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 fn dummy_tokenizer() -> Tokenizer {
213 use tokenizers::models::bpe::BPE;
214 Tokenizer::new(BPE::default())
215 }
216
217 fn simple_tokenizer() -> Tokenizer {
221 use std::collections::HashMap;
222 use tokenizers::models::wordlevel::WordLevel;
223 use tokenizers::pre_tokenizers::whitespace::Whitespace;
224
225 let mut vocab: HashMap<String, u32> = HashMap::new();
226 for (i, w) in [
227 "[PAD]", "[UNK]", "hello", "world", "test", "text", "one", "two", "foo", "bar", "baz",
228 ]
229 .iter()
230 .enumerate()
231 {
232 vocab.insert(w.to_string(), i as u32);
233 }
234
235 let model = WordLevel::builder()
236 .vocab(vocab)
237 .unk_token("[UNK]".to_string())
238 .build()
239 .unwrap();
240
241 let mut tok = Tokenizer::new(model);
242 tok.with_pre_tokenizer(Some(Whitespace {}));
243 tok
244 }
245
246 #[test]
247 fn test_prepare_texts_with_prefix() {
248 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
249
250 let texts = vec!["Hello world".to_string(), "Test query".to_string()];
251 let prepared = processor.prepare_texts(&texts, true);
252
253 assert_eq!(prepared[0], "query: Hello world");
254 assert_eq!(prepared[1], "query: Test query");
255 }
256
257 #[test]
258 fn test_prepare_texts_no_prefix() {
259 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
260
261 let texts = vec!["Hello world".to_string()];
262 let prepared = processor.prepare_texts(&texts, true);
263
264 assert_eq!(prepared[0], "Hello world");
265 }
266
267 #[test]
268 fn test_prepare_texts_document_prefix_e5() {
269 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
270 let texts = vec!["Some document".to_string(), "Another doc".to_string()];
271 let prepared = processor.prepare_texts(&texts, false);
272 assert_eq!(prepared[0], "passage: Some document");
273 assert_eq!(prepared[1], "passage: Another doc");
274 }
275
276 #[test]
277 fn test_prepare_texts_bge_no_prefix_query() {
278 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
279 let texts = vec!["Test".to_string()];
280 let prepared = processor.prepare_texts(&texts, true);
281 assert_eq!(prepared[0], "Test");
282 }
283
284 #[test]
285 fn test_prepare_texts_bge_no_prefix_document() {
286 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
287 let texts = vec!["Doc text".to_string()];
288 let prepared = processor.prepare_texts(&texts, false);
289 assert_eq!(prepared[0], "Doc text");
290 }
291
292 #[test]
293 fn test_prepare_texts_empty_input() {
294 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
295 let texts: Vec<String> = vec![];
296 let prepared = processor.prepare_texts(&texts, true);
297 assert!(prepared.is_empty());
298 }
299
300 #[test]
301 fn test_max_batch_size() {
302 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 64);
303 assert_eq!(processor.max_batch_size(), 64);
304 }
305
306 #[test]
307 fn test_max_batch_size_default() {
308 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
309 assert_eq!(processor.max_batch_size(), 32);
310 }
311
312 #[test]
313 fn test_split_into_batches_exact_multiple() {
314 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 4);
315 let texts: Vec<String> = (0..8).map(|i| format!("text {i}")).collect();
316 let batches = processor.split_into_batches(&texts);
317 assert_eq!(batches.len(), 2);
318 assert_eq!(batches[0].len(), 4);
319 assert_eq!(batches[1].len(), 4);
320 }
321
322 #[test]
323 fn test_split_into_batches_partial_last() {
324 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 4);
325 let texts: Vec<String> = (0..6).map(|i| format!("text {i}")).collect();
326 let batches = processor.split_into_batches(&texts);
327 assert_eq!(batches.len(), 2);
328 assert_eq!(batches[0].len(), 4);
329 assert_eq!(batches[1].len(), 2);
330 }
331
332 #[test]
333 fn test_split_into_batches_smaller_than_max() {
334 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
335 let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
336 let batches = processor.split_into_batches(&texts);
337 assert_eq!(batches.len(), 1);
338 assert_eq!(batches[0].len(), 5);
339 }
340
341 #[test]
342 fn test_split_into_batches_empty() {
343 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
344 let texts: Vec<String> = vec![];
345 let batches = processor.split_into_batches(&texts);
346 assert!(batches.is_empty());
347 }
348
349 #[test]
350 fn test_split_into_batches_preserves_content() {
351 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 3);
352 let texts = vec![
353 "a".to_string(),
354 "b".to_string(),
355 "c".to_string(),
356 "d".to_string(),
357 ];
358 let batches = processor.split_into_batches(&texts);
359 assert_eq!(batches[0], &["a", "b", "c"]);
360 assert_eq!(batches[1], &["d"]);
361 }
362
363 #[test]
364 fn test_tokenize_batch_empty_error() {
365 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
366 let result = processor.tokenize_batch(&[]);
367 assert!(result.is_err());
368 let err = result.unwrap_err();
369 assert!(matches!(err, InferenceError::InvalidInput(_)));
370 assert!(err.to_string().contains("Empty text batch"));
371 }
372
373 #[test]
374 fn test_tokenize_batch_exceeds_max_size_error() {
375 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 2);
376 let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
377 let result = processor.tokenize_batch(&texts);
378 assert!(result.is_err());
379 let err = result.unwrap_err();
380 assert!(matches!(err, InferenceError::InvalidInput(_)));
381 assert!(err.to_string().contains("exceeds maximum"));
382 }
383
384 #[test]
385 fn test_tokenize_batch_exactly_at_max_size_does_not_error_before_encode() {
386 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 2);
387 let texts = vec!["text one".to_string(), "text two".to_string()];
388 let result = processor.tokenize_batch(&texts);
389 if let Err(InferenceError::InvalidInput(msg)) = &result {
392 assert!(
393 !msg.contains("exceeds maximum"),
394 "Batch at exactly max_size should pass size check, got: {msg}"
395 );
396 }
397 }
398
399 #[test]
402 fn test_mean_pooling_output_shape() {
403 let lhs = vec![0.0f32; 2 * 3 * 4]; let mask = vec![1i64; 2 * 3]; let result = mean_pooling(&lhs, 2, 3, 4, &mask);
407 assert_eq!(result.len(), 2);
408 assert_eq!(result[0].len(), 4);
409 assert_eq!(result[1].len(), 4);
410 }
411
412 #[test]
413 fn test_mean_pooling_uniform_hidden_all_ones_mask() {
414 let lhs = vec![2.0f32; 1 * 4 * 3];
417 let mask = vec![1i64; 1 * 4];
418 let result = mean_pooling(&lhs, 1, 4, 3, &mask);
419 assert_eq!(result.len(), 1);
420 for v in &result[0] {
421 assert!((v - 2.0).abs() < 1e-5, "expected 2.0, got {v}");
422 }
423 }
424
425 #[test]
426 fn test_mean_pooling_masked_tokens_ignored() {
427 let lhs = vec![1.0f32, 1.0, 9.0, 9.0];
431 let mask = vec![1i64, 0i64];
432 let result = mean_pooling(&lhs, 1, 2, 2, &mask);
433 assert!(
434 (result[0][0] - 1.0).abs() < 1e-5,
435 "expected 1.0, got {}",
436 result[0][0]
437 );
438 assert!(
439 (result[0][1] - 1.0).abs() < 1e-5,
440 "expected 1.0, got {}",
441 result[0][1]
442 );
443 }
444
445 #[test]
446 fn test_mean_pooling_batch_independence() {
447 let lhs = vec![3.0f32, 4.0, 6.0, 8.0];
452 let mask = vec![1i64, 1i64];
453 let result = mean_pooling(&lhs, 2, 1, 2, &mask);
454 assert_eq!(result.len(), 2);
455 assert!((result[0][0] - 3.0).abs() < 1e-5);
456 assert!((result[0][1] - 4.0).abs() < 1e-5);
457 assert!((result[1][0] - 6.0).abs() < 1e-5);
458 assert!((result[1][1] - 8.0).abs() < 1e-5);
459 }
460
461 #[test]
464 fn test_normalize_embeddings_unit_length() {
465 let mut embeddings = vec![vec![3.0f32, 4.0]];
468 normalize_embeddings(&mut embeddings);
469 let norm: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
470 assert!(
471 (norm - 1.0).abs() < 1e-5,
472 "L2 norm should be 1.0, got {norm}"
473 );
474 }
475
476 #[test]
477 fn test_normalize_embeddings_values() {
478 let mut embeddings = vec![vec![3.0f32, 4.0]];
479 normalize_embeddings(&mut embeddings);
480 assert!(
481 (embeddings[0][0] - 0.6).abs() < 1e-5,
482 "expected 0.6, got {}",
483 embeddings[0][0]
484 );
485 assert!(
486 (embeddings[0][1] - 0.8).abs() < 1e-5,
487 "expected 0.8, got {}",
488 embeddings[0][1]
489 );
490 }
491
492 #[test]
493 fn test_normalize_embeddings_batch() {
494 let mut embeddings = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
496 normalize_embeddings(&mut embeddings);
497 let norm0: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
498 let norm1: f32 = embeddings[1].iter().map(|x| x * x).sum::<f32>().sqrt();
499 assert!((norm0 - 1.0).abs() < 1e-5);
500 assert!((norm1 - 1.0).abs() < 1e-5);
501 }
502
503 #[test]
504 fn test_normalize_embeddings_output_shape() {
505 let mut embeddings: Vec<Vec<f32>> = (1..=3)
506 .map(|i| (1..=4).map(|j| (i * j) as f32).collect())
507 .collect();
508 normalize_embeddings(&mut embeddings);
509 assert_eq!(embeddings.len(), 3);
510 assert!(embeddings.iter().all(|v| v.len() == 4));
511 }
512
513 #[test]
514 fn test_normalize_embeddings_near_zero_safe() {
515 let mut embeddings = vec![vec![1e-14f32, 1e-14]];
517 normalize_embeddings(&mut embeddings);
518 for v in &embeddings[0] {
519 assert!(v.is_finite(), "expected finite value, got {v}");
520 }
521 }
522
523 #[test]
526 fn test_tokenize_batch_single_text_success() {
527 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
528 let texts = vec!["hello world".to_string()];
529 let result = processor.tokenize_batch(&texts);
530 assert!(result.is_ok(), "Expected Ok, got {:?}", result);
531 let batch = result.unwrap();
532 assert_eq!(batch.batch_size, 1);
533 assert_eq!(batch.original_lengths, vec![11]);
534 }
535
536 #[test]
537 fn test_tokenize_batch_tensor_shapes_single() {
538 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
539 let texts = vec!["hello world".to_string()];
540 let batch = processor.tokenize_batch(&texts).unwrap();
541 assert_eq!(batch.batch_size, 1);
542 assert_eq!(batch.input_ids.len(), batch.batch_size * batch.seq_len);
543 assert_eq!(batch.attention_mask.len(), batch.batch_size * batch.seq_len);
544 assert_eq!(batch.token_type_ids.len(), batch.batch_size * batch.seq_len);
545 }
546
547 #[test]
548 fn test_tokenize_batch_multiple_texts_batch_dim() {
549 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
550 let texts = vec!["hello".to_string(), "hello world test".to_string()];
551 let batch = processor.tokenize_batch(&texts).unwrap();
552 assert_eq!(batch.batch_size, 2);
553 assert_eq!(batch.original_lengths.len(), 2);
554 assert_eq!(batch.input_ids.len(), batch.batch_size * batch.seq_len);
555 }
556
557 #[test]
558 fn test_tokenize_batch_token_type_ids_default_zeros() {
559 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
561 let texts = vec!["hello world".to_string()];
562 let batch = processor.tokenize_batch(&texts).unwrap();
563 for &v in &batch.token_type_ids {
564 assert_eq!(v, 0, "Expected zero token_type_id from WordLevel, got {v}");
565 }
566 }
567
568 #[test]
569 fn test_tokenize_batch_original_lengths_preserved() {
570 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
571 let texts = vec!["hello".to_string(), "hello world".to_string()];
572 let batch = processor.tokenize_batch(&texts).unwrap();
573 assert_eq!(batch.original_lengths[0], 5);
574 assert_eq!(batch.original_lengths[1], 11);
575 }
576
577 #[test]
578 fn test_tokenize_batch_three_texts_batch_size_field() {
579 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
580 let texts = vec!["hello".to_string(), "world".to_string(), "test".to_string()];
581 let batch = processor.tokenize_batch(&texts).unwrap();
582 assert_eq!(batch.batch_size, 3);
583 }
584
585 #[test]
586 fn test_tokenize_batch_all_arrays_consistent_length() {
587 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
588 let texts = vec!["foo bar".to_string(), "baz".to_string()];
589 let batch = processor.tokenize_batch(&texts).unwrap();
590 let expected_len = batch.batch_size * batch.seq_len;
591 assert_eq!(batch.input_ids.len(), expected_len);
592 assert_eq!(batch.attention_mask.len(), expected_len);
593 assert_eq!(batch.token_type_ids.len(), expected_len);
594 }
595
596 #[test]
597 fn test_tokenize_batch_ids_are_i64() {
598 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
599 let texts = vec!["hello world".to_string()];
600 let batch = processor.tokenize_batch(&texts).unwrap();
601 for &id in &batch.input_ids {
603 assert!(id >= 0, "input_id should be non-negative, got {id}");
604 }
605 for &m in &batch.attention_mask {
606 assert!(m == 0 || m == 1, "attention_mask should be 0 or 1, got {m}");
607 }
608 }
609}