kizzasi_tokenizer/
batch.rs

1//! Batch processing for efficient tokenization of multiple signals
2//!
3//! Provides batch encode/decode operations that can leverage
4//! vectorization and parallelization for better performance.
5
6use crate::error::{TokenizerError, TokenizerResult};
7use crate::SignalTokenizer;
8use scirs2_core::ndarray::{Array1, Array2};
9
10/// Trait for batch tokenization operations
11pub trait BatchTokenizer: SignalTokenizer {
12    /// Encode multiple signals in batch
13    ///
14    /// Input: batch of signals [batch_size, signal_length]
15    /// Output: batch of encodings [batch_size, embed_dim]
16    fn encode_batch(&self, signals: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
17        let batch_size = signals.shape()[0];
18        let mut results = Vec::with_capacity(batch_size);
19
20        for i in 0..batch_size {
21            let signal = signals.row(i).to_owned();
22            let encoded = self.encode(&signal)?;
23            results.push(encoded);
24        }
25
26        // Stack results into batch
27        batch_from_vec(results)
28    }
29
30    /// Decode multiple token sequences in batch
31    ///
32    /// Input: batch of tokens [batch_size, embed_dim]
33    /// Output: batch of signals [batch_size, signal_length]
34    fn decode_batch(&self, tokens: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
35        let batch_size = tokens.shape()[0];
36        let mut results = Vec::with_capacity(batch_size);
37
38        for i in 0..batch_size {
39            let token_seq = tokens.row(i).to_owned();
40            let decoded = self.decode(&token_seq)?;
41            results.push(decoded);
42        }
43
44        // Stack results into batch
45        batch_from_vec(results)
46    }
47
48    /// Encode batch with padding to handle variable length signals
49    ///
50    /// Pads/truncates all signals to the specified target length before encoding.
51    ///
52    /// # Arguments
53    /// * `signals` - Variable-length signals to encode
54    /// * `target_len` - Target length to pad/truncate to (should match tokenizer's expected input dim)
55    fn encode_batch_padded_to(
56        &self,
57        signals: &[Array1<f32>],
58        target_len: usize,
59    ) -> TokenizerResult<Array2<f32>> {
60        if signals.is_empty() {
61            return Err(TokenizerError::InvalidConfig("Empty batch".into()));
62        }
63
64        let mut results = Vec::with_capacity(signals.len());
65
66        for signal in signals {
67            // Pad or truncate to target length
68            let mut padded = Array1::zeros(target_len);
69            let copy_len = signal.len().min(target_len);
70            for i in 0..copy_len {
71                padded[i] = signal[i];
72            }
73
74            let encoded = self.encode(&padded)?;
75            results.push(encoded);
76        }
77
78        // Stack results into batch
79        batch_from_vec(results)
80    }
81
82    /// Process batch in parallel using multiple threads
83    ///
84    /// This can provide significant speedup for large batches
85    #[cfg(feature = "parallel")]
86    fn encode_batch_parallel(&self, signals: &Array2<f32>) -> TokenizerResult<Array2<f32>>
87    where
88        Self: Sync,
89    {
90        use rayon::prelude::*;
91
92        let batch_size = signals.shape()[0];
93        let results: Result<Vec<_>, _> = (0..batch_size)
94            .into_par_iter()
95            .map(|i| {
96                let signal = signals.row(i).to_owned();
97                self.encode(&signal)
98            })
99            .collect();
100
101        let results = results?;
102        batch_from_vec(results)
103    }
104
105    /// Decode batch in parallel
106    #[cfg(feature = "parallel")]
107    fn decode_batch_parallel(&self, tokens: &Array2<f32>) -> TokenizerResult<Array2<f32>>
108    where
109        Self: Sync,
110    {
111        use rayon::prelude::*;
112
113        let batch_size = tokens.shape()[0];
114        let results: Result<Vec<_>, _> = (0..batch_size)
115            .into_par_iter()
116            .map(|i| {
117                let token_seq = tokens.row(i).to_owned();
118                self.decode(&token_seq)
119            })
120            .collect();
121
122        let results = results?;
123        batch_from_vec(results)
124    }
125}
126
127/// Helper function to stack Array1 vectors into Array2 batch
128fn batch_from_vec(arrays: Vec<Array1<f32>>) -> TokenizerResult<Array2<f32>> {
129    if arrays.is_empty() {
130        return Err(TokenizerError::InvalidConfig("Empty batch".into()));
131    }
132
133    let batch_size = arrays.len();
134    let elem_len = arrays[0].len();
135
136    // Verify all arrays have same length
137    for arr in arrays.iter() {
138        if arr.len() != elem_len {
139            return Err(TokenizerError::dim_mismatch(
140                elem_len,
141                arr.len(),
142                "dimension validation",
143            ));
144        }
145    }
146
147    let mut batch = Array2::zeros((batch_size, elem_len));
148    for (i, arr) in arrays.iter().enumerate() {
149        for (j, &val) in arr.iter().enumerate() {
150            batch[[i, j]] = val;
151        }
152    }
153
154    Ok(batch)
155}
156
157/// Streaming tokenizer for processing long sequences in chunks
158pub struct StreamingTokenizer<T: SignalTokenizer> {
159    tokenizer: T,
160    chunk_size: usize,
161    overlap: usize,
162}
163
164impl<T: SignalTokenizer> StreamingTokenizer<T> {
165    /// Create a new streaming tokenizer
166    ///
167    /// # Arguments
168    /// * `tokenizer` - The underlying tokenizer
169    /// * `chunk_size` - Size of each chunk
170    /// * `overlap` - Overlap between chunks (for continuity)
171    pub fn new(tokenizer: T, chunk_size: usize, overlap: usize) -> TokenizerResult<Self> {
172        if overlap >= chunk_size {
173            return Err(TokenizerError::InvalidConfig(
174                "Overlap must be less than chunk_size".into(),
175            ));
176        }
177
178        Ok(Self {
179            tokenizer,
180            chunk_size,
181            overlap,
182        })
183    }
184
185    /// Encode a long signal in chunks
186    pub fn encode_streaming(&self, signal: &Array1<f32>) -> TokenizerResult<Vec<Array1<f32>>> {
187        let mut chunks = Vec::new();
188        let stride = self.chunk_size - self.overlap;
189
190        let mut start = 0;
191        while start < signal.len() {
192            let end = (start + self.chunk_size).min(signal.len());
193            let chunk = Array1::from_vec(
194                signal
195                    .iter()
196                    .skip(start)
197                    .take(end - start)
198                    .cloned()
199                    .collect(),
200            );
201
202            // Pad last chunk if needed
203            let chunk = if chunk.len() < self.chunk_size {
204                let mut padded = Array1::zeros(self.chunk_size);
205                for (i, &val) in chunk.iter().enumerate() {
206                    padded[i] = val;
207                }
208                padded
209            } else {
210                chunk
211            };
212
213            chunks.push(self.tokenizer.encode(&chunk)?);
214            start += stride;
215        }
216
217        Ok(chunks)
218    }
219
220    /// Decode chunks and reconstruct signal with overlap handling
221    pub fn decode_streaming(&self, chunks: &[Array1<f32>]) -> TokenizerResult<Array1<f32>> {
222        if chunks.is_empty() {
223            return Err(TokenizerError::InvalidConfig("No chunks to decode".into()));
224        }
225
226        let mut decoded_chunks = Vec::new();
227        for chunk in chunks {
228            decoded_chunks.push(self.tokenizer.decode(chunk)?);
229        }
230
231        // Reconstruct with overlap blending
232        let stride = self.chunk_size - self.overlap;
233        let total_len = (chunks.len() - 1) * stride + self.chunk_size;
234        let mut result = Array1::<f32>::zeros(total_len);
235        let mut weight = Array1::<f32>::zeros(total_len);
236
237        for (i, chunk) in decoded_chunks.iter().enumerate() {
238            let start = i * stride;
239            for (j, &val) in chunk.iter().enumerate() {
240                if start + j < total_len {
241                    result[start + j] += val;
242                    weight[start + j] += 1.0;
243                }
244            }
245        }
246
247        // Average overlapping regions
248        for i in 0..total_len {
249            if weight[i] > 0.0 {
250                result[i] /= weight[i];
251            }
252        }
253
254        Ok(result)
255    }
256
257    /// Get the underlying tokenizer
258    pub fn tokenizer(&self) -> &T {
259        &self.tokenizer
260    }
261
262    /// Get chunk size
263    pub fn chunk_size(&self) -> usize {
264        self.chunk_size
265    }
266
267    /// Get overlap size
268    pub fn overlap(&self) -> usize {
269        self.overlap
270    }
271}
272
273// Implement BatchTokenizer for all SignalTokenizer types
274impl<T: SignalTokenizer> BatchTokenizer for T {}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::ContinuousTokenizer;
280
281    #[test]
282    fn test_batch_from_vec() {
283        let arrays = vec![
284            Array1::from_vec(vec![1.0, 2.0, 3.0]),
285            Array1::from_vec(vec![4.0, 5.0, 6.0]),
286            Array1::from_vec(vec![7.0, 8.0, 9.0]),
287        ];
288
289        let batch = batch_from_vec(arrays).unwrap();
290        assert_eq!(batch.shape(), &[3, 3]);
291        assert_eq!(batch[[0, 0]], 1.0);
292        assert_eq!(batch[[2, 2]], 9.0);
293    }
294
295    #[test]
296    fn test_batch_from_vec_dimension_mismatch() {
297        let arrays = vec![
298            Array1::from_vec(vec![1.0, 2.0]),
299            Array1::from_vec(vec![3.0, 4.0, 5.0]), // Wrong length
300        ];
301
302        assert!(batch_from_vec(arrays).is_err());
303    }
304
305    #[test]
306    fn test_encode_batch() {
307        let tokenizer = ContinuousTokenizer::new(4, 8);
308
309        let signals = Array2::from_shape_fn((3, 4), |(i, j)| (i * 4 + j) as f32 * 0.1);
310
311        let encoded = tokenizer.encode_batch(&signals).unwrap();
312        assert_eq!(encoded.shape(), &[3, 8]);
313    }
314
315    #[test]
316    fn test_decode_batch() {
317        let tokenizer = ContinuousTokenizer::new(4, 8);
318
319        let tokens = Array2::from_shape_fn((3, 8), |(i, j)| (i * 8 + j) as f32 * 0.1);
320
321        let decoded = tokenizer.decode_batch(&tokens).unwrap();
322        assert_eq!(decoded.shape(), &[3, 4]);
323    }
324
325    #[test]
326    fn test_encode_decode_batch_roundtrip() {
327        let tokenizer = ContinuousTokenizer::new(10, 20);
328
329        let signals = Array2::from_shape_fn((5, 10), |(i, j)| ((i * 10 + j) as f32 * 0.05).sin());
330
331        let encoded = tokenizer.encode_batch(&signals).unwrap();
332        let decoded = tokenizer.decode_batch(&encoded).unwrap();
333
334        assert_eq!(decoded.shape(), signals.shape());
335    }
336
337    #[test]
338    fn test_encode_batch_padded() {
339        let tokenizer = ContinuousTokenizer::new(10, 20);
340
341        let signals = vec![
342            Array1::from_vec(vec![1.0, 2.0, 3.0]),
343            Array1::from_vec(vec![4.0, 5.0, 6.0, 7.0, 8.0]),
344            Array1::from_vec(vec![9.0, 10.0]),
345        ];
346
347        // Pad to input dimension (10)
348        let encoded = tokenizer.encode_batch_padded_to(&signals, 10).unwrap();
349        assert_eq!(encoded.shape(), &[3, 20]);
350    }
351
352    #[test]
353    fn test_streaming_tokenizer() {
354        let tokenizer = ContinuousTokenizer::new(8, 16);
355        let streaming = StreamingTokenizer::new(tokenizer, 8, 2).unwrap();
356
357        // Create a signal longer than chunk_size
358        let signal = Array1::from_vec((0..20).map(|i| (i as f32 * 0.1).sin()).collect());
359
360        let chunks = streaming.encode_streaming(&signal).unwrap();
361        assert!(chunks.len() > 1); // Should be split into chunks
362
363        let reconstructed = streaming.decode_streaming(&chunks).unwrap();
364        // Allow some length difference due to padding
365        assert!(reconstructed.len() >= signal.len());
366    }
367
368    #[test]
369    fn test_streaming_tokenizer_overlap_validation() {
370        let tokenizer = ContinuousTokenizer::new(8, 16);
371
372        // Overlap >= chunk_size should fail
373        assert!(StreamingTokenizer::new(tokenizer.clone(), 8, 8).is_err());
374        assert!(StreamingTokenizer::new(tokenizer, 8, 9).is_err());
375    }
376
377    #[test]
378    fn test_streaming_short_signal() {
379        let tokenizer = ContinuousTokenizer::new(16, 32);
380        let streaming = StreamingTokenizer::new(tokenizer, 16, 4).unwrap();
381
382        // Signal shorter than chunk_size
383        let signal = Array1::from_vec((0..10).map(|i| i as f32).collect());
384
385        let chunks = streaming.encode_streaming(&signal).unwrap();
386        assert_eq!(chunks.len(), 1); // Should be a single chunk
387
388        let reconstructed = streaming.decode_streaming(&chunks).unwrap();
389        assert_eq!(reconstructed.len(), 16); // Padded to chunk_size
390    }
391}