kizzasi_tokenizer/
quantizer.rs

1//! Linear quantization for signals
2//!
3//! Provides simple uniform quantization of continuous signals
4//! into discrete levels.
5
6use crate::error::{TokenizerError, TokenizerResult};
7use crate::SignalTokenizer;
8use scirs2_core::ndarray::Array1;
9
10/// Trait for quantization strategies
11pub trait Quantizer {
12    /// Quantize a continuous value to discrete level
13    fn quantize(&self, value: f32) -> i32;
14
15    /// Dequantize a discrete level back to continuous value
16    fn dequantize(&self, level: i32) -> f32;
17
18    /// Get number of quantization levels
19    fn num_levels(&self) -> usize;
20}
21
22/// Linear uniform quantizer
23#[derive(Debug, Clone)]
24pub struct LinearQuantizer {
25    /// Minimum value of range
26    min: f32,
27    /// Maximum value of range
28    max: f32,
29    /// Number of bits
30    bits: u8,
31    /// Number of levels
32    levels: usize,
33    /// Step size
34    step: f32,
35}
36
37impl LinearQuantizer {
38    /// Create a new linear quantizer
39    pub fn new(min: f32, max: f32, bits: u8) -> TokenizerResult<Self> {
40        if min >= max {
41            return Err(TokenizerError::InvalidConfig(
42                "min must be less than max".into(),
43            ));
44        }
45        if bits == 0 || bits > 16 {
46            return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
47        }
48
49        let levels = 1usize << bits;
50        let step = (max - min) / levels as f32;
51
52        Ok(Self {
53            min,
54            max,
55            bits,
56            levels,
57            step,
58        })
59    }
60
61    /// Create quantizer for normalized [-1, 1] range
62    pub fn normalized(bits: u8) -> TokenizerResult<Self> {
63        Self::new(-1.0, 1.0, bits)
64    }
65
66    /// Get the range
67    pub fn range(&self) -> (f32, f32) {
68        (self.min, self.max)
69    }
70
71    /// Get number of bits
72    pub fn bits(&self) -> u8 {
73        self.bits
74    }
75
76    /// Get the quantization step size
77    pub fn step_size(&self) -> f32 {
78        self.step
79    }
80}
81
82impl Quantizer for LinearQuantizer {
83    fn quantize(&self, value: f32) -> i32 {
84        let clamped = value.clamp(self.min, self.max);
85        let normalized = (clamped - self.min) / (self.max - self.min);
86        (normalized * (self.levels - 1) as f32).round() as i32
87    }
88
89    fn dequantize(&self, level: i32) -> f32 {
90        let clamped_level = level.clamp(0, (self.levels - 1) as i32);
91        let normalized = clamped_level as f32 / (self.levels - 1) as f32;
92        self.min + normalized * (self.max - self.min)
93    }
94
95    fn num_levels(&self) -> usize {
96        self.levels
97    }
98}
99
100impl SignalTokenizer for LinearQuantizer {
101    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
102        // Return quantized values as floats for embedding lookup
103        Ok(signal.mapv(|x| self.quantize(x) as f32))
104    }
105
106    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
107        Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
108    }
109
110    fn embed_dim(&self) -> usize {
111        1
112    }
113
114    fn vocab_size(&self) -> usize {
115        self.levels
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn test_linear_quantizer() {
125        let q = LinearQuantizer::new(-1.0, 1.0, 8).unwrap();
126
127        assert_eq!(q.num_levels(), 256);
128
129        // Test center value
130        let level = q.quantize(0.0);
131        assert!((level - 127).abs() <= 1);
132
133        // Test extremes
134        assert_eq!(q.quantize(-1.0), 0);
135        assert_eq!(q.quantize(1.0), 255);
136    }
137
138    #[test]
139    fn test_roundtrip() {
140        let q = LinearQuantizer::new(0.0, 100.0, 10).unwrap();
141
142        for value in [0.0, 25.0, 50.0, 75.0, 100.0] {
143            let level = q.quantize(value);
144            let recovered = q.dequantize(level);
145            assert!(
146                (recovered - value).abs() < 0.2,
147                "Roundtrip failed for {} -> {} -> {}",
148                value,
149                level,
150                recovered
151            );
152        }
153    }
154
155    #[test]
156    fn test_clamping() {
157        let q = LinearQuantizer::new(-1.0, 1.0, 8).unwrap();
158
159        // Values outside range should be clamped
160        assert_eq!(q.quantize(-2.0), 0);
161        assert_eq!(q.quantize(2.0), 255);
162    }
163}