use crate::error::{TokenizerError, TokenizerResult};
use crate::SignalTokenizer;
use scirs2_core::ndarray::Array1;
pub trait Quantizer {
fn quantize(&self, value: f32) -> i32;
fn dequantize(&self, level: i32) -> f32;
fn num_levels(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct LinearQuantizer {
min: f32,
max: f32,
bits: u8,
levels: usize,
step: f32,
}
impl LinearQuantizer {
pub fn new(min: f32, max: f32, bits: u8) -> TokenizerResult<Self> {
if min >= max {
return Err(TokenizerError::InvalidConfig(
"min must be less than max".into(),
));
}
if bits == 0 || bits > 16 {
return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
}
let levels = 1usize << bits;
let step = (max - min) / levels as f32;
Ok(Self {
min,
max,
bits,
levels,
step,
})
}
pub fn normalized(bits: u8) -> TokenizerResult<Self> {
Self::new(-1.0, 1.0, bits)
}
pub fn range(&self) -> (f32, f32) {
(self.min, self.max)
}
pub fn bits(&self) -> u8 {
self.bits
}
pub fn step_size(&self) -> f32 {
self.step
}
}
impl Quantizer for LinearQuantizer {
fn quantize(&self, value: f32) -> i32 {
let clamped = value.clamp(self.min, self.max);
let normalized = (clamped - self.min) / (self.max - self.min);
(normalized * (self.levels - 1) as f32).round() as i32
}
fn dequantize(&self, level: i32) -> f32 {
let clamped_level = level.clamp(0, (self.levels - 1) as i32);
let normalized = clamped_level as f32 / (self.levels - 1) as f32;
self.min + normalized * (self.max - self.min)
}
fn num_levels(&self) -> usize {
self.levels
}
}
impl SignalTokenizer for LinearQuantizer {
fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
Ok(signal.mapv(|x| self.quantize(x) as f32))
}
fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
}
fn embed_dim(&self) -> usize {
1
}
fn vocab_size(&self) -> usize {
self.levels
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_quantizer() {
let q = LinearQuantizer::new(-1.0, 1.0, 8).unwrap();
assert_eq!(q.num_levels(), 256);
let level = q.quantize(0.0);
assert!((level - 127).abs() <= 1);
assert_eq!(q.quantize(-1.0), 0);
assert_eq!(q.quantize(1.0), 255);
}
#[test]
fn test_roundtrip() {
let q = LinearQuantizer::new(0.0, 100.0, 10).unwrap();
for value in [0.0, 25.0, 50.0, 75.0, 100.0] {
let level = q.quantize(value);
let recovered = q.dequantize(level);
assert!(
(recovered - value).abs() < 0.2,
"Roundtrip failed for {} -> {} -> {}",
value,
level,
recovered
);
}
}
#[test]
fn test_clamping() {
let q = LinearQuantizer::new(-1.0, 1.0, 8).unwrap();
assert_eq!(q.quantize(-2.0), 0);
assert_eq!(q.quantize(2.0), 255);
}
}