1use crate::error::{TokenizerError, TokenizerResult};
7use crate::SignalTokenizer;
8use scirs2_core::ndarray::{Array1, Array2};
9
10pub trait BatchTokenizer: SignalTokenizer {
12 fn encode_batch(&self, signals: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
17 let batch_size = signals.shape()[0];
18 let mut results = Vec::with_capacity(batch_size);
19
20 for i in 0..batch_size {
21 let signal = signals.row(i).to_owned();
22 let encoded = self.encode(&signal)?;
23 results.push(encoded);
24 }
25
26 batch_from_vec(results)
28 }
29
30 fn decode_batch(&self, tokens: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
35 let batch_size = tokens.shape()[0];
36 let mut results = Vec::with_capacity(batch_size);
37
38 for i in 0..batch_size {
39 let token_seq = tokens.row(i).to_owned();
40 let decoded = self.decode(&token_seq)?;
41 results.push(decoded);
42 }
43
44 batch_from_vec(results)
46 }
47
48 fn encode_batch_padded_to(
56 &self,
57 signals: &[Array1<f32>],
58 target_len: usize,
59 ) -> TokenizerResult<Array2<f32>> {
60 if signals.is_empty() {
61 return Err(TokenizerError::InvalidConfig("Empty batch".into()));
62 }
63
64 let mut results = Vec::with_capacity(signals.len());
65
66 for signal in signals {
67 let mut padded = Array1::zeros(target_len);
69 let copy_len = signal.len().min(target_len);
70 for i in 0..copy_len {
71 padded[i] = signal[i];
72 }
73
74 let encoded = self.encode(&padded)?;
75 results.push(encoded);
76 }
77
78 batch_from_vec(results)
80 }
81
82 #[cfg(feature = "parallel")]
86 fn encode_batch_parallel(&self, signals: &Array2<f32>) -> TokenizerResult<Array2<f32>>
87 where
88 Self: Sync,
89 {
90 use rayon::prelude::*;
91
92 let batch_size = signals.shape()[0];
93 let results: Result<Vec<_>, _> = (0..batch_size)
94 .into_par_iter()
95 .map(|i| {
96 let signal = signals.row(i).to_owned();
97 self.encode(&signal)
98 })
99 .collect();
100
101 let results = results?;
102 batch_from_vec(results)
103 }
104
105 #[cfg(feature = "parallel")]
107 fn decode_batch_parallel(&self, tokens: &Array2<f32>) -> TokenizerResult<Array2<f32>>
108 where
109 Self: Sync,
110 {
111 use rayon::prelude::*;
112
113 let batch_size = tokens.shape()[0];
114 let results: Result<Vec<_>, _> = (0..batch_size)
115 .into_par_iter()
116 .map(|i| {
117 let token_seq = tokens.row(i).to_owned();
118 self.decode(&token_seq)
119 })
120 .collect();
121
122 let results = results?;
123 batch_from_vec(results)
124 }
125}
126
127fn batch_from_vec(arrays: Vec<Array1<f32>>) -> TokenizerResult<Array2<f32>> {
129 if arrays.is_empty() {
130 return Err(TokenizerError::InvalidConfig("Empty batch".into()));
131 }
132
133 let batch_size = arrays.len();
134 let elem_len = arrays[0].len();
135
136 for arr in arrays.iter() {
138 if arr.len() != elem_len {
139 return Err(TokenizerError::dim_mismatch(
140 elem_len,
141 arr.len(),
142 "dimension validation",
143 ));
144 }
145 }
146
147 let mut batch = Array2::zeros((batch_size, elem_len));
148 for (i, arr) in arrays.iter().enumerate() {
149 for (j, &val) in arr.iter().enumerate() {
150 batch[[i, j]] = val;
151 }
152 }
153
154 Ok(batch)
155}
156
157pub struct StreamingTokenizer<T: SignalTokenizer> {
159 tokenizer: T,
160 chunk_size: usize,
161 overlap: usize,
162}
163
164impl<T: SignalTokenizer> StreamingTokenizer<T> {
165 pub fn new(tokenizer: T, chunk_size: usize, overlap: usize) -> TokenizerResult<Self> {
172 if overlap >= chunk_size {
173 return Err(TokenizerError::InvalidConfig(
174 "Overlap must be less than chunk_size".into(),
175 ));
176 }
177
178 Ok(Self {
179 tokenizer,
180 chunk_size,
181 overlap,
182 })
183 }
184
185 pub fn encode_streaming(&self, signal: &Array1<f32>) -> TokenizerResult<Vec<Array1<f32>>> {
187 let mut chunks = Vec::new();
188 let stride = self.chunk_size - self.overlap;
189
190 let mut start = 0;
191 while start < signal.len() {
192 let end = (start + self.chunk_size).min(signal.len());
193 let chunk = Array1::from_vec(
194 signal
195 .iter()
196 .skip(start)
197 .take(end - start)
198 .cloned()
199 .collect(),
200 );
201
202 let chunk = if chunk.len() < self.chunk_size {
204 let mut padded = Array1::zeros(self.chunk_size);
205 for (i, &val) in chunk.iter().enumerate() {
206 padded[i] = val;
207 }
208 padded
209 } else {
210 chunk
211 };
212
213 chunks.push(self.tokenizer.encode(&chunk)?);
214 start += stride;
215 }
216
217 Ok(chunks)
218 }
219
220 pub fn decode_streaming(&self, chunks: &[Array1<f32>]) -> TokenizerResult<Array1<f32>> {
222 if chunks.is_empty() {
223 return Err(TokenizerError::InvalidConfig("No chunks to decode".into()));
224 }
225
226 let mut decoded_chunks = Vec::new();
227 for chunk in chunks {
228 decoded_chunks.push(self.tokenizer.decode(chunk)?);
229 }
230
231 let stride = self.chunk_size - self.overlap;
233 let total_len = (chunks.len() - 1) * stride + self.chunk_size;
234 let mut result = Array1::<f32>::zeros(total_len);
235 let mut weight = Array1::<f32>::zeros(total_len);
236
237 for (i, chunk) in decoded_chunks.iter().enumerate() {
238 let start = i * stride;
239 for (j, &val) in chunk.iter().enumerate() {
240 if start + j < total_len {
241 result[start + j] += val;
242 weight[start + j] += 1.0;
243 }
244 }
245 }
246
247 for i in 0..total_len {
249 if weight[i] > 0.0 {
250 result[i] /= weight[i];
251 }
252 }
253
254 Ok(result)
255 }
256
257 pub fn tokenizer(&self) -> &T {
259 &self.tokenizer
260 }
261
262 pub fn chunk_size(&self) -> usize {
264 self.chunk_size
265 }
266
267 pub fn overlap(&self) -> usize {
269 self.overlap
270 }
271}
272
273impl<T: SignalTokenizer> BatchTokenizer for T {}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::ContinuousTokenizer;
280
281 #[test]
282 fn test_batch_from_vec() {
283 let arrays = vec![
284 Array1::from_vec(vec![1.0, 2.0, 3.0]),
285 Array1::from_vec(vec![4.0, 5.0, 6.0]),
286 Array1::from_vec(vec![7.0, 8.0, 9.0]),
287 ];
288
289 let batch = batch_from_vec(arrays).unwrap();
290 assert_eq!(batch.shape(), &[3, 3]);
291 assert_eq!(batch[[0, 0]], 1.0);
292 assert_eq!(batch[[2, 2]], 9.0);
293 }
294
295 #[test]
296 fn test_batch_from_vec_dimension_mismatch() {
297 let arrays = vec![
298 Array1::from_vec(vec![1.0, 2.0]),
299 Array1::from_vec(vec![3.0, 4.0, 5.0]), ];
301
302 assert!(batch_from_vec(arrays).is_err());
303 }
304
305 #[test]
306 fn test_encode_batch() {
307 let tokenizer = ContinuousTokenizer::new(4, 8);
308
309 let signals = Array2::from_shape_fn((3, 4), |(i, j)| (i * 4 + j) as f32 * 0.1);
310
311 let encoded = tokenizer.encode_batch(&signals).unwrap();
312 assert_eq!(encoded.shape(), &[3, 8]);
313 }
314
315 #[test]
316 fn test_decode_batch() {
317 let tokenizer = ContinuousTokenizer::new(4, 8);
318
319 let tokens = Array2::from_shape_fn((3, 8), |(i, j)| (i * 8 + j) as f32 * 0.1);
320
321 let decoded = tokenizer.decode_batch(&tokens).unwrap();
322 assert_eq!(decoded.shape(), &[3, 4]);
323 }
324
325 #[test]
326 fn test_encode_decode_batch_roundtrip() {
327 let tokenizer = ContinuousTokenizer::new(10, 20);
328
329 let signals = Array2::from_shape_fn((5, 10), |(i, j)| ((i * 10 + j) as f32 * 0.05).sin());
330
331 let encoded = tokenizer.encode_batch(&signals).unwrap();
332 let decoded = tokenizer.decode_batch(&encoded).unwrap();
333
334 assert_eq!(decoded.shape(), signals.shape());
335 }
336
337 #[test]
338 fn test_encode_batch_padded() {
339 let tokenizer = ContinuousTokenizer::new(10, 20);
340
341 let signals = vec![
342 Array1::from_vec(vec![1.0, 2.0, 3.0]),
343 Array1::from_vec(vec![4.0, 5.0, 6.0, 7.0, 8.0]),
344 Array1::from_vec(vec![9.0, 10.0]),
345 ];
346
347 let encoded = tokenizer.encode_batch_padded_to(&signals, 10).unwrap();
349 assert_eq!(encoded.shape(), &[3, 20]);
350 }
351
352 #[test]
353 fn test_streaming_tokenizer() {
354 let tokenizer = ContinuousTokenizer::new(8, 16);
355 let streaming = StreamingTokenizer::new(tokenizer, 8, 2).unwrap();
356
357 let signal = Array1::from_vec((0..20).map(|i| (i as f32 * 0.1).sin()).collect());
359
360 let chunks = streaming.encode_streaming(&signal).unwrap();
361 assert!(chunks.len() > 1); let reconstructed = streaming.decode_streaming(&chunks).unwrap();
364 assert!(reconstructed.len() >= signal.len());
366 }
367
368 #[test]
369 fn test_streaming_tokenizer_overlap_validation() {
370 let tokenizer = ContinuousTokenizer::new(8, 16);
371
372 assert!(StreamingTokenizer::new(tokenizer.clone(), 8, 8).is_err());
374 assert!(StreamingTokenizer::new(tokenizer, 8, 9).is_err());
375 }
376
377 #[test]
378 fn test_streaming_short_signal() {
379 let tokenizer = ContinuousTokenizer::new(16, 32);
380 let streaming = StreamingTokenizer::new(tokenizer, 16, 4).unwrap();
381
382 let signal = Array1::from_vec((0..10).map(|i| i as f32).collect());
384
385 let chunks = streaming.encode_streaming(&signal).unwrap();
386 assert_eq!(chunks.len(), 1); let reconstructed = streaming.decode_streaming(&chunks).unwrap();
389 assert_eq!(reconstructed.len(), 16); }
391}