kizzasi_tokenizer/
advanced_features.rs

1//! Advanced features for tokenizer robustness and regularization
2//!
3//! This module provides:
4//! - **Token Dropout**: Randomly drop tokens during training for regularization
5//! - **Jitter Injection**: Add controlled noise for robustness
6//! - **Temporal Coherence**: Enforce smoothness constraints across time
7//! - **Hierarchical Tokenization**: Variable-length codes with hierarchical structure
8
9use crate::error::{TokenizerError, TokenizerResult};
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::random::thread_rng;
12
13// ============================================================================
14// Token Dropout for Regularization
15// ============================================================================
16
17/// Token dropout configuration
18#[derive(Debug, Clone)]
19pub struct TokenDropoutConfig {
20    /// Dropout probability (0.0 = no dropout, 1.0 = drop all)
21    pub dropout_rate: f32,
22    /// Value to use for dropped tokens (typically 0.0 or codebook mean)
23    pub fill_value: f32,
24    /// Whether to scale remaining tokens to compensate for dropout
25    pub scale_remaining: bool,
26}
27
28impl Default for TokenDropoutConfig {
29    fn default() -> Self {
30        Self {
31            dropout_rate: 0.1,
32            fill_value: 0.0,
33            scale_remaining: true,
34        }
35    }
36}
37
38/// Apply token dropout to a signal
39///
40/// During training, randomly set tokens to `fill_value` with probability `dropout_rate`.
41/// This acts as a regularization technique to prevent over-reliance on specific tokens.
42///
43/// # Arguments
44/// * `tokens` - Input token array
45/// * `config` - Dropout configuration
46/// * `training` - Whether dropout should be applied (true during training)
47///
48/// # Returns
49/// Token array with dropout applied (if training=true)
50pub fn apply_token_dropout(
51    tokens: &Array1<f32>,
52    config: &TokenDropoutConfig,
53    training: bool,
54) -> TokenizerResult<Array1<f32>> {
55    if !training || config.dropout_rate <= 0.0 {
56        return Ok(tokens.clone());
57    }
58
59    if !(0.0..=1.0).contains(&config.dropout_rate) {
60        return Err(TokenizerError::InvalidConfig(
61            "dropout_rate must be in [0, 1]".into(),
62        ));
63    }
64
65    let mut rng = thread_rng();
66    let mut result = tokens.clone();
67
68    for val in result.iter_mut() {
69        if rng.random::<f32>() < config.dropout_rate {
70            *val = config.fill_value;
71        } else if config.scale_remaining {
72            // Scale up to compensate for dropped tokens
73            *val /= 1.0 - config.dropout_rate;
74        }
75    }
76
77    Ok(result)
78}
79
80/// Apply batch token dropout
81pub fn apply_batch_token_dropout(
82    tokens: &Array2<f32>,
83    config: &TokenDropoutConfig,
84    training: bool,
85) -> TokenizerResult<Array2<f32>> {
86    if !training || config.dropout_rate <= 0.0 {
87        return Ok(tokens.clone());
88    }
89
90    let (batch_size, seq_len) = (tokens.shape()[0], tokens.shape()[1]);
91    let mut rng = thread_rng();
92    let mut result = tokens.clone();
93
94    for i in 0..batch_size {
95        for j in 0..seq_len {
96            if rng.random::<f32>() < config.dropout_rate {
97                result[[i, j]] = config.fill_value;
98            } else if config.scale_remaining {
99                result[[i, j]] /= 1.0 - config.dropout_rate;
100            }
101        }
102    }
103
104    Ok(result)
105}
106
107// ============================================================================
108// Jitter Injection for Robustness
109// ============================================================================
110
111/// Jitter injection configuration
112#[derive(Debug, Clone)]
113pub struct JitterConfig {
114    /// Standard deviation of Gaussian noise
115    pub noise_std: f32,
116    /// Whether to apply jitter during inference (usually false)
117    pub apply_at_inference: bool,
118    /// SNR target in dB (alternative to noise_std)
119    pub target_snr_db: Option<f32>,
120}
121
122impl Default for JitterConfig {
123    fn default() -> Self {
124        Self {
125            noise_std: 0.01,
126            apply_at_inference: false,
127            target_snr_db: None,
128        }
129    }
130}
131
132impl JitterConfig {
133    /// Create jitter config with target SNR
134    pub fn with_snr(target_snr_db: f32) -> Self {
135        Self {
136            noise_std: 0.0, // Will be computed based on signal
137            apply_at_inference: false,
138            target_snr_db: Some(target_snr_db),
139        }
140    }
141}
142
143/// Add Gaussian jitter to signal for robustness
144///
145/// Injects controlled noise to make the model robust to small perturbations.
146/// Can be applied during training to improve generalization.
147///
148/// # Arguments
149/// * `signal` - Input signal
150/// * `config` - Jitter configuration
151/// * `training` - Whether currently in training mode
152///
153/// # Returns
154/// Signal with added jitter (if applicable)
155pub fn add_jitter(
156    signal: &Array1<f32>,
157    config: &JitterConfig,
158    training: bool,
159) -> TokenizerResult<Array1<f32>> {
160    if !training && !config.apply_at_inference {
161        return Ok(signal.clone());
162    }
163
164    // Compute noise std based on SNR target if specified
165    let noise_std = if let Some(target_snr_db) = config.target_snr_db {
166        let signal_power = signal.iter().map(|x| x.powi(2)).sum::<f32>() / signal.len() as f32;
167        let target_snr_linear = 10.0_f32.powf(target_snr_db / 10.0);
168        let noise_power = signal_power / target_snr_linear;
169        noise_power.sqrt()
170    } else {
171        config.noise_std
172    };
173
174    if noise_std <= 0.0 {
175        return Ok(signal.clone());
176    }
177
178    let mut rng = thread_rng();
179    let mut result = signal.clone();
180
181    for val in result.iter_mut() {
182        // Use central limit theorem: sum of 12 uniforms approximates Gaussian(0,1)
183        let gaussian: f32 = (0..12).map(|_| rng.random::<f32>()).sum::<f32>() - 6.0;
184        *val += gaussian * noise_std;
185    }
186
187    Ok(result)
188}
189
190/// Add batch jitter
191pub fn add_batch_jitter(
192    signals: &Array2<f32>,
193    config: &JitterConfig,
194    training: bool,
195) -> TokenizerResult<Array2<f32>> {
196    if !training && !config.apply_at_inference {
197        return Ok(signals.clone());
198    }
199
200    let (batch_size, seq_len) = (signals.shape()[0], signals.shape()[1]);
201    let mut result = signals.clone();
202
203    // Apply jitter to each sample in the batch
204    for i in 0..batch_size {
205        let row = signals.row(i).to_owned();
206        let jittered = add_jitter(&row, config, training)?;
207
208        for j in 0..seq_len {
209            result[[i, j]] = jittered[[j]];
210        }
211    }
212
213    Ok(result)
214}
215
216// ============================================================================
217// Temporal Coherence Constraints
218// ============================================================================
219
220/// Temporal coherence configuration
221#[derive(Debug, Clone)]
222pub struct TemporalCoherenceConfig {
223    /// Smoothness strength (0.0 = no smoothing, 1.0 = maximum smoothing)
224    pub smoothness: f32,
225    /// Window size for temporal smoothing
226    pub window_size: usize,
227    /// Type of temporal filter
228    pub filter_type: TemporalFilterType,
229}
230
231#[derive(Debug, Clone, Copy)]
232pub enum TemporalFilterType {
233    /// Exponential moving average
234    ExponentialMovingAverage,
235    /// Simple moving average
236    SimpleMovingAverage,
237    /// Gaussian weighted
238    GaussianWeighted,
239}
240
241impl Default for TemporalCoherenceConfig {
242    fn default() -> Self {
243        Self {
244            smoothness: 0.5,
245            window_size: 5,
246            filter_type: TemporalFilterType::SimpleMovingAverage,
247        }
248    }
249}
250
251/// Apply temporal coherence constraint to enforce smoothness
252///
253/// Smooths the signal across time to reduce jitter and enforce
254/// temporal consistency. Useful for signals that should vary smoothly.
255///
256/// # Arguments
257/// * `signal` - Input signal (assumed to be temporal)
258/// * `config` - Temporal coherence configuration
259///
260/// # Returns
261/// Temporally smoothed signal
262pub fn apply_temporal_coherence(
263    signal: &Array1<f32>,
264    config: &TemporalCoherenceConfig,
265) -> TokenizerResult<Array1<f32>> {
266    if !(0.0..=1.0).contains(&config.smoothness) {
267        return Err(TokenizerError::InvalidConfig(
268            "smoothness must be in [0, 1]".into(),
269        ));
270    }
271
272    if config.smoothness <= 0.0 {
273        return Ok(signal.clone());
274    }
275
276    match config.filter_type {
277        TemporalFilterType::ExponentialMovingAverage => apply_ema(signal, config.smoothness),
278        TemporalFilterType::SimpleMovingAverage => apply_sma(signal, config.window_size),
279        TemporalFilterType::GaussianWeighted => {
280            apply_gaussian_smooth(signal, config.window_size, config.smoothness)
281        }
282    }
283}
284
285/// Apply Exponential Moving Average (EMA)
286fn apply_ema(signal: &Array1<f32>, alpha: f32) -> TokenizerResult<Array1<f32>> {
287    let mut result = signal.clone();
288
289    for i in 1..signal.len() {
290        result[[i]] = alpha * signal[[i]] + (1.0 - alpha) * result[[i - 1]];
291    }
292
293    Ok(result)
294}
295
296/// Apply Simple Moving Average (SMA)
297fn apply_sma(signal: &Array1<f32>, window_size: usize) -> TokenizerResult<Array1<f32>> {
298    if window_size == 0 {
299        return Err(TokenizerError::InvalidConfig(
300            "window_size must be positive".into(),
301        ));
302    }
303
304    let mut result = signal.clone();
305    let half_window = window_size / 2;
306
307    for i in 0..signal.len() {
308        let start = i.saturating_sub(half_window);
309        let end = (i + half_window + 1).min(signal.len());
310
311        let sum: f32 = signal.iter().skip(start).take(end - start).sum();
312        result[[i]] = sum / (end - start) as f32;
313    }
314
315    Ok(result)
316}
317
318/// Apply Gaussian-weighted smoothing
319fn apply_gaussian_smooth(
320    signal: &Array1<f32>,
321    window_size: usize,
322    sigma: f32,
323) -> TokenizerResult<Array1<f32>> {
324    if window_size == 0 {
325        return Err(TokenizerError::InvalidConfig(
326            "window_size must be positive".into(),
327        ));
328    }
329
330    let mut result = signal.clone();
331    let half_window = window_size / 2;
332
333    // Precompute Gaussian weights
334    let mut weights = vec![0.0; window_size];
335    let mut weight_sum = 0.0;
336    for (i, w) in weights.iter_mut().enumerate() {
337        let offset = i as f32 - half_window as f32;
338        *w = (-offset.powi(2) / (2.0 * sigma.powi(2))).exp();
339        weight_sum += *w;
340    }
341
342    // Normalize weights
343    for w in &mut weights {
344        *w /= weight_sum;
345    }
346
347    // Apply weighted smoothing
348    for i in 0..signal.len() {
349        let start = i.saturating_sub(half_window);
350        let end = (i + half_window + 1).min(signal.len());
351
352        let mut value = 0.0;
353        let mut local_weight_sum = 0.0;
354
355        for (j, idx) in (start..end).enumerate() {
356            let weight_idx = j + half_window.saturating_sub(i.saturating_sub(start));
357            if weight_idx < weights.len() {
358                value += signal[[idx]] * weights[weight_idx];
359                local_weight_sum += weights[weight_idx];
360            }
361        }
362
363        result[[i]] = value / local_weight_sum.max(1e-8);
364    }
365
366    Ok(result)
367}
368
369// ============================================================================
370// Hierarchical Tokenization with Variable-Length Codes
371// ============================================================================
372
373/// Hierarchical tokenization configuration
374#[derive(Debug, Clone)]
375pub struct HierarchicalConfig {
376    /// Number of hierarchy levels (1 = flat, >1 = hierarchical)
377    pub num_levels: usize,
378    /// Codebook sizes per level
379    pub codebook_sizes: Vec<usize>,
380    /// Whether to use residual coding between levels
381    pub use_residual: bool,
382}
383
384impl HierarchicalConfig {
385    /// Create a hierarchical config with exponentially decreasing codebook sizes
386    pub fn exponential(base_size: usize, num_levels: usize, decay_factor: f32) -> Self {
387        let mut codebook_sizes = Vec::with_capacity(num_levels);
388
389        for level in 0..num_levels {
390            let size = (base_size as f32 * decay_factor.powi(level as i32)) as usize;
391            codebook_sizes.push(size.max(16)); // Minimum 16 codes per level
392        }
393
394        Self {
395            num_levels,
396            codebook_sizes,
397            use_residual: true,
398        }
399    }
400}
401
402/// Hierarchical tokenizer with variable-length codes
403///
404/// Encodes signals using multiple levels of granularity:
405/// - Coarse level: Few bits, captures main structure
406/// - Fine levels: More bits, capture details
407///
408/// Allows variable bitrate by using different numbers of levels.
409#[derive(Debug, Clone)]
410pub struct HierarchicalTokenizer {
411    config: HierarchicalConfig,
412    /// Codebooks for each level (simplified - just centers)
413    codebooks: Vec<Array2<f32>>,
414}
415
416impl HierarchicalTokenizer {
417    /// Create a new hierarchical tokenizer
418    pub fn new(embed_dim: usize, config: HierarchicalConfig) -> TokenizerResult<Self> {
419        if config.num_levels == 0 {
420            return Err(TokenizerError::InvalidConfig(
421                "num_levels must be positive".into(),
422            ));
423        }
424
425        if config.codebook_sizes.len() != config.num_levels {
426            return Err(TokenizerError::InvalidConfig(
427                "codebook_sizes.len() must equal num_levels".into(),
428            ));
429        }
430
431        // Initialize random codebooks for each level
432        let mut rng = thread_rng();
433        let mut codebooks = Vec::with_capacity(config.num_levels);
434
435        for &size in &config.codebook_sizes {
436            let mut codebook_data = vec![0.0; size * embed_dim];
437            for val in &mut codebook_data {
438                // Use central limit theorem for Gaussian initialization
439                let gaussian: f32 = (0..12).map(|_| rng.random::<f32>()).sum::<f32>() - 6.0;
440                *val = gaussian;
441            }
442
443            let codebook =
444                Array2::from_shape_vec((size, embed_dim), codebook_data).map_err(|e| {
445                    TokenizerError::encoding("serialization", format!("Codebook init: {}", e))
446                })?;
447
448            codebooks.push(codebook);
449        }
450
451        Ok(Self { config, codebooks })
452    }
453
454    /// Encode using specified number of levels (for variable bitrate)
455    pub fn encode_with_levels(
456        &self,
457        signal: &Array1<f32>,
458        num_levels: usize,
459    ) -> TokenizerResult<Vec<usize>> {
460        if num_levels > self.config.num_levels {
461            return Err(TokenizerError::InvalidConfig(format!(
462                "num_levels {} exceeds configured {}",
463                num_levels, self.config.num_levels
464            )));
465        }
466
467        let mut indices = Vec::with_capacity(num_levels);
468        let mut residual = signal.clone();
469
470        for level in 0..num_levels {
471            // Find nearest codebook entry at this level
472            let codebook = &self.codebooks[level];
473            let mut best_idx = 0;
474            let mut best_dist = f32::INFINITY;
475
476            for (idx, code) in codebook.outer_iter().enumerate() {
477                let dist: f32 = residual
478                    .iter()
479                    .zip(code.iter())
480                    .map(|(r, c)| (r - c).powi(2))
481                    .sum();
482
483                if dist < best_dist {
484                    best_dist = dist;
485                    best_idx = idx;
486                }
487            }
488
489            indices.push(best_idx);
490
491            // Update residual if using residual coding
492            if self.config.use_residual && level < num_levels - 1 {
493                let quantized = codebook.row(best_idx);
494                for i in 0..residual.len().min(quantized.len()) {
495                    residual[[i]] -= quantized[[i]];
496                }
497            }
498        }
499
500        Ok(indices)
501    }
502
503    /// Decode from hierarchical indices
504    pub fn decode_hierarchical(&self, indices: &[usize]) -> TokenizerResult<Array1<f32>> {
505        if indices.is_empty() {
506            return Err(TokenizerError::decoding("deserialization", "Empty indices"));
507        }
508
509        if indices.len() > self.config.num_levels {
510            return Err(TokenizerError::decoding(
511                "decoding",
512                format!(
513                    "Too many indices: {} > {}",
514                    indices.len(),
515                    self.config.num_levels
516                ),
517            ));
518        }
519
520        // Get first level codebook entry
521        let first_code = self.codebooks[0].row(indices[0]);
522        let mut result = first_code.to_owned();
523
524        // Add residuals from subsequent levels
525        if self.config.use_residual {
526            for (level, &idx) in indices.iter().enumerate().skip(1) {
527                if idx >= self.codebooks[level].shape()[0] {
528                    return Err(TokenizerError::decoding(
529                        "decoding",
530                        format!("Invalid index {} at level {}", idx, level),
531                    ));
532                }
533
534                let code = self.codebooks[level].row(idx);
535                for i in 0..result.len().min(code.len()) {
536                    result[[i]] += code[[i]];
537                }
538            }
539        }
540
541        Ok(result)
542    }
543
544    /// Get the bitrate for a given number of levels
545    pub fn bitrate_for_levels(&self, num_levels: usize) -> f32 {
546        let mut total_bits = 0.0;
547
548        for level in 0..num_levels.min(self.config.num_levels) {
549            total_bits += (self.config.codebook_sizes[level] as f32).log2();
550        }
551
552        total_bits
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    #[test]
561    fn test_token_dropout() {
562        let tokens = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
563        let config = TokenDropoutConfig {
564            dropout_rate: 0.5,
565            fill_value: 0.0,
566            scale_remaining: false,
567        };
568
569        let result = apply_token_dropout(&tokens, &config, true).unwrap();
570        assert_eq!(result.len(), tokens.len());
571    }
572
573    #[test]
574    fn test_jitter_injection() {
575        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
576        let config = JitterConfig {
577            noise_std: 0.1,
578            apply_at_inference: false,
579            target_snr_db: None,
580        };
581
582        let result = add_jitter(&signal, &config, true).unwrap();
583        assert_eq!(result.len(), signal.len());
584    }
585
586    #[test]
587    fn test_temporal_coherence_sma() {
588        let signal = Array1::from_vec(vec![1.0, 5.0, 2.0, 8.0, 3.0]);
589        let config = TemporalCoherenceConfig {
590            smoothness: 0.5,
591            window_size: 3,
592            filter_type: TemporalFilterType::SimpleMovingAverage,
593        };
594
595        let result = apply_temporal_coherence(&signal, &config).unwrap();
596        assert_eq!(result.len(), signal.len());
597
598        // Smoothed signal should have lower variance
599        let original_var: f32 = signal.iter().map(|x| x.powi(2)).sum::<f32>() / signal.len() as f32;
600        let smoothed_var: f32 = result.iter().map(|x| x.powi(2)).sum::<f32>() / result.len() as f32;
601
602        // Not strictly guaranteed, but very likely with this test signal
603        assert!(
604            (smoothed_var - original_var).abs() < original_var,
605            "Smoothed variance should be similar"
606        );
607    }
608
609    #[test]
610    fn test_hierarchical_tokenizer() {
611        let config = HierarchicalConfig::exponential(256, 3, 0.5);
612        let tokenizer = HierarchicalTokenizer::new(8, config).unwrap();
613
614        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
615
616        // Encode with different numbers of levels
617        let indices1 = tokenizer.encode_with_levels(&signal, 1).unwrap();
618        let indices2 = tokenizer.encode_with_levels(&signal, 2).unwrap();
619        let indices3 = tokenizer.encode_with_levels(&signal, 3).unwrap();
620
621        assert_eq!(indices1.len(), 1);
622        assert_eq!(indices2.len(), 2);
623        assert_eq!(indices3.len(), 3);
624
625        // Decode and check dimension preservation
626        let decoded = tokenizer.decode_hierarchical(&indices3).unwrap();
627        assert_eq!(decoded.len(), signal.len());
628    }
629
630    #[test]
631    fn test_hierarchical_bitrate() {
632        let config = HierarchicalConfig::exponential(256, 3, 0.5);
633        let tokenizer = HierarchicalTokenizer::new(8, config).unwrap();
634
635        let br1 = tokenizer.bitrate_for_levels(1);
636        let br2 = tokenizer.bitrate_for_levels(2);
637        let br3 = tokenizer.bitrate_for_levels(3);
638
639        // More levels = higher bitrate
640        assert!(br1 < br2);
641        assert!(br2 < br3);
642    }
643}