kizzasi_tokenizer/
multiscale.rs

1//! Multi-scale (hierarchical) tokenization
2//!
3//! Processes signals at multiple temporal resolutions for better
4//! representation of both local details and global structure.
5//!
6//! This is similar to hierarchical audio codecs (SoundStream, Encodec)
7//! where lower levels capture fine details and higher levels capture
8//! coarse structure.
9
10use crate::error::{TokenizerError, TokenizerResult};
11use crate::SignalTokenizer;
12use scirs2_core::ndarray::{Array1, Array2};
13use scirs2_core::random::thread_rng;
14
15/// Configuration for a single scale level
16#[derive(Debug, Clone)]
17pub struct ScaleLevel {
18    /// Downsampling factor relative to original signal
19    downsample_factor: usize,
20    /// Embedding dimension for this level
21    embed_dim: usize,
22    /// Input dimension (derived from downsample)
23    input_dim: usize,
24}
25
26impl ScaleLevel {
27    /// Create a new scale level
28    pub fn new(downsample_factor: usize, embed_dim: usize, input_dim: usize) -> Self {
29        Self {
30            downsample_factor,
31            embed_dim,
32            input_dim,
33        }
34    }
35}
36
37/// Multi-scale hierarchical tokenizer
38///
39/// Processes signals at multiple temporal resolutions:
40/// - Level 0: Full resolution (finest details)
41/// - Level 1: 2x downsampled
42/// - Level 2: 4x downsampled
43/// - ...
44///
45/// Each level captures patterns at its corresponding scale.
46#[derive(Debug, Clone)]
47pub struct MultiScaleTokenizer {
48    /// Encoder projections for each level
49    encoders: Vec<Array2<f32>>,
50    /// Decoder projections for each level
51    decoders: Vec<Array2<f32>>,
52    /// Scale configuration
53    levels: Vec<ScaleLevel>,
54    /// Original input dimension
55    input_dim: usize,
56    /// Pooling method for downsampling
57    pool_method: PoolMethod,
58    /// Upsampling method for reconstruction
59    upsample_method: UpsampleMethod,
60}
61
62/// Method for downsampling signals
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
64pub enum PoolMethod {
65    /// Take every Nth sample
66    Stride,
67    /// Average pooling
68    #[default]
69    Average,
70    /// Max pooling
71    Max,
72}
73
74/// Method for upsampling signals
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
76pub enum UpsampleMethod {
77    /// Repeat values (nearest neighbor)
78    Repeat,
79    /// Linear interpolation
80    #[default]
81    Linear,
82}
83
84impl MultiScaleTokenizer {
85    /// Create a new multi-scale tokenizer with default scales
86    ///
87    /// Default: 3 levels with downsample factors 1, 2, 4
88    pub fn new(input_dim: usize, embed_dim_per_level: usize) -> Self {
89        Self::with_factors(input_dim, embed_dim_per_level, &[1, 2, 4])
90    }
91
92    /// Create with custom downsample factors
93    pub fn with_factors(input_dim: usize, embed_dim_per_level: usize, factors: &[usize]) -> Self {
94        let mut rng = thread_rng();
95        let mut encoders = Vec::with_capacity(factors.len());
96        let mut decoders = Vec::with_capacity(factors.len());
97        let mut levels = Vec::with_capacity(factors.len());
98
99        for &factor in factors {
100            let level_input_dim = input_dim / factor;
101            if level_input_dim == 0 {
102                continue;
103            }
104
105            // Xavier initialization
106            let enc_scale = (2.0 / (level_input_dim + embed_dim_per_level) as f32).sqrt();
107            let encoder = Array2::from_shape_fn((level_input_dim, embed_dim_per_level), |_| {
108                (rng.random::<f32>() - 0.5) * 2.0 * enc_scale
109            });
110
111            let dec_scale = (2.0 / (embed_dim_per_level + level_input_dim) as f32).sqrt();
112            let decoder = Array2::from_shape_fn((embed_dim_per_level, level_input_dim), |_| {
113                (rng.random::<f32>() - 0.5) * 2.0 * dec_scale
114            });
115
116            encoders.push(encoder);
117            decoders.push(decoder);
118            levels.push(ScaleLevel::new(
119                factor,
120                embed_dim_per_level,
121                level_input_dim,
122            ));
123        }
124
125        Self {
126            encoders,
127            decoders,
128            levels,
129            input_dim,
130            pool_method: PoolMethod::default(),
131            upsample_method: UpsampleMethod::default(),
132        }
133    }
134
135    /// Set pooling method
136    pub fn with_pool_method(mut self, method: PoolMethod) -> Self {
137        self.pool_method = method;
138        self
139    }
140
141    /// Set upsampling method
142    pub fn with_upsample_method(mut self, method: UpsampleMethod) -> Self {
143        self.upsample_method = method;
144        self
145    }
146
147    /// Get number of levels
148    pub fn num_levels(&self) -> usize {
149        self.levels.len()
150    }
151
152    /// Get total embedding dimension across all levels
153    pub fn total_embed_dim(&self) -> usize {
154        self.levels.iter().map(|l| l.embed_dim).sum()
155    }
156
157    /// Downsample signal by given factor
158    fn downsample(&self, signal: &Array1<f32>, factor: usize) -> Array1<f32> {
159        if factor <= 1 {
160            return signal.clone();
161        }
162
163        let new_len = signal.len() / factor;
164        if new_len == 0 {
165            return Array1::zeros(1);
166        }
167
168        match self.pool_method {
169            PoolMethod::Stride => {
170                Array1::from_vec((0..new_len).map(|i| signal[i * factor]).collect())
171            }
172            PoolMethod::Average => Array1::from_vec(
173                (0..new_len)
174                    .map(|i| {
175                        let start = i * factor;
176                        let end = (start + factor).min(signal.len());
177                        signal.iter().skip(start).take(end - start).sum::<f32>()
178                            / (end - start) as f32
179                    })
180                    .collect(),
181            ),
182            PoolMethod::Max => Array1::from_vec(
183                (0..new_len)
184                    .map(|i| {
185                        let start = i * factor;
186                        let end = (start + factor).min(signal.len());
187                        signal
188                            .iter()
189                            .skip(start)
190                            .take(end - start)
191                            .cloned()
192                            .fold(f32::NEG_INFINITY, f32::max)
193                    })
194                    .collect(),
195            ),
196        }
197    }
198
199    /// Upsample signal by given factor
200    fn upsample(&self, signal: &Array1<f32>, factor: usize, target_len: usize) -> Array1<f32> {
201        if factor <= 1 {
202            return signal.clone();
203        }
204
205        match self.upsample_method {
206            UpsampleMethod::Repeat => {
207                let mut result = Vec::with_capacity(target_len);
208                for &val in signal.iter() {
209                    for _ in 0..factor {
210                        if result.len() < target_len {
211                            result.push(val);
212                        }
213                    }
214                }
215                // Pad if needed
216                while result.len() < target_len {
217                    result.push(*signal.last().unwrap_or(&0.0));
218                }
219                Array1::from_vec(result)
220            }
221            UpsampleMethod::Linear => {
222                if signal.len() < 2 {
223                    return Array1::from_elem(target_len, signal.get(0).copied().unwrap_or(0.0));
224                }
225
226                let mut result = Vec::with_capacity(target_len);
227                for i in 0..target_len {
228                    // Map target position to source position
229                    let src_pos = i as f32 / factor as f32;
230                    let src_idx = src_pos.floor() as usize;
231                    let t = src_pos - src_idx as f32;
232
233                    let val = if src_idx + 1 < signal.len() {
234                        signal[src_idx] * (1.0 - t) + signal[src_idx + 1] * t
235                    } else {
236                        signal[signal.len() - 1]
237                    };
238                    result.push(val);
239                }
240                Array1::from_vec(result)
241            }
242        }
243    }
244
245    /// Encode at a specific level
246    pub fn encode_level(&self, signal: &Array1<f32>, level: usize) -> TokenizerResult<Array1<f32>> {
247        if level >= self.levels.len() {
248            return Err(TokenizerError::InvalidConfig(format!(
249                "Level {} out of range (0..{})",
250                level,
251                self.levels.len()
252            )));
253        }
254
255        let factor = self.levels[level].downsample_factor;
256        let downsampled = self.downsample(signal, factor);
257
258        if downsampled.len() != self.levels[level].input_dim {
259            // Resize to match expected dimension
260            let mut resized = Array1::zeros(self.levels[level].input_dim);
261            for i in 0..resized.len().min(downsampled.len()) {
262                resized[i] = downsampled[i];
263            }
264            return Ok(resized.dot(&self.encoders[level]));
265        }
266
267        Ok(downsampled.dot(&self.encoders[level]))
268    }
269
270    /// Decode at a specific level
271    pub fn decode_level(
272        &self,
273        embedding: &Array1<f32>,
274        level: usize,
275    ) -> TokenizerResult<Array1<f32>> {
276        if level >= self.levels.len() {
277            return Err(TokenizerError::InvalidConfig(format!(
278                "Level {} out of range (0..{})",
279                level,
280                self.levels.len()
281            )));
282        }
283
284        if embedding.len() != self.levels[level].embed_dim {
285            return Err(TokenizerError::dim_mismatch(
286                self.levels[level].embed_dim,
287                embedding.len(),
288                "dimension validation",
289            ));
290        }
291
292        let decoded = embedding.dot(&self.decoders[level]);
293        let factor = self.levels[level].downsample_factor;
294
295        Ok(self.upsample(&decoded, factor, self.input_dim))
296    }
297
298    /// Encode all levels and concatenate embeddings
299    pub fn encode_all(&self, signal: &Array1<f32>) -> TokenizerResult<Vec<Array1<f32>>> {
300        let mut embeddings = Vec::with_capacity(self.levels.len());
301        for level in 0..self.levels.len() {
302            embeddings.push(self.encode_level(signal, level)?);
303        }
304        Ok(embeddings)
305    }
306
307    /// Decode all levels and combine
308    pub fn decode_all(&self, embeddings: &[Array1<f32>]) -> TokenizerResult<Array1<f32>> {
309        if embeddings.len() != self.levels.len() {
310            return Err(TokenizerError::InvalidConfig(format!(
311                "Expected {} embeddings, got {}",
312                self.levels.len(),
313                embeddings.len()
314            )));
315        }
316
317        let mut result = Array1::zeros(self.input_dim);
318        let weight = 1.0 / self.levels.len() as f32;
319
320        for (level, embedding) in embeddings.iter().enumerate() {
321            let decoded = self.decode_level(embedding, level)?;
322            result = &result + &(&decoded * weight);
323        }
324
325        Ok(result)
326    }
327
328    /// Get concatenated encoding (all levels flattened)
329    pub fn encode_concat(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
330        let embeddings = self.encode_all(signal)?;
331        let total_len: usize = embeddings.iter().map(|e| e.len()).sum();
332        let mut result = Vec::with_capacity(total_len);
333        for emb in embeddings {
334            result.extend(emb.iter());
335        }
336        Ok(Array1::from_vec(result))
337    }
338}
339
340impl SignalTokenizer for MultiScaleTokenizer {
341    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
342        if signal.len() != self.input_dim {
343            return Err(TokenizerError::dim_mismatch(
344                self.input_dim,
345                signal.len(),
346                "dimension validation",
347            ));
348        }
349        self.encode_concat(signal)
350    }
351
352    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
353        if tokens.len() != self.total_embed_dim() {
354            return Err(TokenizerError::dim_mismatch(
355                self.total_embed_dim(),
356                tokens.len(),
357                "dimension validation",
358            ));
359        }
360
361        // Split tokens back into level embeddings
362        let mut embeddings = Vec::with_capacity(self.levels.len());
363        let mut offset = 0;
364        for level in &self.levels {
365            let end = offset + level.embed_dim;
366            let embedding: Array1<f32> = Array1::from_vec(
367                tokens
368                    .iter()
369                    .skip(offset)
370                    .take(level.embed_dim)
371                    .cloned()
372                    .collect(),
373            );
374            embeddings.push(embedding);
375            offset = end;
376        }
377
378        self.decode_all(&embeddings)
379    }
380
381    fn embed_dim(&self) -> usize {
382        self.total_embed_dim()
383    }
384
385    fn vocab_size(&self) -> usize {
386        0 // Continuous
387    }
388}
389
390/// Pyramid tokenizer with residual connections
391///
392/// Each level encodes the residual from the previous level's reconstruction,
393/// similar to residual VQ (RVQ) used in audio codecs.
394#[derive(Debug, Clone)]
395pub struct PyramidTokenizer {
396    /// Base multi-scale tokenizer
397    inner: MultiScaleTokenizer,
398    /// Whether to use residual encoding
399    use_residual: bool,
400}
401
402impl PyramidTokenizer {
403    /// Create a new pyramid tokenizer
404    pub fn new(input_dim: usize, embed_dim_per_level: usize, num_levels: usize) -> Self {
405        // Generate factors: 1, 2, 4, 8, ...
406        let factors: Vec<usize> = (0..num_levels).map(|i| 1 << i).collect();
407        let inner = MultiScaleTokenizer::with_factors(input_dim, embed_dim_per_level, &factors);
408
409        Self {
410            inner,
411            use_residual: true,
412        }
413    }
414
415    /// Disable residual encoding (use independent levels)
416    pub fn without_residual(mut self) -> Self {
417        self.use_residual = false;
418        self
419    }
420
421    /// Encode with residual pyramid
422    pub fn encode_pyramid(&self, signal: &Array1<f32>) -> TokenizerResult<Vec<Array1<f32>>> {
423        if !self.use_residual {
424            return self.inner.encode_all(signal);
425        }
426
427        let mut embeddings = Vec::with_capacity(self.inner.num_levels());
428        let mut residual = signal.clone();
429
430        for level in 0..self.inner.num_levels() {
431            let embedding = self.inner.encode_level(&residual, level)?;
432            embeddings.push(embedding.clone());
433
434            // Compute reconstruction and subtract from residual
435            let reconstruction = self.inner.decode_level(&embedding, level)?;
436            residual = &residual - &reconstruction;
437        }
438
439        Ok(embeddings)
440    }
441
442    /// Decode from pyramid embeddings
443    pub fn decode_pyramid(&self, embeddings: &[Array1<f32>]) -> TokenizerResult<Array1<f32>> {
444        if !self.use_residual {
445            return self.inner.decode_all(embeddings);
446        }
447
448        // Sum all level reconstructions
449        let mut result = Array1::zeros(self.inner.input_dim);
450
451        for (level, embedding) in embeddings.iter().enumerate() {
452            let decoded = self.inner.decode_level(embedding, level)?;
453            result = &result + &decoded;
454        }
455
456        Ok(result)
457    }
458
459    /// Get number of levels
460    pub fn num_levels(&self) -> usize {
461        self.inner.num_levels()
462    }
463}
464
465impl SignalTokenizer for PyramidTokenizer {
466    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
467        let embeddings = self.encode_pyramid(signal)?;
468        let total_len: usize = embeddings.iter().map(|e| e.len()).sum();
469        let mut result = Vec::with_capacity(total_len);
470        for emb in embeddings {
471            result.extend(emb.iter());
472        }
473        Ok(Array1::from_vec(result))
474    }
475
476    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
477        let total_dim = self.inner.total_embed_dim();
478        if tokens.len() != total_dim {
479            return Err(TokenizerError::dim_mismatch(
480                total_dim,
481                tokens.len(),
482                "dimension validation",
483            ));
484        }
485
486        // Split tokens
487        let mut embeddings = Vec::new();
488        let mut offset = 0;
489        for level in &self.inner.levels {
490            let end = offset + level.embed_dim;
491            let embedding = Array1::from_vec(
492                tokens
493                    .iter()
494                    .skip(offset)
495                    .take(level.embed_dim)
496                    .cloned()
497                    .collect(),
498            );
499            embeddings.push(embedding);
500            offset = end;
501        }
502
503        self.decode_pyramid(&embeddings)
504    }
505
506    fn embed_dim(&self) -> usize {
507        self.inner.total_embed_dim()
508    }
509
510    fn vocab_size(&self) -> usize {
511        0
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn test_multiscale_basic() {
521        let tokenizer = MultiScaleTokenizer::new(64, 16);
522        assert_eq!(tokenizer.num_levels(), 3);
523        assert_eq!(tokenizer.total_embed_dim(), 48); // 16 * 3
524
525        let signal = Array1::from_vec((0..64).map(|i| (i as f32 * 0.1).sin()).collect());
526        let encoded = tokenizer.encode(&signal).unwrap();
527        assert_eq!(encoded.len(), 48);
528
529        let decoded = tokenizer.decode(&encoded).unwrap();
530        assert_eq!(decoded.len(), 64);
531    }
532
533    #[test]
534    fn test_downsample_average() {
535        let tokenizer = MultiScaleTokenizer::new(8, 4);
536
537        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
538        let down = tokenizer.downsample(&signal, 2);
539
540        assert_eq!(down.len(), 4);
541        // Average of pairs: (1+2)/2=1.5, (3+4)/2=3.5, ...
542        assert!((down[0] - 1.5).abs() < 0.01);
543        assert!((down[1] - 3.5).abs() < 0.01);
544    }
545
546    #[test]
547    fn test_downsample_stride() {
548        let tokenizer = MultiScaleTokenizer::new(8, 4).with_pool_method(PoolMethod::Stride);
549
550        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
551        let down = tokenizer.downsample(&signal, 2);
552
553        assert_eq!(down.len(), 4);
554        // Every 2nd sample: 1, 3, 5, 7
555        assert_eq!(down[0], 1.0);
556        assert_eq!(down[1], 3.0);
557    }
558
559    #[test]
560    fn test_upsample_repeat() {
561        let tokenizer = MultiScaleTokenizer::new(8, 4).with_upsample_method(UpsampleMethod::Repeat);
562
563        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
564        let up = tokenizer.upsample(&signal, 2, 8);
565
566        assert_eq!(up.len(), 8);
567        assert_eq!(up[0], 1.0);
568        assert_eq!(up[1], 1.0);
569        assert_eq!(up[2], 2.0);
570        assert_eq!(up[3], 2.0);
571    }
572
573    #[test]
574    fn test_upsample_linear() {
575        let tokenizer = MultiScaleTokenizer::new(8, 4).with_upsample_method(UpsampleMethod::Linear);
576
577        let signal = Array1::from_vec(vec![0.0, 2.0]);
578        let up = tokenizer.upsample(&signal, 4, 8);
579
580        assert_eq!(up.len(), 8);
581        // Linear interp from 0 to 2
582        // Position 0 maps to src 0/4 = 0.0 -> signal[0] = 0.0
583        // Position 4 maps to src 4/4 = 1.0 -> signal[1] = 2.0
584        assert!(up[0].abs() < 0.01);
585        // At position 2, src_pos = 0.5, so interp = 0.0*(1-0.5) + 2.0*0.5 = 1.0
586        assert!((up[2] - 1.0).abs() < 0.01);
587    }
588
589    #[test]
590    fn test_encode_level() {
591        let tokenizer = MultiScaleTokenizer::new(64, 16);
592
593        let signal = Array1::from_vec((0..64).map(|i| i as f32).collect());
594
595        // Level 0: factor 1 (full res)
596        let enc0 = tokenizer.encode_level(&signal, 0).unwrap();
597        assert_eq!(enc0.len(), 16);
598
599        // Level 1: factor 2 (half res)
600        let enc1 = tokenizer.encode_level(&signal, 1).unwrap();
601        assert_eq!(enc1.len(), 16);
602
603        // Level 2: factor 4 (quarter res)
604        let enc2 = tokenizer.encode_level(&signal, 2).unwrap();
605        assert_eq!(enc2.len(), 16);
606    }
607
608    #[test]
609    fn test_pyramid_tokenizer() {
610        let tokenizer = PyramidTokenizer::new(64, 16, 3);
611        assert_eq!(tokenizer.num_levels(), 3);
612
613        let signal = Array1::from_vec((0..64).map(|i| (i as f32 * 0.1).sin()).collect());
614
615        let embeddings = tokenizer.encode_pyramid(&signal).unwrap();
616        assert_eq!(embeddings.len(), 3);
617
618        let decoded = tokenizer.decode_pyramid(&embeddings).unwrap();
619        assert_eq!(decoded.len(), 64);
620    }
621
622    #[test]
623    fn test_pyramid_residual() {
624        // Residual pyramid should capture progressively finer details
625        let tokenizer = PyramidTokenizer::new(32, 8, 3);
626
627        let signal = Array1::from_vec((0..32).map(|i| (i as f32 * 0.2).sin()).collect());
628
629        let embeddings = tokenizer.encode_pyramid(&signal).unwrap();
630
631        // Each level's embedding variance should generally decrease
632        // (residuals get smaller as we add more detail)
633        let variances: Vec<f32> = embeddings
634            .iter()
635            .map(|e| {
636                let mean = e.sum() / e.len() as f32;
637                e.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / e.len() as f32
638            })
639            .collect();
640
641        // Level 0 should capture most signal variance
642        assert!(variances[0] > 0.0);
643    }
644
645    #[test]
646    fn test_custom_factors() {
647        let tokenizer = MultiScaleTokenizer::with_factors(100, 10, &[1, 5, 10, 20]);
648        assert_eq!(tokenizer.num_levels(), 4);
649
650        let signal = Array1::from_vec((0..100).map(|i| i as f32).collect());
651        let encoded = tokenizer.encode(&signal).unwrap();
652        assert_eq!(encoded.len(), 40); // 10 * 4
653
654        let decoded = tokenizer.decode(&encoded).unwrap();
655        assert_eq!(decoded.len(), 100);
656    }
657}