use gamut_core::{Error, Result};
fn round_half_away_from_zero(x: f64) -> f64 {
if x >= 0.0 {
(x + 0.5).floor()
} else {
(x - 0.5).ceil()
}
}
fn max_index(bits: u32) -> Result<u32> {
if !(2..=31).contains(&bits) {
return Err(Error::InvalidInput("mu-law bit width must be in 2..=31"));
}
Ok((1u32 << bits) - 2)
}
#[must_use]
pub fn mu_compress(value: f64, mu: f64) -> f64 {
let v = value.clamp(-1.0, 1.0);
v.signum() * (1.0 + mu * v.abs()).ln() / (1.0 + mu).ln()
}
#[must_use]
pub fn mu_expand(compressed: f64, mu: f64) -> f64 {
compressed.signum() * ((1.0 + mu).powf(compressed.abs()) - 1.0) / mu
}
pub fn mu_quantize(value: f64, bits: u32, mu: f64) -> Result<u32> {
let max_idx = max_index(bits)?;
let compressed = mu_compress(value, mu);
let idx = round_half_away_from_zero((compressed + 1.0) / 2.0 * f64::from(max_idx));
Ok((idx as i64).clamp(0, i64::from(max_idx)) as u32)
}
pub fn mu_dequantize(index: u32, bits: u32, mu: f64) -> Result<f64> {
let max_idx = max_index(bits)?;
let index = index.min(max_idx);
let compressed = f64::from(index) / f64::from(max_idx) * 2.0 - 1.0;
Ok(mu_expand(compressed, mu))
}
#[cfg(test)]
mod tests {
use super::*;
const MU: f64 = 5.0;
#[test]
fn compress_expand_roundtrip() {
for &v in &[-1.0, -0.5, 0.0, 0.5, 1.0] {
let rt = mu_expand(mu_compress(v, MU), MU);
assert!((rt - v).abs() < 1e-12, "roundtrip at {v}: {rt}");
}
}
#[test]
fn compressed_endpoints() {
assert!((mu_compress(1.0, MU) - 1.0).abs() < 1e-12);
assert!((mu_compress(-1.0, MU) + 1.0).abs() < 1e-12);
assert!(mu_compress(0.0, MU).abs() < 1e-12);
}
#[test]
fn zero_quantizes_to_center_and_back_exactly() {
for bits in [4u32, 5, 6] {
let center = (1u32 << (bits - 1)) - 1;
assert_eq!(mu_quantize(0.0, bits, MU).unwrap(), center, "bits={bits}");
assert_eq!(mu_dequantize(center, bits, MU).unwrap(), 0.0, "bits={bits}");
}
}
#[test]
fn extremes_quantize_to_bounds() {
for bits in [4u32, 5, 6] {
let max_idx = (1u32 << bits) - 2;
assert_eq!(mu_quantize(-1.0, bits, MU).unwrap(), 0);
assert_eq!(mu_quantize(1.0, bits, MU).unwrap(), max_idx);
}
}
#[test]
fn top_code_clamps_on_dequantize() {
for bits in [4u32, 5, 6] {
let top = (1u32 << bits) - 1;
assert_eq!(
mu_dequantize(top, bits, MU).unwrap(),
mu_dequantize(top - 1, bits, MU).unwrap()
);
}
}
#[test]
fn symmetric_codes_around_center() {
for bits in [4u32, 5, 6] {
let center = (1u32 << (bits - 1)) - 1;
for &v in &[0.1, 0.3, 0.7] {
let qp = mu_quantize(v, bits, MU).unwrap();
let qn = mu_quantize(-v, bits, MU).unwrap();
assert_eq!(qp - center, center - qn, "±{v} at bits={bits}");
}
}
}
#[test]
fn invalid_bit_width_errors() {
assert!(mu_quantize(0.5, 1, MU).is_err());
assert!(mu_quantize(0.5, 32, MU).is_err());
assert!(mu_dequantize(0, 0, MU).is_err());
}
#[test]
fn matches_chromahash_mulaw_vectors() {
struct Case {
value: f64,
bits: u32,
mu: f64,
compressed: f64,
expanded: f64,
quantized: u32,
dequantized: f64,
}
let cases = [
Case {
value: 0.0,
bits: 5,
mu: 5.0,
compressed: 0.0,
expanded: 0.0,
quantized: 15,
dequantized: 0.0,
},
Case {
value: 1.0,
bits: 5,
mu: 5.0,
compressed: 1.0,
expanded: 1.0000000000000002,
quantized: 30,
dequantized: 1.0000000000000002,
},
Case {
value: -1.0,
bits: 5,
mu: 5.0,
compressed: -1.0,
expanded: -1.0000000000000002,
quantized: 0,
dequantized: -1.0000000000000002,
},
Case {
value: 0.5,
bits: 5,
mu: 5.0,
compressed: 0.6991803252671502,
expanded: 0.4999999999999999,
quantized: 25,
dequantized: 0.46038544977892554,
},
Case {
value: -0.5,
bits: 4,
mu: 5.0,
compressed: -0.6991803252671502,
expanded: -0.4999999999999999,
quantized: 2,
dequantized: -0.5192043696541104,
},
Case {
value: 0.75,
bits: 6,
mu: 5.0,
compressed: 0.8696170690354138,
expanded: 0.7499999999999998,
quantized: 58,
dequantized: 0.7523018611322912,
},
Case {
value: 0.0,
bits: 5,
mu: 8.0,
compressed: 0.0,
expanded: 0.0,
quantized: 15,
dequantized: 0.0,
},
Case {
value: 0.5,
bits: 5,
mu: 8.0,
compressed: 0.7324867603589637,
expanded: 0.4999999999999999,
quantized: 26,
dequantized: 0.5011636512658286,
},
Case {
value: -0.25,
bits: 6,
mu: 8.0,
compressed: -0.5000000000000001,
expanded: -0.25000000000000006,
quantized: 15,
dequantized: -0.2635279583396101,
},
Case {
value: 1.0,
bits: 4,
mu: 8.0,
compressed: 1.0,
expanded: 0.9999999999999998,
quantized: 14,
dequantized: 0.9999999999999998,
},
Case {
value: -0.75,
bits: 5,
mu: 8.0,
compressed: -0.8856218745807111,
expanded: -0.7499999999999999,
quantized: 2,
dequantized: -0.7143057295610805,
},
];
for c in &cases {
assert!(
(mu_compress(c.value, c.mu) - c.compressed).abs() < 1e-9,
"compress {}",
c.value
);
assert!(
(mu_expand(c.compressed, c.mu) - c.expanded).abs() < 1e-9,
"expand {}",
c.value
);
let q = mu_quantize(c.value, c.bits, c.mu).unwrap();
assert!(
(i64::from(q) - i64::from(c.quantized)).abs() <= 1,
"quantize {}: {q} vs {}",
c.value,
c.quantized
);
assert!(
(mu_dequantize(c.quantized, c.bits, c.mu).unwrap() - c.dequantized).abs() < 1e-9,
"dequantize {}",
c.value
);
}
}
}