1use crate::error::{InferenceError, Result};
4use crate::models::EmbeddingModel;
5use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
6use tracing::{debug, instrument};
7
8#[derive(Debug)]
13pub struct PreparedBatch {
14 pub input_ids: Vec<i64>,
16 pub attention_mask: Vec<i64>,
18 pub token_type_ids: Vec<i64>,
20 pub batch_size: usize,
22 pub seq_len: usize,
24 pub original_lengths: Vec<usize>,
26}
27
28pub struct BatchProcessor {
30 tokenizer: Tokenizer,
31 model: EmbeddingModel,
32 max_batch_size: usize,
33}
34
35impl BatchProcessor {
36 pub fn new(mut tokenizer: Tokenizer, model: EmbeddingModel, max_batch_size: usize) -> Self {
38 let padding = PaddingParams {
40 strategy: PaddingStrategy::BatchLongest,
41 pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
42 pad_token: tokenizer
43 .get_padding()
44 .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
45 ..Default::default()
46 };
47 tokenizer.with_padding(Some(padding));
48
49 let truncation = TruncationParams {
51 max_length: model.max_seq_length(),
52 ..Default::default()
53 };
54 let _ = tokenizer.with_truncation(Some(truncation));
55
56 Self {
57 tokenizer,
58 model,
59 max_batch_size,
60 }
61 }
62
63 pub fn max_batch_size(&self) -> usize {
65 self.max_batch_size
66 }
67
68 #[instrument(skip(self, texts), fields(count = texts.len()))]
70 pub fn prepare_texts(&self, texts: &[String], is_query: bool) -> Vec<String> {
71 let prefix = if is_query {
72 self.model.query_prefix()
73 } else {
74 self.model.document_prefix()
75 };
76
77 match prefix {
78 Some(p) => texts.iter().map(|t| format!("{}{}", p, t)).collect(),
79 None => texts.to_vec(),
80 }
81 }
82
83 #[instrument(skip(self, texts), fields(count = texts.len()))]
85 pub fn tokenize_batch(&self, texts: &[String]) -> Result<PreparedBatch> {
86 if texts.is_empty() {
87 return Err(InferenceError::InvalidInput("Empty text batch".into()));
88 }
89
90 if texts.len() > self.max_batch_size {
91 return Err(InferenceError::InvalidInput(format!(
92 "Batch size {} exceeds maximum {}",
93 texts.len(),
94 self.max_batch_size
95 )));
96 }
97
98 let original_lengths: Vec<usize> = texts.iter().map(|t| t.len()).collect();
99
100 debug!(
101 "Tokenizing {} texts, max length: {}",
102 texts.len(),
103 original_lengths.iter().max().unwrap_or(&0)
104 );
105
106 let encodings = self
108 .tokenizer
109 .encode_batch(texts.to_vec(), true)
110 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
111
112 let batch_size = encodings.len();
113 let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
114
115 debug!("Tokenized: batch_size={}, seq_len={}", batch_size, seq_len);
116
117 let mut input_ids = Vec::with_capacity(batch_size * seq_len);
119 let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
120 let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
121
122 for enc in &encodings {
123 input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
124 attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
125
126 let type_ids = enc.get_type_ids();
127 if type_ids.is_empty() {
128 token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
129 } else {
130 token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
131 }
132 }
133
134 Ok(PreparedBatch {
135 input_ids,
136 attention_mask,
137 token_type_ids,
138 batch_size,
139 seq_len,
140 original_lengths,
141 })
142 }
143
144 pub fn split_into_batches<'a>(&self, texts: &'a [String]) -> Vec<&'a [String]> {
146 texts.chunks(self.max_batch_size).collect()
147 }
148}
149
150#[instrument(skip_all, fields(batch_size, seq_len, hidden_size))]
157pub fn mean_pooling(
158 last_hidden_state: &[f32],
159 batch_size: usize,
160 seq_len: usize,
161 hidden_size: usize,
162 attention_mask: &[i64],
163) -> Vec<Vec<f32>> {
164 let mut result = vec![vec![0.0f32; hidden_size]; batch_size];
165
166 for b in 0..batch_size {
167 let mask_sum: f32 = (0..seq_len)
169 .map(|s| attention_mask[b * seq_len + s] as f32)
170 .sum::<f32>()
171 .max(1e-9);
172
173 for (h, cell) in result[b].iter_mut().enumerate() {
174 let weighted_sum: f32 = (0..seq_len)
175 .map(|s| {
176 let lhs_idx = b * seq_len * hidden_size + s * hidden_size + h;
177 last_hidden_state[lhs_idx] * attention_mask[b * seq_len + s] as f32
178 })
179 .sum();
180 *cell = weighted_sum / mask_sum;
181 }
182 }
183
184 debug!(
185 "Mean pooled: batch={}, hidden={}",
186 result.len(),
187 result.first().map(|v| v.len()).unwrap_or(0)
188 );
189
190 result
191}
192
193#[instrument(skip_all, fields(count = embeddings.len()))]
195pub fn normalize_embeddings(embeddings: &mut [Vec<f32>]) {
196 for emb in embeddings.iter_mut() {
197 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
198 for v in emb.iter_mut() {
199 *v /= norm;
200 }
201 }
202 debug!("Normalized {} embeddings", embeddings.len());
203}
204
205pub fn truncate_mrl(embedding: &[f32], target_dim: usize) -> crate::error::Result<Vec<f32>> {
214 if target_dim == 0 || target_dim > embedding.len() {
215 return Err(crate::error::InferenceError::InvalidInput(format!(
216 "MRL target_dim={} is out of range for embedding of length {}",
217 target_dim,
218 embedding.len()
219 )));
220 }
221 let mut truncated = embedding[..target_dim].to_vec();
222 let norm: f32 = truncated
223 .iter()
224 .map(|x| x * x)
225 .sum::<f32>()
226 .sqrt()
227 .max(1e-12);
228 for v in truncated.iter_mut() {
229 *v /= norm;
230 }
231 Ok(truncated)
232}
233
234pub struct TokenBudgetBatcher {
251 token_budget: usize,
252 current_batch: Vec<String>,
253 current_tokens: usize,
254 finished_batches: Vec<Vec<String>>,
255 token_count_fn: Box<dyn Fn(&str) -> usize + Send + Sync>,
257}
258
259impl TokenBudgetBatcher {
260 pub fn new(token_budget: usize) -> Self {
265 let budget = std::env::var("DAKERA_TOKEN_BUDGET")
266 .ok()
267 .and_then(|v| v.parse::<usize>().ok())
268 .filter(|&n| n > 0)
269 .unwrap_or(token_budget)
270 .max(1);
271
272 Self {
273 token_budget: budget,
274 current_batch: Vec::new(),
275 current_tokens: 0,
276 finished_batches: Vec::new(),
277 token_count_fn: Box::new(|text| (text.len() / 4).max(1)),
278 }
279 }
280
281 pub fn with_token_fn(mut self, f: impl Fn(&str) -> usize + Send + Sync + 'static) -> Self {
283 self.token_count_fn = Box::new(f);
284 self
285 }
286
287 pub fn push(&mut self, text: String) {
292 let tokens = (self.token_count_fn)(&text);
293 if !self.current_batch.is_empty() && self.current_tokens + tokens > self.token_budget {
294 let batch = std::mem::take(&mut self.current_batch);
296 self.finished_batches.push(batch);
297 self.current_tokens = 0;
298 }
299 self.current_tokens += tokens;
300 self.current_batch.push(text);
301 }
302
303 pub fn push_all(&mut self, texts: impl IntoIterator<Item = String>) {
305 for t in texts {
306 self.push(t);
307 }
308 }
309
310 pub fn finish(&mut self) -> Vec<Vec<String>> {
314 if !self.current_batch.is_empty() {
315 let batch = std::mem::take(&mut self.current_batch);
316 self.finished_batches.push(batch);
317 self.current_tokens = 0;
318 }
319 std::mem::take(&mut self.finished_batches)
320 }
321
322 pub fn pending_count(&self) -> usize {
324 self.current_batch.len()
325 }
326
327 pub fn pending_tokens(&self) -> usize {
329 self.current_tokens
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 fn dummy_tokenizer() -> Tokenizer {
341 use tokenizers::models::bpe::BPE;
342 Tokenizer::new(BPE::default())
343 }
344
345 fn simple_tokenizer() -> Tokenizer {
349 use std::collections::HashMap;
350 use tokenizers::models::wordlevel::WordLevel;
351 use tokenizers::pre_tokenizers::whitespace::Whitespace;
352
353 let mut vocab: HashMap<String, u32> = HashMap::new();
354 for (i, w) in [
355 "[PAD]", "[UNK]", "hello", "world", "test", "text", "one", "two", "foo", "bar", "baz",
356 ]
357 .iter()
358 .enumerate()
359 {
360 vocab.insert(w.to_string(), i as u32);
361 }
362
363 let model = WordLevel::builder()
364 .vocab(vocab)
365 .unk_token("[UNK]".to_string())
366 .build()
367 .unwrap();
368
369 let mut tok = Tokenizer::new(model);
370 tok.with_pre_tokenizer(Some(Whitespace {}));
371 tok
372 }
373
374 #[test]
375 fn test_prepare_texts_with_prefix() {
376 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
377
378 let texts = vec!["Hello world".to_string(), "Test query".to_string()];
379 let prepared = processor.prepare_texts(&texts, true);
380
381 assert_eq!(prepared[0], "query: Hello world");
382 assert_eq!(prepared[1], "query: Test query");
383 }
384
385 #[test]
386 fn test_prepare_texts_no_prefix() {
387 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
388
389 let texts = vec!["Hello world".to_string()];
390 let prepared = processor.prepare_texts(&texts, true);
391
392 assert_eq!(prepared[0], "Hello world");
393 }
394
395 #[test]
396 fn test_prepare_texts_document_prefix_e5() {
397 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
398 let texts = vec!["Some document".to_string(), "Another doc".to_string()];
399 let prepared = processor.prepare_texts(&texts, false);
400 assert_eq!(prepared[0], "passage: Some document");
401 assert_eq!(prepared[1], "passage: Another doc");
402 }
403
404 #[test]
405 fn test_prepare_texts_bge_no_prefix_query() {
406 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
407 let texts = vec!["Test".to_string()];
408 let prepared = processor.prepare_texts(&texts, true);
409 assert_eq!(prepared[0], "Test");
410 }
411
412 #[test]
413 fn test_prepare_texts_bge_no_prefix_document() {
414 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
415 let texts = vec!["Doc text".to_string()];
416 let prepared = processor.prepare_texts(&texts, false);
417 assert_eq!(prepared[0], "Doc text");
418 }
419
420 #[test]
421 fn test_prepare_texts_empty_input() {
422 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
423 let texts: Vec<String> = vec![];
424 let prepared = processor.prepare_texts(&texts, true);
425 assert!(prepared.is_empty());
426 }
427
428 #[test]
429 fn test_max_batch_size() {
430 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 64);
431 assert_eq!(processor.max_batch_size(), 64);
432 }
433
434 #[test]
435 fn test_max_batch_size_default() {
436 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
437 assert_eq!(processor.max_batch_size(), 32);
438 }
439
440 #[test]
441 fn test_split_into_batches_exact_multiple() {
442 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 4);
443 let texts: Vec<String> = (0..8).map(|i| format!("text {i}")).collect();
444 let batches = processor.split_into_batches(&texts);
445 assert_eq!(batches.len(), 2);
446 assert_eq!(batches[0].len(), 4);
447 assert_eq!(batches[1].len(), 4);
448 }
449
450 #[test]
451 fn test_split_into_batches_partial_last() {
452 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 4);
453 let texts: Vec<String> = (0..6).map(|i| format!("text {i}")).collect();
454 let batches = processor.split_into_batches(&texts);
455 assert_eq!(batches.len(), 2);
456 assert_eq!(batches[0].len(), 4);
457 assert_eq!(batches[1].len(), 2);
458 }
459
460 #[test]
461 fn test_split_into_batches_smaller_than_max() {
462 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
463 let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
464 let batches = processor.split_into_batches(&texts);
465 assert_eq!(batches.len(), 1);
466 assert_eq!(batches[0].len(), 5);
467 }
468
469 #[test]
470 fn test_split_into_batches_empty() {
471 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
472 let texts: Vec<String> = vec![];
473 let batches = processor.split_into_batches(&texts);
474 assert!(batches.is_empty());
475 }
476
477 #[test]
478 fn test_split_into_batches_preserves_content() {
479 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 3);
480 let texts = vec![
481 "a".to_string(),
482 "b".to_string(),
483 "c".to_string(),
484 "d".to_string(),
485 ];
486 let batches = processor.split_into_batches(&texts);
487 assert_eq!(batches[0], &["a", "b", "c"]);
488 assert_eq!(batches[1], &["d"]);
489 }
490
491 #[test]
492 fn test_tokenize_batch_empty_error() {
493 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
494 let result = processor.tokenize_batch(&[]);
495 assert!(result.is_err());
496 let err = result.unwrap_err();
497 assert!(matches!(err, InferenceError::InvalidInput(_)));
498 assert!(err.to_string().contains("Empty text batch"));
499 }
500
501 #[test]
502 fn test_tokenize_batch_exceeds_max_size_error() {
503 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 2);
504 let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
505 let result = processor.tokenize_batch(&texts);
506 assert!(result.is_err());
507 let err = result.unwrap_err();
508 assert!(matches!(err, InferenceError::InvalidInput(_)));
509 assert!(err.to_string().contains("exceeds maximum"));
510 }
511
512 #[test]
513 fn test_tokenize_batch_exactly_at_max_size_does_not_error_before_encode() {
514 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 2);
515 let texts = vec!["text one".to_string(), "text two".to_string()];
516 let result = processor.tokenize_batch(&texts);
517 if let Err(InferenceError::InvalidInput(msg)) = &result {
520 assert!(
521 !msg.contains("exceeds maximum"),
522 "Batch at exactly max_size should pass size check, got: {msg}"
523 );
524 }
525 }
526
527 #[test]
530 fn test_mean_pooling_output_shape() {
531 let lhs = vec![0.0f32; 2 * 3 * 4]; let mask = vec![1i64; 2 * 3]; let result = mean_pooling(&lhs, 2, 3, 4, &mask);
535 assert_eq!(result.len(), 2);
536 assert_eq!(result[0].len(), 4);
537 assert_eq!(result[1].len(), 4);
538 }
539
540 #[test]
541 fn test_mean_pooling_uniform_hidden_all_ones_mask() {
542 let lhs = vec![2.0f32; 4 * 3];
545 let mask = vec![1i64; 4];
546 let result = mean_pooling(&lhs, 1, 4, 3, &mask);
547 assert_eq!(result.len(), 1);
548 for v in &result[0] {
549 assert!((v - 2.0).abs() < 1e-5, "expected 2.0, got {v}");
550 }
551 }
552
553 #[test]
554 fn test_mean_pooling_masked_tokens_ignored() {
555 let lhs = vec![1.0f32, 1.0, 9.0, 9.0];
559 let mask = vec![1i64, 0i64];
560 let result = mean_pooling(&lhs, 1, 2, 2, &mask);
561 assert!(
562 (result[0][0] - 1.0).abs() < 1e-5,
563 "expected 1.0, got {}",
564 result[0][0]
565 );
566 assert!(
567 (result[0][1] - 1.0).abs() < 1e-5,
568 "expected 1.0, got {}",
569 result[0][1]
570 );
571 }
572
573 #[test]
574 fn test_mean_pooling_batch_independence() {
575 let lhs = vec![3.0f32, 4.0, 6.0, 8.0];
580 let mask = vec![1i64, 1i64];
581 let result = mean_pooling(&lhs, 2, 1, 2, &mask);
582 assert_eq!(result.len(), 2);
583 assert!((result[0][0] - 3.0).abs() < 1e-5);
584 assert!((result[0][1] - 4.0).abs() < 1e-5);
585 assert!((result[1][0] - 6.0).abs() < 1e-5);
586 assert!((result[1][1] - 8.0).abs() < 1e-5);
587 }
588
589 #[test]
592 fn test_normalize_embeddings_unit_length() {
593 let mut embeddings = vec![vec![3.0f32, 4.0]];
596 normalize_embeddings(&mut embeddings);
597 let norm: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
598 assert!(
599 (norm - 1.0).abs() < 1e-5,
600 "L2 norm should be 1.0, got {norm}"
601 );
602 }
603
604 #[test]
605 fn test_normalize_embeddings_values() {
606 let mut embeddings = vec![vec![3.0f32, 4.0]];
607 normalize_embeddings(&mut embeddings);
608 assert!(
609 (embeddings[0][0] - 0.6).abs() < 1e-5,
610 "expected 0.6, got {}",
611 embeddings[0][0]
612 );
613 assert!(
614 (embeddings[0][1] - 0.8).abs() < 1e-5,
615 "expected 0.8, got {}",
616 embeddings[0][1]
617 );
618 }
619
620 #[test]
621 fn test_normalize_embeddings_batch() {
622 let mut embeddings = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
624 normalize_embeddings(&mut embeddings);
625 let norm0: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
626 let norm1: f32 = embeddings[1].iter().map(|x| x * x).sum::<f32>().sqrt();
627 assert!((norm0 - 1.0).abs() < 1e-5);
628 assert!((norm1 - 1.0).abs() < 1e-5);
629 }
630
631 #[test]
632 fn test_normalize_embeddings_output_shape() {
633 let mut embeddings: Vec<Vec<f32>> = (1..=3)
634 .map(|i| (1..=4).map(|j| (i * j) as f32).collect())
635 .collect();
636 normalize_embeddings(&mut embeddings);
637 assert_eq!(embeddings.len(), 3);
638 assert!(embeddings.iter().all(|v| v.len() == 4));
639 }
640
641 #[test]
642 fn test_normalize_embeddings_near_zero_safe() {
643 let mut embeddings = vec![vec![1e-14f32, 1e-14]];
645 normalize_embeddings(&mut embeddings);
646 for v in &embeddings[0] {
647 assert!(v.is_finite(), "expected finite value, got {v}");
648 }
649 }
650
651 #[test]
654 fn test_tokenize_batch_single_text_success() {
655 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
656 let texts = vec!["hello world".to_string()];
657 let result = processor.tokenize_batch(&texts);
658 assert!(result.is_ok(), "Expected Ok, got {:?}", result);
659 let batch = result.unwrap();
660 assert_eq!(batch.batch_size, 1);
661 assert_eq!(batch.original_lengths, vec![11]);
662 }
663
664 #[test]
665 fn test_tokenize_batch_tensor_shapes_single() {
666 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
667 let texts = vec!["hello world".to_string()];
668 let batch = processor.tokenize_batch(&texts).unwrap();
669 assert_eq!(batch.batch_size, 1);
670 assert_eq!(batch.input_ids.len(), batch.batch_size * batch.seq_len);
671 assert_eq!(batch.attention_mask.len(), batch.batch_size * batch.seq_len);
672 assert_eq!(batch.token_type_ids.len(), batch.batch_size * batch.seq_len);
673 }
674
675 #[test]
676 fn test_tokenize_batch_multiple_texts_batch_dim() {
677 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
678 let texts = vec!["hello".to_string(), "hello world test".to_string()];
679 let batch = processor.tokenize_batch(&texts).unwrap();
680 assert_eq!(batch.batch_size, 2);
681 assert_eq!(batch.original_lengths.len(), 2);
682 assert_eq!(batch.input_ids.len(), batch.batch_size * batch.seq_len);
683 }
684
685 #[test]
686 fn test_tokenize_batch_token_type_ids_default_zeros() {
687 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
689 let texts = vec!["hello world".to_string()];
690 let batch = processor.tokenize_batch(&texts).unwrap();
691 for &v in &batch.token_type_ids {
692 assert_eq!(v, 0, "Expected zero token_type_id from WordLevel, got {v}");
693 }
694 }
695
696 #[test]
697 fn test_tokenize_batch_original_lengths_preserved() {
698 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
699 let texts = vec!["hello".to_string(), "hello world".to_string()];
700 let batch = processor.tokenize_batch(&texts).unwrap();
701 assert_eq!(batch.original_lengths[0], 5);
702 assert_eq!(batch.original_lengths[1], 11);
703 }
704
705 #[test]
706 fn test_tokenize_batch_three_texts_batch_size_field() {
707 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
708 let texts = vec!["hello".to_string(), "world".to_string(), "test".to_string()];
709 let batch = processor.tokenize_batch(&texts).unwrap();
710 assert_eq!(batch.batch_size, 3);
711 }
712
713 #[test]
714 fn test_tokenize_batch_all_arrays_consistent_length() {
715 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
716 let texts = vec!["foo bar".to_string(), "baz".to_string()];
717 let batch = processor.tokenize_batch(&texts).unwrap();
718 let expected_len = batch.batch_size * batch.seq_len;
719 assert_eq!(batch.input_ids.len(), expected_len);
720 assert_eq!(batch.attention_mask.len(), expected_len);
721 assert_eq!(batch.token_type_ids.len(), expected_len);
722 }
723
724 #[test]
725 fn test_tokenize_batch_ids_are_i64() {
726 let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
727 let texts = vec!["hello world".to_string()];
728 let batch = processor.tokenize_batch(&texts).unwrap();
729 for &id in &batch.input_ids {
731 assert!(id >= 0, "input_id should be non-negative, got {id}");
732 }
733 for &m in &batch.attention_mask {
734 assert!(m == 0 || m == 1, "attention_mask should be 0 or 1, got {m}");
735 }
736 }
737
738 fn exact_batcher(budget: usize) -> TokenBudgetBatcher {
742 TokenBudgetBatcher::new(budget).with_token_fn(|text| text.len())
743 }
744
745 #[test]
746 fn test_token_budget_batcher_empty_finish() {
747 let mut batcher = exact_batcher(100);
748 let batches = batcher.finish();
749 assert!(batches.is_empty());
750 }
751
752 #[test]
753 fn test_token_budget_batcher_single_text_single_batch() {
754 let mut batcher = exact_batcher(100);
755 batcher.push("hello".to_string()); let batches = batcher.finish();
757 assert_eq!(batches.len(), 1);
758 assert_eq!(batches[0], vec!["hello".to_string()]);
759 }
760
761 #[test]
762 fn test_token_budget_batcher_fits_small_texts_in_one_batch() {
763 let mut batcher = exact_batcher(50);
764 for i in 0..5 {
765 batcher.push(format!("t{i}")); }
767 let batches = batcher.finish();
768 assert_eq!(batches.len(), 1);
769 assert_eq!(batches[0].len(), 5);
770 }
771
772 #[test]
773 fn test_token_budget_batcher_splits_on_budget_exceeded() {
774 let mut batcher = exact_batcher(10);
777 for _ in 0..5 {
778 batcher.push("ab".to_string()); }
780 batcher.push("cd".to_string()); let batches = batcher.finish();
783 assert_eq!(batches.len(), 2);
786 assert_eq!(batches[0].len(), 5);
787 assert_eq!(batches[1].len(), 1);
788 }
789
790 #[test]
791 fn test_token_budget_batcher_large_single_text_gets_own_batch() {
792 let mut batcher = exact_batcher(10);
793 batcher.push("small".to_string()); batcher.push("a".repeat(50)); let batches = batcher.finish();
796 assert_eq!(batches.len(), 2);
797 assert_eq!(batches[0][0], "small");
798 }
799
800 #[test]
801 fn test_token_budget_batcher_finish_resets_state() {
802 let mut batcher = exact_batcher(100);
803 batcher.push("hello".to_string());
804 let _first = batcher.finish();
805 batcher.push("world".to_string());
806 let second = batcher.finish();
807 assert_eq!(second.len(), 1);
808 assert_eq!(second[0][0], "world");
809 }
810
811 #[test]
812 fn test_token_budget_batcher_push_all() {
813 let mut batcher = exact_batcher(100);
814 batcher.push_all(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
815 let batches = batcher.finish();
816 assert_eq!(batches.len(), 1);
817 assert_eq!(batches[0].len(), 3);
818 }
819
820 #[test]
821 fn test_token_budget_batcher_pending_count() {
822 let mut batcher = exact_batcher(100);
823 assert_eq!(batcher.pending_count(), 0);
824 batcher.push("hello".to_string());
825 assert_eq!(batcher.pending_count(), 1);
826 batcher.push("world".to_string());
827 assert_eq!(batcher.pending_count(), 2);
828 }
829
830 #[test]
833 fn test_mrl_truncation_basic() {
834 let embedding = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
835 let truncated = truncate_mrl(&embedding, 4).unwrap();
836 assert_eq!(truncated.len(), 4);
837 }
838
839 #[test]
840 fn test_mrl_truncation_normalized() {
841 let embedding = vec![3.0f32, 4.0, 0.0, 0.0];
842 let truncated = truncate_mrl(&embedding, 2).unwrap();
843 assert!((truncated[0] - 0.6).abs() < 1e-5);
845 assert!((truncated[1] - 0.8).abs() < 1e-5);
846 }
847
848 #[test]
849 fn test_mrl_truncation_256_from_1024() {
850 let embedding: Vec<f32> = (0..1024).map(|i| i as f32).collect();
851 let truncated = truncate_mrl(&embedding, 256).unwrap();
852 assert_eq!(truncated.len(), 256);
853 let norm: f32 = truncated.iter().map(|x| x * x).sum::<f32>().sqrt();
855 assert!((norm - 1.0).abs() < 1e-4, "norm={norm}");
856 }
857
858 #[test]
859 fn test_mrl_truncation_full_dimension_is_noop_shape() {
860 let embedding = vec![0.0f32; 1024];
861 let truncated = truncate_mrl(&embedding, 1024).unwrap();
863 assert_eq!(truncated.len(), 1024);
864 }
865
866 #[test]
867 fn test_mrl_truncation_zero_target_dim_error() {
868 let embedding = vec![1.0f32; 10];
869 let result = truncate_mrl(&embedding, 0);
870 assert!(result.is_err());
871 }
872
873 #[test]
874 fn test_mrl_truncation_target_exceeds_length_error() {
875 let embedding = vec![1.0f32; 4];
876 let result = truncate_mrl(&embedding, 5);
877 assert!(result.is_err());
878 }
879
880 #[test]
881 fn test_mrl_preserves_semantic_direction() {
882 let mut embedding: Vec<f32> = (0..1024)
888 .map(|i| if i < 256 { (i % 16) as f32 + 1.0 } else { 0.0 })
889 .collect();
890 let norm: f32 = embedding
891 .iter()
892 .map(|x| x * x)
893 .sum::<f32>()
894 .sqrt()
895 .max(1e-12);
896 for v in embedding.iter_mut() {
897 *v /= norm;
898 }
899 let truncated = truncate_mrl(&embedding, 256).unwrap();
900 let dot: f32 = truncated
902 .iter()
903 .zip(embedding.iter().take(256))
904 .map(|(a, b)| a * b)
905 .sum();
906 assert!(dot > 0.9, "cosine similarity {dot} should be >0.9");
907 }
908}