kizzasi_tokenizer/
mulaw.rs

1//! μ-law companding codec for audio signals
2//!
3//! μ-law encoding is a logarithmic quantization scheme commonly used
4//! for audio signals. It provides better dynamic range preservation
5//! than linear quantization, especially for quiet sounds.
6//!
7//! The μ-law formula is:
8//! F(x) = sign(x) * ln(1 + μ|x|) / ln(1 + μ)
9//!
10//! where μ is typically 255 for 8-bit quantization.
11
12use crate::error::TokenizerResult;
13use crate::SignalTokenizer;
14use scirs2_core::ndarray::Array1;
15
16/// μ-law companding codec
17#[derive(Debug, Clone)]
18pub struct MuLawCodec {
19    /// μ parameter (typically 255)
20    mu: f32,
21    /// Number of quantization bits
22    bits: u8,
23    /// Number of quantization levels
24    levels: usize,
25}
26
27impl MuLawCodec {
28    /// Create a new μ-law codec with specified bits
29    pub fn new(bits: u8) -> Self {
30        let levels = 1usize << bits;
31        let mu = (levels - 1) as f32;
32        Self { mu, bits, levels }
33    }
34
35    /// Create with custom μ value
36    pub fn with_mu(mu: f32, bits: u8) -> Self {
37        Self {
38            mu,
39            bits,
40            levels: 1usize << bits,
41        }
42    }
43
44    /// Encode a single sample using μ-law
45    fn encode_sample(&self, x: f32) -> f32 {
46        let x_clamped = x.clamp(-1.0, 1.0);
47        let sign = x_clamped.signum();
48        let magnitude = (1.0 + self.mu * x_clamped.abs()).ln() / (1.0 + self.mu).ln();
49        sign * magnitude
50    }
51
52    /// Decode a single sample using μ-law
53    fn decode_sample(&self, y: f32) -> f32 {
54        let y_clamped = y.clamp(-1.0, 1.0);
55        let sign = y_clamped.signum();
56        let magnitude = ((1.0 + self.mu).powf(y_clamped.abs()) - 1.0) / self.mu;
57        sign * magnitude
58    }
59
60    /// Quantize to integer level
61    pub fn quantize(&self, x: f32) -> i32 {
62        let encoded = self.encode_sample(x);
63        let half_levels = (self.levels / 2) as f32;
64        ((encoded + 1.0) * half_levels).round() as i32
65    }
66
67    /// Dequantize from integer level
68    pub fn dequantize(&self, level: i32) -> f32 {
69        let half_levels = (self.levels / 2) as f32;
70        let encoded = (level as f32 / half_levels) - 1.0;
71        self.decode_sample(encoded)
72    }
73
74    /// Get the number of bits
75    pub fn bits(&self) -> u8 {
76        self.bits
77    }
78
79    /// Get μ value
80    pub fn mu(&self) -> f32 {
81        self.mu
82    }
83}
84
85impl SignalTokenizer for MuLawCodec {
86    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
87        Ok(signal.mapv(|x| self.encode_sample(x)))
88    }
89
90    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
91        Ok(tokens.mapv(|y| self.decode_sample(y)))
92    }
93
94    fn embed_dim(&self) -> usize {
95        1 // μ-law maintains dimensionality
96    }
97
98    fn vocab_size(&self) -> usize {
99        self.levels
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_mulaw_encode_decode() {
109        let codec = MuLawCodec::new(8);
110
111        // Test roundtrip for various values
112        for x in [-1.0, -0.5, 0.0, 0.5, 1.0] {
113            let encoded = codec.encode_sample(x);
114            let decoded = codec.decode_sample(encoded);
115            assert!((decoded - x).abs() < 0.01, "Roundtrip failed for {}", x);
116        }
117    }
118
119    #[test]
120    fn test_mulaw_quantize() {
121        let codec = MuLawCodec::new(8);
122
123        let level = codec.quantize(0.0);
124        assert_eq!(level, 128); // Middle of 256 levels
125
126        let level = codec.quantize(-1.0);
127        assert_eq!(level, 0);
128
129        let level = codec.quantize(1.0);
130        assert_eq!(level, 256);
131    }
132
133    #[test]
134    fn test_mulaw_signal() {
135        let codec = MuLawCodec::new(8);
136        let signal = Array1::from_vec(vec![0.0, 0.5, -0.5, 1.0, -1.0]);
137
138        let encoded = codec.encode(&signal).unwrap();
139        let decoded = codec.decode(&encoded).unwrap();
140
141        for (orig, dec) in signal.iter().zip(decoded.iter()) {
142            assert!(
143                (orig - dec).abs() < 0.01,
144                "Signal roundtrip failed: {} vs {}",
145                orig,
146                dec
147            );
148        }
149    }
150}