use crate::error::{InferenceError, Result};
use crate::models::EmbeddingModel;
use candle_core::{Device, Tensor};
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use tracing::{debug, instrument};
#[derive(Debug)]
pub struct PreparedBatch {
pub input_ids: Tensor,
pub attention_mask: Tensor,
pub token_type_ids: Tensor,
pub batch_size: usize,
pub original_lengths: Vec<usize>,
}
pub struct BatchProcessor {
tokenizer: Tokenizer,
model: EmbeddingModel,
max_batch_size: usize,
}
impl BatchProcessor {
pub fn new(mut tokenizer: Tokenizer, model: EmbeddingModel, max_batch_size: usize) -> Self {
let padding = PaddingParams {
strategy: PaddingStrategy::BatchLongest,
pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
pad_token: tokenizer
.get_padding()
.map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
..Default::default()
};
tokenizer.with_padding(Some(padding));
let truncation = TruncationParams {
max_length: model.max_seq_length(),
..Default::default()
};
let _ = tokenizer.with_truncation(Some(truncation));
Self {
tokenizer,
model,
max_batch_size,
}
}
pub fn max_batch_size(&self) -> usize {
self.max_batch_size
}
#[instrument(skip(self, texts), fields(count = texts.len()))]
pub fn prepare_texts(&self, texts: &[String], is_query: bool) -> Vec<String> {
let prefix = if is_query {
self.model.query_prefix()
} else {
self.model.document_prefix()
};
match prefix {
Some(p) => texts.iter().map(|t| format!("{}{}", p, t)).collect(),
None => texts.to_vec(),
}
}
#[instrument(skip(self, texts, device), fields(count = texts.len()))]
pub fn tokenize_batch(&self, texts: &[String], device: &Device) -> Result<PreparedBatch> {
if texts.is_empty() {
return Err(InferenceError::InvalidInput("Empty text batch".into()));
}
if texts.len() > self.max_batch_size {
return Err(InferenceError::InvalidInput(format!(
"Batch size {} exceeds maximum {}",
texts.len(),
self.max_batch_size
)));
}
let original_lengths: Vec<usize> = texts.iter().map(|t| t.len()).collect();
debug!(
"Tokenizing {} texts, max length: {}",
texts.len(),
original_lengths.iter().max().unwrap_or(&0)
);
let encodings = self
.tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let batch_size = encodings.len();
let input_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
let attention_masks: Vec<Vec<u32>> = encodings
.iter()
.map(|e| e.get_attention_mask().to_vec())
.collect();
let token_type_ids: Vec<Vec<u32>> = encodings
.iter()
.map(|e| {
let type_ids = e.get_type_ids();
if type_ids.is_empty() {
vec![0u32; e.get_ids().len()]
} else {
type_ids.to_vec()
}
})
.collect();
let seq_len = input_ids.first().map(|v| v.len()).unwrap_or(0);
let input_ids_flat: Vec<u32> = input_ids.into_iter().flatten().collect();
let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
let input_ids = Tensor::from_vec(input_ids_flat, (batch_size, seq_len), device)?;
let attention_mask = Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), device)?;
let token_type_ids = Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), device)?;
debug!(
"Created tensors: input_ids {:?}, attention_mask {:?}",
input_ids.shape(),
attention_mask.shape()
);
Ok(PreparedBatch {
input_ids,
attention_mask,
token_type_ids,
batch_size,
original_lengths,
})
}
pub fn split_into_batches<'a>(&self, texts: &'a [String]) -> Vec<&'a [String]> {
texts.chunks(self.max_batch_size).collect()
}
}
#[instrument(skip_all)]
pub fn mean_pooling(last_hidden_state: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let attention_mask = attention_mask.unsqueeze(2)?; let attention_mask = attention_mask.to_dtype(last_hidden_state.dtype())?;
let attention_mask = attention_mask.broadcast_as(last_hidden_state.shape())?;
let masked_hidden = last_hidden_state.mul(&attention_mask)?;
let sum_hidden = masked_hidden.sum(1)?;
let sum_mask = attention_mask.sum(1)?;
let sum_mask = sum_mask.clamp(1e-9, f64::MAX)?;
let mean_pooled = sum_hidden.div(&sum_mask)?;
debug!("Mean pooled shape: {:?}", mean_pooled.shape());
Ok(mean_pooled)
}
#[instrument(skip_all)]
pub fn normalize_embeddings(embeddings: &Tensor) -> Result<Tensor> {
let norm = embeddings.sqr()?.sum_keepdim(1)?.sqrt()?;
let norm = norm.clamp(1e-12, f64::MAX)?;
let normalized = embeddings.broadcast_div(&norm)?;
debug!("Normalized embeddings shape: {:?}", normalized.shape());
Ok(normalized)
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_tokenizer() -> Tokenizer {
use tokenizers::models::bpe::BPE;
Tokenizer::new(BPE::default())
}
#[test]
fn test_prepare_texts_with_prefix() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
let texts = vec!["Hello world".to_string(), "Test query".to_string()];
let prepared = processor.prepare_texts(&texts, true);
assert_eq!(prepared[0], "query: Hello world");
assert_eq!(prepared[1], "query: Test query");
}
#[test]
fn test_prepare_texts_no_prefix() {
let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
let texts = vec!["Hello world".to_string()];
let prepared = processor.prepare_texts(&texts, true);
assert_eq!(prepared[0], "Hello world");
}
}