kizzasi_tokenizer/
mulaw.rs1use crate::error::TokenizerResult;
13use crate::SignalTokenizer;
14use scirs2_core::ndarray::Array1;
15
16#[derive(Debug, Clone)]
18pub struct MuLawCodec {
19 mu: f32,
21 bits: u8,
23 levels: usize,
25}
26
27impl MuLawCodec {
28 pub fn new(bits: u8) -> Self {
30 let levels = 1usize << bits;
31 let mu = (levels - 1) as f32;
32 Self { mu, bits, levels }
33 }
34
35 pub fn with_mu(mu: f32, bits: u8) -> Self {
37 Self {
38 mu,
39 bits,
40 levels: 1usize << bits,
41 }
42 }
43
44 fn encode_sample(&self, x: f32) -> f32 {
46 let x_clamped = x.clamp(-1.0, 1.0);
47 let sign = x_clamped.signum();
48 let magnitude = (1.0 + self.mu * x_clamped.abs()).ln() / (1.0 + self.mu).ln();
49 sign * magnitude
50 }
51
52 fn decode_sample(&self, y: f32) -> f32 {
54 let y_clamped = y.clamp(-1.0, 1.0);
55 let sign = y_clamped.signum();
56 let magnitude = ((1.0 + self.mu).powf(y_clamped.abs()) - 1.0) / self.mu;
57 sign * magnitude
58 }
59
60 pub fn quantize(&self, x: f32) -> i32 {
62 let encoded = self.encode_sample(x);
63 let half_levels = (self.levels / 2) as f32;
64 ((encoded + 1.0) * half_levels).round() as i32
65 }
66
67 pub fn dequantize(&self, level: i32) -> f32 {
69 let half_levels = (self.levels / 2) as f32;
70 let encoded = (level as f32 / half_levels) - 1.0;
71 self.decode_sample(encoded)
72 }
73
74 pub fn bits(&self) -> u8 {
76 self.bits
77 }
78
79 pub fn mu(&self) -> f32 {
81 self.mu
82 }
83}
84
85impl SignalTokenizer for MuLawCodec {
86 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
87 Ok(signal.mapv(|x| self.encode_sample(x)))
88 }
89
90 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
91 Ok(tokens.mapv(|y| self.decode_sample(y)))
92 }
93
94 fn embed_dim(&self) -> usize {
95 1 }
97
98 fn vocab_size(&self) -> usize {
99 self.levels
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn test_mulaw_encode_decode() {
109 let codec = MuLawCodec::new(8);
110
111 for x in [-1.0, -0.5, 0.0, 0.5, 1.0] {
113 let encoded = codec.encode_sample(x);
114 let decoded = codec.decode_sample(encoded);
115 assert!((decoded - x).abs() < 0.01, "Roundtrip failed for {}", x);
116 }
117 }
118
119 #[test]
120 fn test_mulaw_quantize() {
121 let codec = MuLawCodec::new(8);
122
123 let level = codec.quantize(0.0);
124 assert_eq!(level, 128); let level = codec.quantize(-1.0);
127 assert_eq!(level, 0);
128
129 let level = codec.quantize(1.0);
130 assert_eq!(level, 256);
131 }
132
133 #[test]
134 fn test_mulaw_signal() {
135 let codec = MuLawCodec::new(8);
136 let signal = Array1::from_vec(vec![0.0, 0.5, -0.5, 1.0, -1.0]);
137
138 let encoded = codec.encode(&signal).unwrap();
139 let decoded = codec.decode(&encoded).unwrap();
140
141 for (orig, dec) in signal.iter().zip(decoded.iter()) {
142 assert!(
143 (orig - dec).abs() < 0.01,
144 "Signal roundtrip failed: {} vs {}",
145 orig,
146 dec
147 );
148 }
149 }
150}