use crate::error::{InferenceError, Result};
use crate::models::EmbeddingModel;
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use tracing::{debug, instrument};
#[derive(Debug)]
pub struct PreparedBatch {
pub input_ids: Vec<i64>,
pub attention_mask: Vec<i64>,
pub token_type_ids: Vec<i64>,
pub batch_size: usize,
pub seq_len: usize,
pub original_lengths: Vec<usize>,
}
pub struct BatchProcessor {
tokenizer: Tokenizer,
model: EmbeddingModel,
max_batch_size: usize,
}
impl BatchProcessor {
pub fn new(mut tokenizer: Tokenizer, model: EmbeddingModel, max_batch_size: usize) -> Self {
let padding = PaddingParams {
strategy: PaddingStrategy::BatchLongest,
pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
pad_token: tokenizer
.get_padding()
.map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
..Default::default()
};
tokenizer.with_padding(Some(padding));
let truncation = TruncationParams {
max_length: model.max_seq_length(),
..Default::default()
};
let _ = tokenizer.with_truncation(Some(truncation));
Self {
tokenizer,
model,
max_batch_size,
}
}
pub fn max_batch_size(&self) -> usize {
self.max_batch_size
}
#[instrument(skip(self, texts), fields(count = texts.len()))]
pub fn prepare_texts(&self, texts: &[String], is_query: bool) -> Vec<String> {
let prefix = if is_query {
self.model.query_prefix()
} else {
self.model.document_prefix()
};
match prefix {
Some(p) => texts.iter().map(|t| format!("{}{}", p, t)).collect(),
None => texts.to_vec(),
}
}
#[instrument(skip(self, texts), fields(count = texts.len()))]
pub fn tokenize_batch(&self, texts: &[String]) -> Result<PreparedBatch> {
if texts.is_empty() {
return Err(InferenceError::InvalidInput("Empty text batch".into()));
}
if texts.len() > self.max_batch_size {
return Err(InferenceError::InvalidInput(format!(
"Batch size {} exceeds maximum {}",
texts.len(),
self.max_batch_size
)));
}
let original_lengths: Vec<usize> = texts.iter().map(|t| t.len()).collect();
debug!(
"Tokenizing {} texts, max length: {}",
texts.len(),
original_lengths.iter().max().unwrap_or(&0)
);
let encodings = self
.tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let batch_size = encodings.len();
let seq_len = encodings.first().map(|e| e.get_ids().len()).unwrap_or(0);
debug!("Tokenized: batch_size={}, seq_len={}", batch_size, seq_len);
let mut input_ids = Vec::with_capacity(batch_size * seq_len);
let mut attention_mask = Vec::with_capacity(batch_size * seq_len);
let mut token_type_ids = Vec::with_capacity(batch_size * seq_len);
for enc in &encodings {
input_ids.extend(enc.get_ids().iter().map(|&id| id as i64));
attention_mask.extend(enc.get_attention_mask().iter().map(|&m| m as i64));
let type_ids = enc.get_type_ids();
if type_ids.is_empty() {
token_type_ids.extend(std::iter::repeat_n(0i64, seq_len));
} else {
token_type_ids.extend(type_ids.iter().map(|&t| t as i64));
}
}
Ok(PreparedBatch {
input_ids,
attention_mask,
token_type_ids,
batch_size,
seq_len,
original_lengths,
})
}
pub fn split_into_batches<'a>(&self, texts: &'a [String]) -> Vec<&'a [String]> {
texts.chunks(self.max_batch_size).collect()
}
}
#[instrument(skip_all, fields(batch_size, seq_len, hidden_size))]
pub fn mean_pooling(
last_hidden_state: &[f32],
batch_size: usize,
seq_len: usize,
hidden_size: usize,
attention_mask: &[i64],
) -> Vec<Vec<f32>> {
let mut result = vec![vec![0.0f32; hidden_size]; batch_size];
for b in 0..batch_size {
let mask_sum: f32 = (0..seq_len)
.map(|s| attention_mask[b * seq_len + s] as f32)
.sum::<f32>()
.max(1e-9);
for (h, cell) in result[b].iter_mut().enumerate() {
let weighted_sum: f32 = (0..seq_len)
.map(|s| {
let lhs_idx = b * seq_len * hidden_size + s * hidden_size + h;
last_hidden_state[lhs_idx] * attention_mask[b * seq_len + s] as f32
})
.sum();
*cell = weighted_sum / mask_sum;
}
}
debug!(
"Mean pooled: batch={}, hidden={}",
result.len(),
result.first().map(|v| v.len()).unwrap_or(0)
);
result
}
#[instrument(skip_all, fields(count = embeddings.len()))]
pub fn normalize_embeddings(embeddings: &mut [Vec<f32>]) {
for emb in embeddings.iter_mut() {
let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
for v in emb.iter_mut() {
*v /= norm;
}
}
debug!("Normalized {} embeddings", embeddings.len());
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_tokenizer() -> Tokenizer {
use tokenizers::models::bpe::BPE;
Tokenizer::new(BPE::default())
}
fn simple_tokenizer() -> Tokenizer {
use std::collections::HashMap;
use tokenizers::models::wordlevel::WordLevel;
use tokenizers::pre_tokenizers::whitespace::Whitespace;
let mut vocab: HashMap<String, u32> = HashMap::new();
for (i, w) in [
"[PAD]", "[UNK]", "hello", "world", "test", "text", "one", "two", "foo", "bar", "baz",
]
.iter()
.enumerate()
{
vocab.insert(w.to_string(), i as u32);
}
let model = WordLevel::builder()
.vocab(vocab)
.unk_token("[UNK]".to_string())
.build()
.unwrap();
let mut tok = Tokenizer::new(model);
tok.with_pre_tokenizer(Some(Whitespace {}));
tok
}
#[test]
fn test_prepare_texts_with_prefix() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
let texts = vec!["Hello world".to_string(), "Test query".to_string()];
let prepared = processor.prepare_texts(&texts, true);
assert_eq!(prepared[0], "query: Hello world");
assert_eq!(prepared[1], "query: Test query");
}
#[test]
fn test_prepare_texts_no_prefix() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts = vec!["Hello world".to_string()];
let prepared = processor.prepare_texts(&texts, true);
assert_eq!(prepared[0], "Hello world");
}
#[test]
fn test_prepare_texts_document_prefix_e5() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
let texts = vec!["Some document".to_string(), "Another doc".to_string()];
let prepared = processor.prepare_texts(&texts, false);
assert_eq!(prepared[0], "passage: Some document");
assert_eq!(prepared[1], "passage: Another doc");
}
#[test]
fn test_prepare_texts_bge_no_prefix_query() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
let texts = vec!["Test".to_string()];
let prepared = processor.prepare_texts(&texts, true);
assert_eq!(prepared[0], "Test");
}
#[test]
fn test_prepare_texts_bge_no_prefix_document() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
let texts = vec!["Doc text".to_string()];
let prepared = processor.prepare_texts(&texts, false);
assert_eq!(prepared[0], "Doc text");
}
#[test]
fn test_prepare_texts_empty_input() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts: Vec<String> = vec![];
let prepared = processor.prepare_texts(&texts, true);
assert!(prepared.is_empty());
}
#[test]
fn test_max_batch_size() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 64);
assert_eq!(processor.max_batch_size(), 64);
}
#[test]
fn test_max_batch_size_default() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::BgeSmall, 32);
assert_eq!(processor.max_batch_size(), 32);
}
#[test]
fn test_split_into_batches_exact_multiple() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 4);
let texts: Vec<String> = (0..8).map(|i| format!("text {i}")).collect();
let batches = processor.split_into_batches(&texts);
assert_eq!(batches.len(), 2);
assert_eq!(batches[0].len(), 4);
assert_eq!(batches[1].len(), 4);
}
#[test]
fn test_split_into_batches_partial_last() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 4);
let texts: Vec<String> = (0..6).map(|i| format!("text {i}")).collect();
let batches = processor.split_into_batches(&texts);
assert_eq!(batches.len(), 2);
assert_eq!(batches[0].len(), 4);
assert_eq!(batches[1].len(), 2);
}
#[test]
fn test_split_into_batches_smaller_than_max() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
let batches = processor.split_into_batches(&texts);
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].len(), 5);
}
#[test]
fn test_split_into_batches_empty() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts: Vec<String> = vec![];
let batches = processor.split_into_batches(&texts);
assert!(batches.is_empty());
}
#[test]
fn test_split_into_batches_preserves_content() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 3);
let texts = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let batches = processor.split_into_batches(&texts);
assert_eq!(batches[0], &["a", "b", "c"]);
assert_eq!(batches[1], &["d"]);
}
#[test]
fn test_tokenize_batch_empty_error() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
let result = processor.tokenize_batch(&[]);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, InferenceError::InvalidInput(_)));
assert!(err.to_string().contains("Empty text batch"));
}
#[test]
fn test_tokenize_batch_exceeds_max_size_error() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 2);
let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
let result = processor.tokenize_batch(&texts);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, InferenceError::InvalidInput(_)));
assert!(err.to_string().contains("exceeds maximum"));
}
#[test]
fn test_tokenize_batch_exactly_at_max_size_does_not_error_before_encode() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 2);
let texts = vec!["text one".to_string(), "text two".to_string()];
let result = processor.tokenize_batch(&texts);
if let Err(InferenceError::InvalidInput(msg)) = &result {
assert!(
!msg.contains("exceeds maximum"),
"Batch at exactly max_size should pass size check, got: {msg}"
);
}
}
#[test]
fn test_mean_pooling_output_shape() {
let lhs = vec![0.0f32; 2 * 3 * 4]; let mask = vec![1i64; 2 * 3]; let result = mean_pooling(&lhs, 2, 3, 4, &mask);
assert_eq!(result.len(), 2);
assert_eq!(result[0].len(), 4);
assert_eq!(result[1].len(), 4);
}
#[test]
fn test_mean_pooling_uniform_hidden_all_ones_mask() {
let lhs = vec![2.0f32; 1 * 4 * 3];
let mask = vec![1i64; 1 * 4];
let result = mean_pooling(&lhs, 1, 4, 3, &mask);
assert_eq!(result.len(), 1);
for v in &result[0] {
assert!((v - 2.0).abs() < 1e-5, "expected 2.0, got {v}");
}
}
#[test]
fn test_mean_pooling_masked_tokens_ignored() {
let lhs = vec![1.0f32, 1.0, 9.0, 9.0];
let mask = vec![1i64, 0i64];
let result = mean_pooling(&lhs, 1, 2, 2, &mask);
assert!(
(result[0][0] - 1.0).abs() < 1e-5,
"expected 1.0, got {}",
result[0][0]
);
assert!(
(result[0][1] - 1.0).abs() < 1e-5,
"expected 1.0, got {}",
result[0][1]
);
}
#[test]
fn test_mean_pooling_batch_independence() {
let lhs = vec![3.0f32, 4.0, 6.0, 8.0];
let mask = vec![1i64, 1i64];
let result = mean_pooling(&lhs, 2, 1, 2, &mask);
assert_eq!(result.len(), 2);
assert!((result[0][0] - 3.0).abs() < 1e-5);
assert!((result[0][1] - 4.0).abs() < 1e-5);
assert!((result[1][0] - 6.0).abs() < 1e-5);
assert!((result[1][1] - 8.0).abs() < 1e-5);
}
#[test]
fn test_normalize_embeddings_unit_length() {
let mut embeddings = vec![vec![3.0f32, 4.0]];
normalize_embeddings(&mut embeddings);
let norm: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"L2 norm should be 1.0, got {norm}"
);
}
#[test]
fn test_normalize_embeddings_values() {
let mut embeddings = vec![vec![3.0f32, 4.0]];
normalize_embeddings(&mut embeddings);
assert!(
(embeddings[0][0] - 0.6).abs() < 1e-5,
"expected 0.6, got {}",
embeddings[0][0]
);
assert!(
(embeddings[0][1] - 0.8).abs() < 1e-5,
"expected 0.8, got {}",
embeddings[0][1]
);
}
#[test]
fn test_normalize_embeddings_batch() {
let mut embeddings = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
normalize_embeddings(&mut embeddings);
let norm0: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
let norm1: f32 = embeddings[1].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm0 - 1.0).abs() < 1e-5);
assert!((norm1 - 1.0).abs() < 1e-5);
}
#[test]
fn test_normalize_embeddings_output_shape() {
let mut embeddings: Vec<Vec<f32>> = (1..=3)
.map(|i| (1..=4).map(|j| (i * j) as f32).collect())
.collect();
normalize_embeddings(&mut embeddings);
assert_eq!(embeddings.len(), 3);
assert!(embeddings.iter().all(|v| v.len() == 4));
}
#[test]
fn test_normalize_embeddings_near_zero_safe() {
let mut embeddings = vec![vec![1e-14f32, 1e-14]];
normalize_embeddings(&mut embeddings);
for v in &embeddings[0] {
assert!(v.is_finite(), "expected finite value, got {v}");
}
}
#[test]
fn test_tokenize_batch_single_text_success() {
let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts = vec!["hello world".to_string()];
let result = processor.tokenize_batch(&texts);
assert!(result.is_ok(), "Expected Ok, got {:?}", result);
let batch = result.unwrap();
assert_eq!(batch.batch_size, 1);
assert_eq!(batch.original_lengths, vec![11]);
}
#[test]
fn test_tokenize_batch_tensor_shapes_single() {
let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts = vec!["hello world".to_string()];
let batch = processor.tokenize_batch(&texts).unwrap();
assert_eq!(batch.batch_size, 1);
assert_eq!(batch.input_ids.len(), batch.batch_size * batch.seq_len);
assert_eq!(batch.attention_mask.len(), batch.batch_size * batch.seq_len);
assert_eq!(batch.token_type_ids.len(), batch.batch_size * batch.seq_len);
}
#[test]
fn test_tokenize_batch_multiple_texts_batch_dim() {
let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts = vec!["hello".to_string(), "hello world test".to_string()];
let batch = processor.tokenize_batch(&texts).unwrap();
assert_eq!(batch.batch_size, 2);
assert_eq!(batch.original_lengths.len(), 2);
assert_eq!(batch.input_ids.len(), batch.batch_size * batch.seq_len);
}
#[test]
fn test_tokenize_batch_token_type_ids_default_zeros() {
let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts = vec!["hello world".to_string()];
let batch = processor.tokenize_batch(&texts).unwrap();
for &v in &batch.token_type_ids {
assert_eq!(v, 0, "Expected zero token_type_id from WordLevel, got {v}");
}
}
#[test]
fn test_tokenize_batch_original_lengths_preserved() {
let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts = vec!["hello".to_string(), "hello world".to_string()];
let batch = processor.tokenize_batch(&texts).unwrap();
assert_eq!(batch.original_lengths[0], 5);
assert_eq!(batch.original_lengths[1], 11);
}
#[test]
fn test_tokenize_batch_three_texts_batch_size_field() {
let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts = vec!["hello".to_string(), "world".to_string(), "test".to_string()];
let batch = processor.tokenize_batch(&texts).unwrap();
assert_eq!(batch.batch_size, 3);
}
#[test]
fn test_tokenize_batch_all_arrays_consistent_length() {
let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts = vec!["foo bar".to_string(), "baz".to_string()];
let batch = processor.tokenize_batch(&texts).unwrap();
let expected_len = batch.batch_size * batch.seq_len;
assert_eq!(batch.input_ids.len(), expected_len);
assert_eq!(batch.attention_mask.len(), expected_len);
assert_eq!(batch.token_type_ids.len(), expected_len);
}
#[test]
fn test_tokenize_batch_ids_are_i64() {
let processor = BatchProcessor::new(simple_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts = vec!["hello world".to_string()];
let batch = processor.tokenize_batch(&texts).unwrap();
for &id in &batch.input_ids {
assert!(id >= 0, "input_id should be non-negative, got {id}");
}
for &m in &batch.attention_mask {
assert!(m == 0 || m == 1, "attention_mask should be 0 or 1, got {m}");
}
}
}