1use crate::error::{TokenizerError, TokenizerResult};
11use crate::SignalTokenizer;
12use scirs2_core::ndarray::{Array1, Array2};
13use scirs2_core::random::thread_rng;
14
15#[derive(Debug, Clone)]
17pub struct ScaleLevel {
18 downsample_factor: usize,
20 embed_dim: usize,
22 input_dim: usize,
24}
25
26impl ScaleLevel {
27 pub fn new(downsample_factor: usize, embed_dim: usize, input_dim: usize) -> Self {
29 Self {
30 downsample_factor,
31 embed_dim,
32 input_dim,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
47pub struct MultiScaleTokenizer {
48 encoders: Vec<Array2<f32>>,
50 decoders: Vec<Array2<f32>>,
52 levels: Vec<ScaleLevel>,
54 input_dim: usize,
56 pool_method: PoolMethod,
58 upsample_method: UpsampleMethod,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
64pub enum PoolMethod {
65 Stride,
67 #[default]
69 Average,
70 Max,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
76pub enum UpsampleMethod {
77 Repeat,
79 #[default]
81 Linear,
82}
83
84impl MultiScaleTokenizer {
85 pub fn new(input_dim: usize, embed_dim_per_level: usize) -> Self {
89 Self::with_factors(input_dim, embed_dim_per_level, &[1, 2, 4])
90 }
91
92 pub fn with_factors(input_dim: usize, embed_dim_per_level: usize, factors: &[usize]) -> Self {
94 let mut rng = thread_rng();
95 let mut encoders = Vec::with_capacity(factors.len());
96 let mut decoders = Vec::with_capacity(factors.len());
97 let mut levels = Vec::with_capacity(factors.len());
98
99 for &factor in factors {
100 let level_input_dim = input_dim / factor;
101 if level_input_dim == 0 {
102 continue;
103 }
104
105 let enc_scale = (2.0 / (level_input_dim + embed_dim_per_level) as f32).sqrt();
107 let encoder = Array2::from_shape_fn((level_input_dim, embed_dim_per_level), |_| {
108 (rng.random::<f32>() - 0.5) * 2.0 * enc_scale
109 });
110
111 let dec_scale = (2.0 / (embed_dim_per_level + level_input_dim) as f32).sqrt();
112 let decoder = Array2::from_shape_fn((embed_dim_per_level, level_input_dim), |_| {
113 (rng.random::<f32>() - 0.5) * 2.0 * dec_scale
114 });
115
116 encoders.push(encoder);
117 decoders.push(decoder);
118 levels.push(ScaleLevel::new(
119 factor,
120 embed_dim_per_level,
121 level_input_dim,
122 ));
123 }
124
125 Self {
126 encoders,
127 decoders,
128 levels,
129 input_dim,
130 pool_method: PoolMethod::default(),
131 upsample_method: UpsampleMethod::default(),
132 }
133 }
134
135 pub fn with_pool_method(mut self, method: PoolMethod) -> Self {
137 self.pool_method = method;
138 self
139 }
140
141 pub fn with_upsample_method(mut self, method: UpsampleMethod) -> Self {
143 self.upsample_method = method;
144 self
145 }
146
147 pub fn num_levels(&self) -> usize {
149 self.levels.len()
150 }
151
152 pub fn total_embed_dim(&self) -> usize {
154 self.levels.iter().map(|l| l.embed_dim).sum()
155 }
156
157 fn downsample(&self, signal: &Array1<f32>, factor: usize) -> Array1<f32> {
159 if factor <= 1 {
160 return signal.clone();
161 }
162
163 let new_len = signal.len() / factor;
164 if new_len == 0 {
165 return Array1::zeros(1);
166 }
167
168 match self.pool_method {
169 PoolMethod::Stride => {
170 Array1::from_vec((0..new_len).map(|i| signal[i * factor]).collect())
171 }
172 PoolMethod::Average => Array1::from_vec(
173 (0..new_len)
174 .map(|i| {
175 let start = i * factor;
176 let end = (start + factor).min(signal.len());
177 signal.iter().skip(start).take(end - start).sum::<f32>()
178 / (end - start) as f32
179 })
180 .collect(),
181 ),
182 PoolMethod::Max => Array1::from_vec(
183 (0..new_len)
184 .map(|i| {
185 let start = i * factor;
186 let end = (start + factor).min(signal.len());
187 signal
188 .iter()
189 .skip(start)
190 .take(end - start)
191 .cloned()
192 .fold(f32::NEG_INFINITY, f32::max)
193 })
194 .collect(),
195 ),
196 }
197 }
198
199 fn upsample(&self, signal: &Array1<f32>, factor: usize, target_len: usize) -> Array1<f32> {
201 if factor <= 1 {
202 return signal.clone();
203 }
204
205 match self.upsample_method {
206 UpsampleMethod::Repeat => {
207 let mut result = Vec::with_capacity(target_len);
208 for &val in signal.iter() {
209 for _ in 0..factor {
210 if result.len() < target_len {
211 result.push(val);
212 }
213 }
214 }
215 while result.len() < target_len {
217 result.push(*signal.last().unwrap_or(&0.0));
218 }
219 Array1::from_vec(result)
220 }
221 UpsampleMethod::Linear => {
222 if signal.len() < 2 {
223 return Array1::from_elem(target_len, signal.get(0).copied().unwrap_or(0.0));
224 }
225
226 let mut result = Vec::with_capacity(target_len);
227 for i in 0..target_len {
228 let src_pos = i as f32 / factor as f32;
230 let src_idx = src_pos.floor() as usize;
231 let t = src_pos - src_idx as f32;
232
233 let val = if src_idx + 1 < signal.len() {
234 signal[src_idx] * (1.0 - t) + signal[src_idx + 1] * t
235 } else {
236 signal[signal.len() - 1]
237 };
238 result.push(val);
239 }
240 Array1::from_vec(result)
241 }
242 }
243 }
244
245 pub fn encode_level(&self, signal: &Array1<f32>, level: usize) -> TokenizerResult<Array1<f32>> {
247 if level >= self.levels.len() {
248 return Err(TokenizerError::InvalidConfig(format!(
249 "Level {} out of range (0..{})",
250 level,
251 self.levels.len()
252 )));
253 }
254
255 let factor = self.levels[level].downsample_factor;
256 let downsampled = self.downsample(signal, factor);
257
258 if downsampled.len() != self.levels[level].input_dim {
259 let mut resized = Array1::zeros(self.levels[level].input_dim);
261 for i in 0..resized.len().min(downsampled.len()) {
262 resized[i] = downsampled[i];
263 }
264 return Ok(resized.dot(&self.encoders[level]));
265 }
266
267 Ok(downsampled.dot(&self.encoders[level]))
268 }
269
270 pub fn decode_level(
272 &self,
273 embedding: &Array1<f32>,
274 level: usize,
275 ) -> TokenizerResult<Array1<f32>> {
276 if level >= self.levels.len() {
277 return Err(TokenizerError::InvalidConfig(format!(
278 "Level {} out of range (0..{})",
279 level,
280 self.levels.len()
281 )));
282 }
283
284 if embedding.len() != self.levels[level].embed_dim {
285 return Err(TokenizerError::dim_mismatch(
286 self.levels[level].embed_dim,
287 embedding.len(),
288 "dimension validation",
289 ));
290 }
291
292 let decoded = embedding.dot(&self.decoders[level]);
293 let factor = self.levels[level].downsample_factor;
294
295 Ok(self.upsample(&decoded, factor, self.input_dim))
296 }
297
298 pub fn encode_all(&self, signal: &Array1<f32>) -> TokenizerResult<Vec<Array1<f32>>> {
300 let mut embeddings = Vec::with_capacity(self.levels.len());
301 for level in 0..self.levels.len() {
302 embeddings.push(self.encode_level(signal, level)?);
303 }
304 Ok(embeddings)
305 }
306
307 pub fn decode_all(&self, embeddings: &[Array1<f32>]) -> TokenizerResult<Array1<f32>> {
309 if embeddings.len() != self.levels.len() {
310 return Err(TokenizerError::InvalidConfig(format!(
311 "Expected {} embeddings, got {}",
312 self.levels.len(),
313 embeddings.len()
314 )));
315 }
316
317 let mut result = Array1::zeros(self.input_dim);
318 let weight = 1.0 / self.levels.len() as f32;
319
320 for (level, embedding) in embeddings.iter().enumerate() {
321 let decoded = self.decode_level(embedding, level)?;
322 result = &result + &(&decoded * weight);
323 }
324
325 Ok(result)
326 }
327
328 pub fn encode_concat(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
330 let embeddings = self.encode_all(signal)?;
331 let total_len: usize = embeddings.iter().map(|e| e.len()).sum();
332 let mut result = Vec::with_capacity(total_len);
333 for emb in embeddings {
334 result.extend(emb.iter());
335 }
336 Ok(Array1::from_vec(result))
337 }
338}
339
340impl SignalTokenizer for MultiScaleTokenizer {
341 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
342 if signal.len() != self.input_dim {
343 return Err(TokenizerError::dim_mismatch(
344 self.input_dim,
345 signal.len(),
346 "dimension validation",
347 ));
348 }
349 self.encode_concat(signal)
350 }
351
352 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
353 if tokens.len() != self.total_embed_dim() {
354 return Err(TokenizerError::dim_mismatch(
355 self.total_embed_dim(),
356 tokens.len(),
357 "dimension validation",
358 ));
359 }
360
361 let mut embeddings = Vec::with_capacity(self.levels.len());
363 let mut offset = 0;
364 for level in &self.levels {
365 let end = offset + level.embed_dim;
366 let embedding: Array1<f32> = Array1::from_vec(
367 tokens
368 .iter()
369 .skip(offset)
370 .take(level.embed_dim)
371 .cloned()
372 .collect(),
373 );
374 embeddings.push(embedding);
375 offset = end;
376 }
377
378 self.decode_all(&embeddings)
379 }
380
381 fn embed_dim(&self) -> usize {
382 self.total_embed_dim()
383 }
384
385 fn vocab_size(&self) -> usize {
386 0 }
388}
389
390#[derive(Debug, Clone)]
395pub struct PyramidTokenizer {
396 inner: MultiScaleTokenizer,
398 use_residual: bool,
400}
401
402impl PyramidTokenizer {
403 pub fn new(input_dim: usize, embed_dim_per_level: usize, num_levels: usize) -> Self {
405 let factors: Vec<usize> = (0..num_levels).map(|i| 1 << i).collect();
407 let inner = MultiScaleTokenizer::with_factors(input_dim, embed_dim_per_level, &factors);
408
409 Self {
410 inner,
411 use_residual: true,
412 }
413 }
414
415 pub fn without_residual(mut self) -> Self {
417 self.use_residual = false;
418 self
419 }
420
421 pub fn encode_pyramid(&self, signal: &Array1<f32>) -> TokenizerResult<Vec<Array1<f32>>> {
423 if !self.use_residual {
424 return self.inner.encode_all(signal);
425 }
426
427 let mut embeddings = Vec::with_capacity(self.inner.num_levels());
428 let mut residual = signal.clone();
429
430 for level in 0..self.inner.num_levels() {
431 let embedding = self.inner.encode_level(&residual, level)?;
432 embeddings.push(embedding.clone());
433
434 let reconstruction = self.inner.decode_level(&embedding, level)?;
436 residual = &residual - &reconstruction;
437 }
438
439 Ok(embeddings)
440 }
441
442 pub fn decode_pyramid(&self, embeddings: &[Array1<f32>]) -> TokenizerResult<Array1<f32>> {
444 if !self.use_residual {
445 return self.inner.decode_all(embeddings);
446 }
447
448 let mut result = Array1::zeros(self.inner.input_dim);
450
451 for (level, embedding) in embeddings.iter().enumerate() {
452 let decoded = self.inner.decode_level(embedding, level)?;
453 result = &result + &decoded;
454 }
455
456 Ok(result)
457 }
458
459 pub fn num_levels(&self) -> usize {
461 self.inner.num_levels()
462 }
463}
464
465impl SignalTokenizer for PyramidTokenizer {
466 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
467 let embeddings = self.encode_pyramid(signal)?;
468 let total_len: usize = embeddings.iter().map(|e| e.len()).sum();
469 let mut result = Vec::with_capacity(total_len);
470 for emb in embeddings {
471 result.extend(emb.iter());
472 }
473 Ok(Array1::from_vec(result))
474 }
475
476 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
477 let total_dim = self.inner.total_embed_dim();
478 if tokens.len() != total_dim {
479 return Err(TokenizerError::dim_mismatch(
480 total_dim,
481 tokens.len(),
482 "dimension validation",
483 ));
484 }
485
486 let mut embeddings = Vec::new();
488 let mut offset = 0;
489 for level in &self.inner.levels {
490 let end = offset + level.embed_dim;
491 let embedding = Array1::from_vec(
492 tokens
493 .iter()
494 .skip(offset)
495 .take(level.embed_dim)
496 .cloned()
497 .collect(),
498 );
499 embeddings.push(embedding);
500 offset = end;
501 }
502
503 self.decode_pyramid(&embeddings)
504 }
505
506 fn embed_dim(&self) -> usize {
507 self.inner.total_embed_dim()
508 }
509
510 fn vocab_size(&self) -> usize {
511 0
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_multiscale_basic() {
521 let tokenizer = MultiScaleTokenizer::new(64, 16);
522 assert_eq!(tokenizer.num_levels(), 3);
523 assert_eq!(tokenizer.total_embed_dim(), 48); let signal = Array1::from_vec((0..64).map(|i| (i as f32 * 0.1).sin()).collect());
526 let encoded = tokenizer.encode(&signal).unwrap();
527 assert_eq!(encoded.len(), 48);
528
529 let decoded = tokenizer.decode(&encoded).unwrap();
530 assert_eq!(decoded.len(), 64);
531 }
532
533 #[test]
534 fn test_downsample_average() {
535 let tokenizer = MultiScaleTokenizer::new(8, 4);
536
537 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
538 let down = tokenizer.downsample(&signal, 2);
539
540 assert_eq!(down.len(), 4);
541 assert!((down[0] - 1.5).abs() < 0.01);
543 assert!((down[1] - 3.5).abs() < 0.01);
544 }
545
546 #[test]
547 fn test_downsample_stride() {
548 let tokenizer = MultiScaleTokenizer::new(8, 4).with_pool_method(PoolMethod::Stride);
549
550 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
551 let down = tokenizer.downsample(&signal, 2);
552
553 assert_eq!(down.len(), 4);
554 assert_eq!(down[0], 1.0);
556 assert_eq!(down[1], 3.0);
557 }
558
559 #[test]
560 fn test_upsample_repeat() {
561 let tokenizer = MultiScaleTokenizer::new(8, 4).with_upsample_method(UpsampleMethod::Repeat);
562
563 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
564 let up = tokenizer.upsample(&signal, 2, 8);
565
566 assert_eq!(up.len(), 8);
567 assert_eq!(up[0], 1.0);
568 assert_eq!(up[1], 1.0);
569 assert_eq!(up[2], 2.0);
570 assert_eq!(up[3], 2.0);
571 }
572
573 #[test]
574 fn test_upsample_linear() {
575 let tokenizer = MultiScaleTokenizer::new(8, 4).with_upsample_method(UpsampleMethod::Linear);
576
577 let signal = Array1::from_vec(vec![0.0, 2.0]);
578 let up = tokenizer.upsample(&signal, 4, 8);
579
580 assert_eq!(up.len(), 8);
581 assert!(up[0].abs() < 0.01);
585 assert!((up[2] - 1.0).abs() < 0.01);
587 }
588
589 #[test]
590 fn test_encode_level() {
591 let tokenizer = MultiScaleTokenizer::new(64, 16);
592
593 let signal = Array1::from_vec((0..64).map(|i| i as f32).collect());
594
595 let enc0 = tokenizer.encode_level(&signal, 0).unwrap();
597 assert_eq!(enc0.len(), 16);
598
599 let enc1 = tokenizer.encode_level(&signal, 1).unwrap();
601 assert_eq!(enc1.len(), 16);
602
603 let enc2 = tokenizer.encode_level(&signal, 2).unwrap();
605 assert_eq!(enc2.len(), 16);
606 }
607
608 #[test]
609 fn test_pyramid_tokenizer() {
610 let tokenizer = PyramidTokenizer::new(64, 16, 3);
611 assert_eq!(tokenizer.num_levels(), 3);
612
613 let signal = Array1::from_vec((0..64).map(|i| (i as f32 * 0.1).sin()).collect());
614
615 let embeddings = tokenizer.encode_pyramid(&signal).unwrap();
616 assert_eq!(embeddings.len(), 3);
617
618 let decoded = tokenizer.decode_pyramid(&embeddings).unwrap();
619 assert_eq!(decoded.len(), 64);
620 }
621
622 #[test]
623 fn test_pyramid_residual() {
624 let tokenizer = PyramidTokenizer::new(32, 8, 3);
626
627 let signal = Array1::from_vec((0..32).map(|i| (i as f32 * 0.2).sin()).collect());
628
629 let embeddings = tokenizer.encode_pyramid(&signal).unwrap();
630
631 let variances: Vec<f32> = embeddings
634 .iter()
635 .map(|e| {
636 let mean = e.sum() / e.len() as f32;
637 e.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / e.len() as f32
638 })
639 .collect();
640
641 assert!(variances[0] > 0.0);
643 }
644
645 #[test]
646 fn test_custom_factors() {
647 let tokenizer = MultiScaleTokenizer::with_factors(100, 10, &[1, 5, 10, 20]);
648 assert_eq!(tokenizer.num_levels(), 4);
649
650 let signal = Array1::from_vec((0..100).map(|i| i as f32).collect());
651 let encoded = tokenizer.encode(&signal).unwrap();
652 assert_eq!(encoded.len(), 40); let decoded = tokenizer.decode(&encoded).unwrap();
655 assert_eq!(decoded.len(), 100);
656 }
657}