1use crate::error::{InferenceError, Result};
4use crate::models::EmbeddingModel;
5use candle_core::{Device, Tensor};
6use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
7use tracing::{debug, instrument};
8
9#[derive(Debug)]
11pub struct PreparedBatch {
12 pub input_ids: Tensor,
14 pub attention_mask: Tensor,
16 pub token_type_ids: Tensor,
18 pub batch_size: usize,
20 pub original_lengths: Vec<usize>,
22}
23
24pub struct BatchProcessor {
26 tokenizer: Tokenizer,
27 model: EmbeddingModel,
28 max_batch_size: usize,
29}
30
31impl BatchProcessor {
32 pub fn new(mut tokenizer: Tokenizer, model: EmbeddingModel, max_batch_size: usize) -> Self {
34 let padding = PaddingParams {
36 strategy: PaddingStrategy::BatchLongest,
37 pad_id: tokenizer.get_padding().map_or(0, |p| p.pad_id),
38 pad_token: tokenizer
39 .get_padding()
40 .map_or("[PAD]".to_string(), |p| p.pad_token.clone()),
41 ..Default::default()
42 };
43 tokenizer.with_padding(Some(padding));
44
45 let truncation = TruncationParams {
47 max_length: model.max_seq_length(),
48 ..Default::default()
49 };
50 let _ = tokenizer.with_truncation(Some(truncation));
51
52 Self {
53 tokenizer,
54 model,
55 max_batch_size,
56 }
57 }
58
59 pub fn max_batch_size(&self) -> usize {
61 self.max_batch_size
62 }
63
64 #[instrument(skip(self, texts), fields(count = texts.len()))]
66 pub fn prepare_texts(&self, texts: &[String], is_query: bool) -> Vec<String> {
67 let prefix = if is_query {
68 self.model.query_prefix()
69 } else {
70 self.model.document_prefix()
71 };
72
73 match prefix {
74 Some(p) => texts.iter().map(|t| format!("{}{}", p, t)).collect(),
75 None => texts.to_vec(),
76 }
77 }
78
79 #[instrument(skip(self, texts, device), fields(count = texts.len()))]
81 pub fn tokenize_batch(&self, texts: &[String], device: &Device) -> Result<PreparedBatch> {
82 if texts.is_empty() {
83 return Err(InferenceError::InvalidInput("Empty text batch".into()));
84 }
85
86 if texts.len() > self.max_batch_size {
87 return Err(InferenceError::InvalidInput(format!(
88 "Batch size {} exceeds maximum {}",
89 texts.len(),
90 self.max_batch_size
91 )));
92 }
93
94 let original_lengths: Vec<usize> = texts.iter().map(|t| t.len()).collect();
95
96 debug!(
97 "Tokenizing {} texts, max length: {}",
98 texts.len(),
99 original_lengths.iter().max().unwrap_or(&0)
100 );
101
102 let encodings = self
104 .tokenizer
105 .encode_batch(texts.to_vec(), true)
106 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
107
108 let batch_size = encodings.len();
109
110 let input_ids: Vec<Vec<u32>> = encodings.iter().map(|e| e.get_ids().to_vec()).collect();
112
113 let attention_masks: Vec<Vec<u32>> = encodings
115 .iter()
116 .map(|e| e.get_attention_mask().to_vec())
117 .collect();
118
119 let token_type_ids: Vec<Vec<u32>> = encodings
121 .iter()
122 .map(|e| {
123 let type_ids = e.get_type_ids();
124 if type_ids.is_empty() {
125 vec![0u32; e.get_ids().len()]
126 } else {
127 type_ids.to_vec()
128 }
129 })
130 .collect();
131
132 let seq_len = input_ids.first().map(|v| v.len()).unwrap_or(0);
134
135 let input_ids_flat: Vec<u32> = input_ids.into_iter().flatten().collect();
137 let attention_mask_flat: Vec<u32> = attention_masks.into_iter().flatten().collect();
138 let token_type_ids_flat: Vec<u32> = token_type_ids.into_iter().flatten().collect();
139
140 let input_ids = Tensor::from_vec(input_ids_flat, (batch_size, seq_len), device)?;
141 let attention_mask = Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), device)?;
142 let token_type_ids = Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), device)?;
143
144 debug!(
145 "Created tensors: input_ids {:?}, attention_mask {:?}",
146 input_ids.shape(),
147 attention_mask.shape()
148 );
149
150 Ok(PreparedBatch {
151 input_ids,
152 attention_mask,
153 token_type_ids,
154 batch_size,
155 original_lengths,
156 })
157 }
158
159 pub fn split_into_batches<'a>(&self, texts: &'a [String]) -> Vec<&'a [String]> {
161 texts.chunks(self.max_batch_size).collect()
162 }
163}
164
165#[instrument(skip_all)]
169pub fn mean_pooling(last_hidden_state: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
170 let attention_mask = attention_mask.unsqueeze(2)?; let attention_mask = attention_mask.to_dtype(last_hidden_state.dtype())?;
175
176 let attention_mask = attention_mask.broadcast_as(last_hidden_state.shape())?;
178
179 let masked_hidden = last_hidden_state.mul(&attention_mask)?;
181
182 let sum_hidden = masked_hidden.sum(1)?; let sum_mask = attention_mask.sum(1)?; let sum_mask = sum_mask.clamp(1e-9, f64::MAX)?;
190
191 let mean_pooled = sum_hidden.div(&sum_mask)?;
193
194 debug!("Mean pooled shape: {:?}", mean_pooled.shape());
195
196 Ok(mean_pooled)
197}
198
199#[instrument(skip_all)]
201pub fn normalize_embeddings(embeddings: &Tensor) -> Result<Tensor> {
202 let norm = embeddings.sqr()?.sum_keepdim(1)?.sqrt()?;
204
205 let norm = norm.clamp(1e-12, f64::MAX)?;
207
208 let normalized = embeddings.broadcast_div(&norm)?;
210
211 debug!("Normalized embeddings shape: {:?}", normalized.shape());
212
213 Ok(normalized)
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 fn dummy_tokenizer() -> Tokenizer {
224 use tokenizers::models::bpe::BPE;
225 Tokenizer::new(BPE::default())
226 }
227
228 #[test]
229 fn test_prepare_texts_with_prefix() {
230 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::E5Small, 32);
231
232 let texts = vec!["Hello world".to_string(), "Test query".to_string()];
233 let prepared = processor.prepare_texts(&texts, true);
234
235 assert_eq!(prepared[0], "query: Hello world");
236 assert_eq!(prepared[1], "query: Test query");
237 }
238
239 #[test]
240 fn test_prepare_texts_no_prefix() {
241 let processor = BatchProcessor::new(dummy_tokenizer(), EmbeddingModel::MiniLM, 32);
242
243 let texts = vec!["Hello world".to_string()];
244 let prepared = processor.prepare_texts(&texts, true);
245
246 assert_eq!(prepared[0], "Hello world");
247 }
248}