Skip to main content

inference/
batch.rs

1//! Batch processing utilities for efficient embedding generation.
2
3use crate::error::{InferenceError, Result};
4use crate::models::EmbeddingModel;
5use candle_core::{Device, Tensor};
6use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
7use tracing::{debug, instrument};
8
9/// Prepared batch of tokenized inputs ready for inference.
10#[derive(Debug)]
11pub struct PreparedBatch {
12    /// Token IDs tensor [batch_size, seq_len]
13    pub input_ids: Tensor,
14    /// Attention mask tensor [batch_size, seq_len]
15    pub attention_mask: Tensor,
16    /// Token type IDs tensor [batch_size, seq_len]
17    pub token_type_ids: Tensor,
18    /// Number of items in this batch
19    pub batch_size: usize,
20    /// Original text lengths (for debugging)
21    pub original_lengths: Vec<usize>,
22}
23
24/// Batch processor for preparing text inputs for embedding models.
25pub struct BatchProcessor {
26    tokenizer: Tokenizer,
27    model: EmbeddingModel,
28    max_batch_size: usize,
29}
30
31impl BatchProcessor {
32    /// Create a new batch processor.
33    pub fn new(mut tokenizer: Tokenizer, model: EmbeddingModel, max_batch_size: usize) -> Self {
34        // Configure padding
35        let padding = PaddingParams {
36            strategy: PaddingStrategy::BatchLongest,
37            pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
38            pad_token: tokenizer
39                .get_padding()
40                .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
41            ..Default::default()
42        };
43        tokenizer.with_padding(Some(padding));
44
45        // Configure truncation
46        let truncation = TruncationParams {
47            max_length: model.max_seq_length(),
48            ..Default::default()
49        };
50        let _ = tokenizer.with_truncation(Some(truncation));
51
52        Self {
53            tokenizer,
54            model,
55            max_batch_size,
56        }
57    }
58
59    /// Get the maximum batch size.
60    pub fn max_batch_size(&self) -> usize {
61        self.max_batch_size
62    }
63
64    /// Prepare texts for embedding, optionally applying model-specific prefixes.
65    #[instrument(skip(self, texts), fields(count = texts.len()))]
66    pub fn prepare_texts(&self, texts: &[String], is_query: bool) -> Vec<String> {
67        let prefix = if is_query {
68            self.model.query_prefix()
69        } else {
70            self.model.document_prefix()
71        };
72
73        match prefix {
74            Some(p) => texts.iter().map(|t| format!("{}{}", p, t)).collect(),
75            None => texts.to_vec(),
76        }
77    }
78
79    /// Tokenize a batch of texts and prepare tensors for the model.
80    #[instrument(skip(self, texts, device), fields(count = texts.len()))]
81    pub fn tokenize_batch(&self, texts: &[String], device: &Device) -> Result<PreparedBatch> {
82        if texts.is_empty() {
83            return Err(InferenceError::InvalidInput("Empty text batch".into()));
84        }
85
86        if texts.len() > self.max_batch_size {
87            return Err(InferenceError::InvalidInput(format!(
88                "Batch size {} exceeds maximum {}",
89                texts.len(),
90                self.max_batch_size
91            )));
92        }
93
94        let original_lengths: Vec<usize> = texts.iter().map(|t| t.len()).collect();
95
96        debug!(
97            "Tokenizing {} texts, max length: {}",
98            texts.len(),
99            original_lengths.iter().max().unwrap_or(&0)
100        );
101
102        // Tokenize all texts
103        let encodings = self
104            .tokenizer
105            .encode_batch(texts.to_vec(), true)
106            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
107
108        let batch_size = encodings.len();
109
110        // Extract token IDs
111        let input_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
112
113        // Extract attention masks
114        let attention_masks: Vec<Vec<u32>> = encodings
115            .iter()
116            .map(|e| e.get_attention_mask().to_vec())
117            .collect();
118
119        // Extract token type IDs (or create zeros)
120        let token_type_ids: Vec<Vec<u32>> = encodings
121            .iter()
122            .map(|e| {
123                let type_ids = e.get_type_ids();
124                if type_ids.is_empty() {
125                    vec![0u32; e.get_ids().len()]
126                } else {
127                    type_ids.to_vec()
128                }
129            })
130            .collect();
131
132        // Get sequence length (should be uniform after padding)
133        let seq_len = input_ids.first().map(|v| v.len()).unwrap_or(0);
134
135        // Convert to tensors
136        let input_ids_flat: Vec<u32> = input_ids.into_iter().flatten().collect();
137        let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
138        let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
139
140        let input_ids = Tensor::from_vec(input_ids_flat, (batch_size, seq_len), device)?;
141        let attention_mask = Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), device)?;
142        let token_type_ids = Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), device)?;
143
144        debug!(
145            "Created tensors: input_ids {:?}, attention_mask {:?}",
146            input_ids.shape(),
147            attention_mask.shape()
148        );
149
150        Ok(PreparedBatch {
151            input_ids,
152            attention_mask,
153            token_type_ids,
154            batch_size,
155            original_lengths,
156        })
157    }
158
159    /// Split texts into batches of maximum size.
160    pub fn split_into_batches<'a>(&self, texts: &'a [String]) -> Vec<&'a [String]> {
161        texts.chunks(self.max_batch_size).collect()
162    }
163}
164
165/// Apply mean pooling to model outputs.
166///
167/// Mean pooling averages the token embeddings, weighted by the attention mask.
168#[instrument(skip_all)]
169pub fn mean_pooling(last_hidden_state: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
170    // Expand attention mask to match hidden state dimensions
171    // last_hidden_state: [batch, seq_len, hidden_size]
172    // attention_mask: [batch, seq_len]
173    let attention_mask = attention_mask.unsqueeze(2)?; // [batch, seq_len, 1]
174    let attention_mask = attention_mask.to_dtype(last_hidden_state.dtype())?;
175
176    // Expand to match hidden size (broadcast to last_hidden_state shape)
177    let attention_mask = attention_mask.broadcast_as(last_hidden_state.shape())?;
178
179    // Multiply hidden states by attention mask
180    let masked_hidden = last_hidden_state.mul(&attention_mask)?;
181
182    // Sum across sequence dimension
183    let sum_hidden = masked_hidden.sum(1)?; // [batch, hidden_size]
184
185    // Sum attention mask for normalization
186    let sum_mask = attention_mask.sum(1)?; // [batch, hidden_size]
187
188    // Clamp to avoid division by zero
189    let sum_mask = sum_mask.clamp(1e-9, f64::MAX)?;
190
191    // Divide to get mean
192    let mean_pooled = sum_hidden.div(&sum_mask)?;
193
194    debug!("Mean pooled shape: {:?}", mean_pooled.shape());
195
196    Ok(mean_pooled)
197}
198
199/// Normalize embeddings to unit length (L2 normalization).
200#[instrument(skip_all)]
201pub fn normalize_embeddings(embeddings: &Tensor) -> Result<Tensor> {
202    // Compute L2 norm across the embedding dimension
203    let norm = embeddings.sqr()?.sum_keepdim(1)?.sqrt()?;
204
205    // Clamp to avoid division by zero
206    let norm = norm.clamp(1e-12, f64::MAX)?;
207
208    // Normalize
209    let normalized = embeddings.broadcast_div(&norm)?;
210
211    debug!("Normalized embeddings shape: {:?}", normalized.shape());
212
213    Ok(normalized)
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    /// Create a minimal tokenizer for unit tests (no network required).
221    /// `prepare_texts` only uses the model's prefix logic, not the tokenizer,
222    /// so any valid tokenizer works here.
223    fn dummy_tokenizer() -> Tokenizer {
224        use tokenizers::models::bpe::BPE;
225        Tokenizer::new(BPE::default())
226    }
227
228    #[test]
229    fn test_prepare_texts_with_prefix() {
230        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
231
232        let texts = vec!["Hello world".to_string(), "Test query".to_string()];
233        let prepared = processor.prepare_texts(&texts, true);
234
235        assert_eq!(prepared[0], "query: Hello world");
236        assert_eq!(prepared[1], "query: Test query");
237    }
238
239    #[test]
240    fn test_prepare_texts_no_prefix() {
241        let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
242
243        let texts = vec!["Hello world".to_string()];
244        let prepared = processor.prepare_texts(&texts, true);
245
246        assert_eq!(prepared[0], "Hello world");
247    }
248}