use crate::error::TokenizerResult;
use crate::SignalTokenizer;
use scirs2_core::ndarray::Array1;
#[derive(Debug, Clone)]
pub struct MuLawCodec {
mu: f32,
bits: u8,
levels: usize,
}
impl MuLawCodec {
pub fn new(bits: u8) -> Self {
let levels = 1usize << bits;
let mu = (levels - 1) as f32;
Self { mu, bits, levels }
}
pub fn with_mu(mu: f32, bits: u8) -> Self {
Self {
mu,
bits,
levels: 1usize << bits,
}
}
fn encode_sample(&self, x: f32) -> f32 {
let x_clamped = x.clamp(-1.0, 1.0);
let sign = x_clamped.signum();
let magnitude = (1.0 + self.mu * x_clamped.abs()).ln() / (1.0 + self.mu).ln();
sign * magnitude
}
fn decode_sample(&self, y: f32) -> f32 {
let y_clamped = y.clamp(-1.0, 1.0);
let sign = y_clamped.signum();
let magnitude = ((1.0 + self.mu).powf(y_clamped.abs()) - 1.0) / self.mu;
sign * magnitude
}
pub fn quantize(&self, x: f32) -> i32 {
let encoded = self.encode_sample(x);
let half_levels = (self.levels / 2) as f32;
((encoded + 1.0) * half_levels).round() as i32
}
pub fn dequantize(&self, level: i32) -> f32 {
let half_levels = (self.levels / 2) as f32;
let encoded = (level as f32 / half_levels) - 1.0;
self.decode_sample(encoded)
}
pub fn bits(&self) -> u8 {
self.bits
}
pub fn mu(&self) -> f32 {
self.mu
}
}
impl SignalTokenizer for MuLawCodec {
fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
Ok(signal.mapv(|x| self.encode_sample(x)))
}
fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
Ok(tokens.mapv(|y| self.decode_sample(y)))
}
fn embed_dim(&self) -> usize {
1 }
fn vocab_size(&self) -> usize {
self.levels
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mulaw_encode_decode() {
let codec = MuLawCodec::new(8);
for x in [-1.0, -0.5, 0.0, 0.5, 1.0] {
let encoded = codec.encode_sample(x);
let decoded = codec.decode_sample(encoded);
assert!((decoded - x).abs() < 0.01, "Roundtrip failed for {}", x);
}
}
#[test]
fn test_mulaw_quantize() {
let codec = MuLawCodec::new(8);
let level = codec.quantize(0.0);
assert_eq!(level, 128);
let level = codec.quantize(-1.0);
assert_eq!(level, 0);
let level = codec.quantize(1.0);
assert_eq!(level, 256);
}
#[test]
fn test_mulaw_signal() {
let codec = MuLawCodec::new(8);
let signal = Array1::from_vec(vec![0.0, 0.5, -0.5, 1.0, -1.0]);
let encoded = codec.encode(&signal).unwrap();
let decoded = codec.decode(&encoded).unwrap();
for (orig, dec) in signal.iter().zip(decoded.iter()) {
assert!(
(orig - dec).abs() < 0.01,
"Signal roundtrip failed: {} vs {}",
orig,
dec
);
}
}
}