1use crate::error::{TokenizerError, TokenizerResult};
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::random::thread_rng;
12
13#[derive(Debug, Clone)]
19pub struct TokenDropoutConfig {
20 pub dropout_rate: f32,
22 pub fill_value: f32,
24 pub scale_remaining: bool,
26}
27
28impl Default for TokenDropoutConfig {
29 fn default() -> Self {
30 Self {
31 dropout_rate: 0.1,
32 fill_value: 0.0,
33 scale_remaining: true,
34 }
35 }
36}
37
38pub fn apply_token_dropout(
51 tokens: &Array1<f32>,
52 config: &TokenDropoutConfig,
53 training: bool,
54) -> TokenizerResult<Array1<f32>> {
55 if !training || config.dropout_rate <= 0.0 {
56 return Ok(tokens.clone());
57 }
58
59 if !(0.0..=1.0).contains(&config.dropout_rate) {
60 return Err(TokenizerError::InvalidConfig(
61 "dropout_rate must be in [0, 1]".into(),
62 ));
63 }
64
65 let mut rng = thread_rng();
66 let mut result = tokens.clone();
67
68 for val in result.iter_mut() {
69 if rng.random::<f32>() < config.dropout_rate {
70 *val = config.fill_value;
71 } else if config.scale_remaining {
72 *val /= 1.0 - config.dropout_rate;
74 }
75 }
76
77 Ok(result)
78}
79
80pub fn apply_batch_token_dropout(
82 tokens: &Array2<f32>,
83 config: &TokenDropoutConfig,
84 training: bool,
85) -> TokenizerResult<Array2<f32>> {
86 if !training || config.dropout_rate <= 0.0 {
87 return Ok(tokens.clone());
88 }
89
90 let (batch_size, seq_len) = (tokens.shape()[0], tokens.shape()[1]);
91 let mut rng = thread_rng();
92 let mut result = tokens.clone();
93
94 for i in 0..batch_size {
95 for j in 0..seq_len {
96 if rng.random::<f32>() < config.dropout_rate {
97 result[[i, j]] = config.fill_value;
98 } else if config.scale_remaining {
99 result[[i, j]] /= 1.0 - config.dropout_rate;
100 }
101 }
102 }
103
104 Ok(result)
105}
106
107#[derive(Debug, Clone)]
113pub struct JitterConfig {
114 pub noise_std: f32,
116 pub apply_at_inference: bool,
118 pub target_snr_db: Option<f32>,
120}
121
122impl Default for JitterConfig {
123 fn default() -> Self {
124 Self {
125 noise_std: 0.01,
126 apply_at_inference: false,
127 target_snr_db: None,
128 }
129 }
130}
131
132impl JitterConfig {
133 pub fn with_snr(target_snr_db: f32) -> Self {
135 Self {
136 noise_std: 0.0, apply_at_inference: false,
138 target_snr_db: Some(target_snr_db),
139 }
140 }
141}
142
143pub fn add_jitter(
156 signal: &Array1<f32>,
157 config: &JitterConfig,
158 training: bool,
159) -> TokenizerResult<Array1<f32>> {
160 if !training && !config.apply_at_inference {
161 return Ok(signal.clone());
162 }
163
164 let noise_std = if let Some(target_snr_db) = config.target_snr_db {
166 let signal_power = signal.iter().map(|x| x.powi(2)).sum::<f32>() / signal.len() as f32;
167 let target_snr_linear = 10.0_f32.powf(target_snr_db / 10.0);
168 let noise_power = signal_power / target_snr_linear;
169 noise_power.sqrt()
170 } else {
171 config.noise_std
172 };
173
174 if noise_std <= 0.0 {
175 return Ok(signal.clone());
176 }
177
178 let mut rng = thread_rng();
179 let mut result = signal.clone();
180
181 for val in result.iter_mut() {
182 let gaussian: f32 = (0..12).map(|_| rng.random::<f32>()).sum::<f32>() - 6.0;
184 *val += gaussian * noise_std;
185 }
186
187 Ok(result)
188}
189
190pub fn add_batch_jitter(
192 signals: &Array2<f32>,
193 config: &JitterConfig,
194 training: bool,
195) -> TokenizerResult<Array2<f32>> {
196 if !training && !config.apply_at_inference {
197 return Ok(signals.clone());
198 }
199
200 let (batch_size, seq_len) = (signals.shape()[0], signals.shape()[1]);
201 let mut result = signals.clone();
202
203 for i in 0..batch_size {
205 let row = signals.row(i).to_owned();
206 let jittered = add_jitter(&row, config, training)?;
207
208 for j in 0..seq_len {
209 result[[i, j]] = jittered[[j]];
210 }
211 }
212
213 Ok(result)
214}
215
216#[derive(Debug, Clone)]
222pub struct TemporalCoherenceConfig {
223 pub smoothness: f32,
225 pub window_size: usize,
227 pub filter_type: TemporalFilterType,
229}
230
231#[derive(Debug, Clone, Copy)]
232pub enum TemporalFilterType {
233 ExponentialMovingAverage,
235 SimpleMovingAverage,
237 GaussianWeighted,
239}
240
241impl Default for TemporalCoherenceConfig {
242 fn default() -> Self {
243 Self {
244 smoothness: 0.5,
245 window_size: 5,
246 filter_type: TemporalFilterType::SimpleMovingAverage,
247 }
248 }
249}
250
251pub fn apply_temporal_coherence(
263 signal: &Array1<f32>,
264 config: &TemporalCoherenceConfig,
265) -> TokenizerResult<Array1<f32>> {
266 if !(0.0..=1.0).contains(&config.smoothness) {
267 return Err(TokenizerError::InvalidConfig(
268 "smoothness must be in [0, 1]".into(),
269 ));
270 }
271
272 if config.smoothness <= 0.0 {
273 return Ok(signal.clone());
274 }
275
276 match config.filter_type {
277 TemporalFilterType::ExponentialMovingAverage => apply_ema(signal, config.smoothness),
278 TemporalFilterType::SimpleMovingAverage => apply_sma(signal, config.window_size),
279 TemporalFilterType::GaussianWeighted => {
280 apply_gaussian_smooth(signal, config.window_size, config.smoothness)
281 }
282 }
283}
284
285fn apply_ema(signal: &Array1<f32>, alpha: f32) -> TokenizerResult<Array1<f32>> {
287 let mut result = signal.clone();
288
289 for i in 1..signal.len() {
290 result[[i]] = alpha * signal[[i]] + (1.0 - alpha) * result[[i - 1]];
291 }
292
293 Ok(result)
294}
295
296fn apply_sma(signal: &Array1<f32>, window_size: usize) -> TokenizerResult<Array1<f32>> {
298 if window_size == 0 {
299 return Err(TokenizerError::InvalidConfig(
300 "window_size must be positive".into(),
301 ));
302 }
303
304 let mut result = signal.clone();
305 let half_window = window_size / 2;
306
307 for i in 0..signal.len() {
308 let start = i.saturating_sub(half_window);
309 let end = (i + half_window + 1).min(signal.len());
310
311 let sum: f32 = signal.iter().skip(start).take(end - start).sum();
312 result[[i]] = sum / (end - start) as f32;
313 }
314
315 Ok(result)
316}
317
318fn apply_gaussian_smooth(
320 signal: &Array1<f32>,
321 window_size: usize,
322 sigma: f32,
323) -> TokenizerResult<Array1<f32>> {
324 if window_size == 0 {
325 return Err(TokenizerError::InvalidConfig(
326 "window_size must be positive".into(),
327 ));
328 }
329
330 let mut result = signal.clone();
331 let half_window = window_size / 2;
332
333 let mut weights = vec![0.0; window_size];
335 let mut weight_sum = 0.0;
336 for (i, w) in weights.iter_mut().enumerate() {
337 let offset = i as f32 - half_window as f32;
338 *w = (-offset.powi(2) / (2.0 * sigma.powi(2))).exp();
339 weight_sum += *w;
340 }
341
342 for w in &mut weights {
344 *w /= weight_sum;
345 }
346
347 for i in 0..signal.len() {
349 let start = i.saturating_sub(half_window);
350 let end = (i + half_window + 1).min(signal.len());
351
352 let mut value = 0.0;
353 let mut local_weight_sum = 0.0;
354
355 for (j, idx) in (start..end).enumerate() {
356 let weight_idx = j + half_window.saturating_sub(i.saturating_sub(start));
357 if weight_idx < weights.len() {
358 value += signal[[idx]] * weights[weight_idx];
359 local_weight_sum += weights[weight_idx];
360 }
361 }
362
363 result[[i]] = value / local_weight_sum.max(1e-8);
364 }
365
366 Ok(result)
367}
368
369#[derive(Debug, Clone)]
375pub struct HierarchicalConfig {
376 pub num_levels: usize,
378 pub codebook_sizes: Vec<usize>,
380 pub use_residual: bool,
382}
383
384impl HierarchicalConfig {
385 pub fn exponential(base_size: usize, num_levels: usize, decay_factor: f32) -> Self {
387 let mut codebook_sizes = Vec::with_capacity(num_levels);
388
389 for level in 0..num_levels {
390 let size = (base_size as f32 * decay_factor.powi(level as i32)) as usize;
391 codebook_sizes.push(size.max(16)); }
393
394 Self {
395 num_levels,
396 codebook_sizes,
397 use_residual: true,
398 }
399 }
400}
401
402#[derive(Debug, Clone)]
410pub struct HierarchicalTokenizer {
411 config: HierarchicalConfig,
412 codebooks: Vec<Array2<f32>>,
414}
415
416impl HierarchicalTokenizer {
417 pub fn new(embed_dim: usize, config: HierarchicalConfig) -> TokenizerResult<Self> {
419 if config.num_levels == 0 {
420 return Err(TokenizerError::InvalidConfig(
421 "num_levels must be positive".into(),
422 ));
423 }
424
425 if config.codebook_sizes.len() != config.num_levels {
426 return Err(TokenizerError::InvalidConfig(
427 "codebook_sizes.len() must equal num_levels".into(),
428 ));
429 }
430
431 let mut rng = thread_rng();
433 let mut codebooks = Vec::with_capacity(config.num_levels);
434
435 for &size in &config.codebook_sizes {
436 let mut codebook_data = vec![0.0; size * embed_dim];
437 for val in &mut codebook_data {
438 let gaussian: f32 = (0..12).map(|_| rng.random::<f32>()).sum::<f32>() - 6.0;
440 *val = gaussian;
441 }
442
443 let codebook =
444 Array2::from_shape_vec((size, embed_dim), codebook_data).map_err(|e| {
445 TokenizerError::encoding("serialization", format!("Codebook init: {}", e))
446 })?;
447
448 codebooks.push(codebook);
449 }
450
451 Ok(Self { config, codebooks })
452 }
453
454 pub fn encode_with_levels(
456 &self,
457 signal: &Array1<f32>,
458 num_levels: usize,
459 ) -> TokenizerResult<Vec<usize>> {
460 if num_levels > self.config.num_levels {
461 return Err(TokenizerError::InvalidConfig(format!(
462 "num_levels {} exceeds configured {}",
463 num_levels, self.config.num_levels
464 )));
465 }
466
467 let mut indices = Vec::with_capacity(num_levels);
468 let mut residual = signal.clone();
469
470 for level in 0..num_levels {
471 let codebook = &self.codebooks[level];
473 let mut best_idx = 0;
474 let mut best_dist = f32::INFINITY;
475
476 for (idx, code) in codebook.outer_iter().enumerate() {
477 let dist: f32 = residual
478 .iter()
479 .zip(code.iter())
480 .map(|(r, c)| (r - c).powi(2))
481 .sum();
482
483 if dist < best_dist {
484 best_dist = dist;
485 best_idx = idx;
486 }
487 }
488
489 indices.push(best_idx);
490
491 if self.config.use_residual && level < num_levels - 1 {
493 let quantized = codebook.row(best_idx);
494 for i in 0..residual.len().min(quantized.len()) {
495 residual[[i]] -= quantized[[i]];
496 }
497 }
498 }
499
500 Ok(indices)
501 }
502
503 pub fn decode_hierarchical(&self, indices: &[usize]) -> TokenizerResult<Array1<f32>> {
505 if indices.is_empty() {
506 return Err(TokenizerError::decoding("deserialization", "Empty indices"));
507 }
508
509 if indices.len() > self.config.num_levels {
510 return Err(TokenizerError::decoding(
511 "decoding",
512 format!(
513 "Too many indices: {} > {}",
514 indices.len(),
515 self.config.num_levels
516 ),
517 ));
518 }
519
520 let first_code = self.codebooks[0].row(indices[0]);
522 let mut result = first_code.to_owned();
523
524 if self.config.use_residual {
526 for (level, &idx) in indices.iter().enumerate().skip(1) {
527 if idx >= self.codebooks[level].shape()[0] {
528 return Err(TokenizerError::decoding(
529 "decoding",
530 format!("Invalid index {} at level {}", idx, level),
531 ));
532 }
533
534 let code = self.codebooks[level].row(idx);
535 for i in 0..result.len().min(code.len()) {
536 result[[i]] += code[[i]];
537 }
538 }
539 }
540
541 Ok(result)
542 }
543
544 pub fn bitrate_for_levels(&self, num_levels: usize) -> f32 {
546 let mut total_bits = 0.0;
547
548 for level in 0..num_levels.min(self.config.num_levels) {
549 total_bits += (self.config.codebook_sizes[level] as f32).log2();
550 }
551
552 total_bits
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 #[test]
561 fn test_token_dropout() {
562 let tokens = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
563 let config = TokenDropoutConfig {
564 dropout_rate: 0.5,
565 fill_value: 0.0,
566 scale_remaining: false,
567 };
568
569 let result = apply_token_dropout(&tokens, &config, true).unwrap();
570 assert_eq!(result.len(), tokens.len());
571 }
572
573 #[test]
574 fn test_jitter_injection() {
575 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
576 let config = JitterConfig {
577 noise_std: 0.1,
578 apply_at_inference: false,
579 target_snr_db: None,
580 };
581
582 let result = add_jitter(&signal, &config, true).unwrap();
583 assert_eq!(result.len(), signal.len());
584 }
585
586 #[test]
587 fn test_temporal_coherence_sma() {
588 let signal = Array1::from_vec(vec![1.0, 5.0, 2.0, 8.0, 3.0]);
589 let config = TemporalCoherenceConfig {
590 smoothness: 0.5,
591 window_size: 3,
592 filter_type: TemporalFilterType::SimpleMovingAverage,
593 };
594
595 let result = apply_temporal_coherence(&signal, &config).unwrap();
596 assert_eq!(result.len(), signal.len());
597
598 let original_var: f32 = signal.iter().map(|x| x.powi(2)).sum::<f32>() / signal.len() as f32;
600 let smoothed_var: f32 = result.iter().map(|x| x.powi(2)).sum::<f32>() / result.len() as f32;
601
602 assert!(
604 (smoothed_var - original_var).abs() < original_var,
605 "Smoothed variance should be similar"
606 );
607 }
608
609 #[test]
610 fn test_hierarchical_tokenizer() {
611 let config = HierarchicalConfig::exponential(256, 3, 0.5);
612 let tokenizer = HierarchicalTokenizer::new(8, config).unwrap();
613
614 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
615
616 let indices1 = tokenizer.encode_with_levels(&signal, 1).unwrap();
618 let indices2 = tokenizer.encode_with_levels(&signal, 2).unwrap();
619 let indices3 = tokenizer.encode_with_levels(&signal, 3).unwrap();
620
621 assert_eq!(indices1.len(), 1);
622 assert_eq!(indices2.len(), 2);
623 assert_eq!(indices3.len(), 3);
624
625 let decoded = tokenizer.decode_hierarchical(&indices3).unwrap();
627 assert_eq!(decoded.len(), signal.len());
628 }
629
630 #[test]
631 fn test_hierarchical_bitrate() {
632 let config = HierarchicalConfig::exponential(256, 3, 0.5);
633 let tokenizer = HierarchicalTokenizer::new(8, config).unwrap();
634
635 let br1 = tokenizer.bitrate_for_levels(1);
636 let br2 = tokenizer.bitrate_for_levels(2);
637 let br3 = tokenizer.bitrate_for_levels(3);
638
639 assert!(br1 < br2);
641 assert!(br2 < br3);
642 }
643}