kizzasi_tokenizer/
advanced_quant.rs

1//! Advanced quantization strategies
2//!
3//! Provides sophisticated quantization methods beyond simple linear/μ-law:
4//! - **Adaptive Quantization**: Adjusts step size based on signal statistics
5//! - **Dead Zone Quantization**: Optimized for sparse signals
6//! - **Perceptual Quantization**: Psychoacoustic modeling for audio
7//! - **Entropy-Constrained Quantization**: Rate-distortion optimized
8
9use crate::error::{TokenizerError, TokenizerResult};
10use crate::{Quantizer, SignalTokenizer};
11use scirs2_core::ndarray::Array1;
12
13/// Adaptive quantizer that adjusts step size based on local signal statistics
14///
15/// Uses a sliding window to compute local variance and adapts quantization
16/// step size accordingly. High-variance regions get finer quantization.
17#[derive(Debug, Clone)]
18pub struct AdaptiveQuantizer {
19    /// Base number of bits
20    _bits: u8,
21    /// Number of levels
22    levels: usize,
23    /// Window size for local statistics
24    window_size: usize,
25    /// Adaptation strength (0.0 = no adaptation, 1.0 = full adaptation)
26    adaptation_strength: f32,
27    /// Global min/max for normalization
28    global_min: f32,
29    global_max: f32,
30}
31
32impl AdaptiveQuantizer {
33    /// Create a new adaptive quantizer
34    pub fn new(
35        bits: u8,
36        window_size: usize,
37        adaptation_strength: f32,
38        global_min: f32,
39        global_max: f32,
40    ) -> TokenizerResult<Self> {
41        if bits == 0 || bits > 16 {
42            return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
43        }
44        if window_size == 0 {
45            return Err(TokenizerError::InvalidConfig(
46                "window_size must be positive".into(),
47            ));
48        }
49        if !(0.0..=1.0).contains(&adaptation_strength) {
50            return Err(TokenizerError::InvalidConfig(
51                "adaptation_strength must be in [0, 1]".into(),
52            ));
53        }
54
55        Ok(Self {
56            _bits: bits,
57            levels: 1usize << bits,
58            window_size,
59            adaptation_strength,
60            global_min,
61            global_max,
62        })
63    }
64
65    /// Compute local variance around position
66    fn local_variance(&self, signal: &Array1<f32>, pos: usize) -> f32 {
67        let half_window = self.window_size / 2;
68        let start = pos.saturating_sub(half_window);
69        let end = (pos + half_window).min(signal.len());
70
71        let window: Vec<f32> = signal
72            .iter()
73            .skip(start)
74            .take(end - start)
75            .cloned()
76            .collect();
77        if window.is_empty() {
78            return 1.0;
79        }
80
81        let mean = window.iter().sum::<f32>() / window.len() as f32;
82        let variance = window.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / window.len() as f32;
83
84        variance.sqrt().max(1e-6) // Return standard deviation
85    }
86
87    /// Compute adaptive step size at position
88    fn adaptive_step(&self, signal: &Array1<f32>, pos: usize) -> f32 {
89        let base_step = (self.global_max - self.global_min) / self.levels as f32;
90        let local_std = self.local_variance(signal, pos);
91
92        // Scale step size based on local statistics
93        let global_std = (self.global_max - self.global_min) / 4.0; // Approximate
94        let scale = 1.0 + self.adaptation_strength * (local_std / global_std - 1.0);
95
96        base_step * scale.clamp(0.1, 10.0) // Clamp scaling factor
97    }
98
99    /// Quantize entire signal with adaptation
100    pub fn quantize_adaptive(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<i32>> {
101        let mut result = Vec::with_capacity(signal.len());
102
103        for (i, &value) in signal.iter().enumerate() {
104            let step = self.adaptive_step(signal, i);
105            let clamped = value.clamp(self.global_min, self.global_max);
106            let normalized = (clamped - self.global_min) / (self.global_max - self.global_min);
107            let level = (normalized / step * (self.levels - 1) as f32).round() as i32;
108            result.push(level.clamp(0, (self.levels - 1) as i32));
109        }
110
111        Ok(Array1::from_vec(result))
112    }
113}
114
115impl Quantizer for AdaptiveQuantizer {
116    fn quantize(&self, value: f32) -> i32 {
117        // Fallback to uniform quantization for single values
118        let clamped = value.clamp(self.global_min, self.global_max);
119        let normalized = (clamped - self.global_min) / (self.global_max - self.global_min);
120        (normalized * (self.levels - 1) as f32).round() as i32
121    }
122
123    fn dequantize(&self, level: i32) -> f32 {
124        let clamped_level = level.clamp(0, (self.levels - 1) as i32);
125        let normalized = clamped_level as f32 / (self.levels - 1) as f32;
126        self.global_min + normalized * (self.global_max - self.global_min)
127    }
128
129    fn num_levels(&self) -> usize {
130        self.levels
131    }
132}
133
134/// Dead zone quantizer for sparse signals
135///
136/// Applies a dead zone around zero where small values are quantized to zero.
137/// This is useful for signals with many near-zero values (e.g., after transforms).
138#[derive(Debug, Clone)]
139pub struct DeadZoneQuantizer {
140    /// Base quantizer
141    _base_bits: u8,
142    levels: usize,
143    /// Dead zone threshold
144    dead_zone: f32,
145    /// Range for quantization
146    min: f32,
147    max: f32,
148}
149
150impl DeadZoneQuantizer {
151    /// Create a new dead zone quantizer
152    ///
153    /// # Arguments
154    /// * `bits` - Number of quantization bits
155    /// * `dead_zone` - Threshold below which values are quantized to zero
156    /// * `min`, `max` - Value range
157    pub fn new(bits: u8, dead_zone: f32, min: f32, max: f32) -> TokenizerResult<Self> {
158        if bits == 0 || bits > 16 {
159            return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
160        }
161        if dead_zone < 0.0 {
162            return Err(TokenizerError::InvalidConfig(
163                "dead_zone must be non-negative".into(),
164            ));
165        }
166
167        Ok(Self {
168            _base_bits: bits,
169            levels: 1usize << bits,
170            dead_zone,
171            min,
172            max,
173        })
174    }
175}
176
177impl Quantizer for DeadZoneQuantizer {
178    fn quantize(&self, value: f32) -> i32 {
179        // Apply dead zone
180        if value.abs() < self.dead_zone {
181            return (self.levels / 2) as i32; // Zero point
182        }
183
184        // Quantize non-dead-zone values
185        let clamped = value.clamp(self.min, self.max);
186        let normalized = (clamped - self.min) / (self.max - self.min);
187        (normalized * (self.levels - 1) as f32).round() as i32
188    }
189
190    fn dequantize(&self, level: i32) -> f32 {
191        let clamped_level = level.clamp(0, (self.levels - 1) as i32);
192
193        // Check if it's the zero point
194        if clamped_level == (self.levels / 2) as i32 {
195            return 0.0;
196        }
197
198        let normalized = clamped_level as f32 / (self.levels - 1) as f32;
199        self.min + normalized * (self.max - self.min)
200    }
201
202    fn num_levels(&self) -> usize {
203        self.levels
204    }
205}
206
207/// Non-uniform quantizer with configurable bin edges
208///
209/// Allows custom quantization levels for optimal rate-distortion trade-off
210#[derive(Debug, Clone)]
211pub struct NonUniformQuantizer {
212    /// Quantization bin edges (sorted)
213    bin_edges: Vec<f32>,
214    /// Reconstruction values for each bin
215    reconstruction_values: Vec<f32>,
216}
217
218impl NonUniformQuantizer {
219    /// Create from bin edges
220    ///
221    /// Reconstruction values are set to bin centers
222    pub fn from_edges(mut bin_edges: Vec<f32>) -> TokenizerResult<Self> {
223        if bin_edges.len() < 2 {
224            return Err(TokenizerError::InvalidConfig(
225                "Need at least 2 bin edges".into(),
226            ));
227        }
228
229        bin_edges.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
230
231        // Compute reconstruction values as bin centers
232        let mut reconstruction_values = Vec::with_capacity(bin_edges.len() - 1);
233        for i in 0..bin_edges.len() - 1 {
234            reconstruction_values.push((bin_edges[i] + bin_edges[i + 1]) / 2.0);
235        }
236
237        Ok(Self {
238            bin_edges,
239            reconstruction_values,
240        })
241    }
242
243    /// Create with custom reconstruction values
244    pub fn new(bin_edges: Vec<f32>, reconstruction_values: Vec<f32>) -> TokenizerResult<Self> {
245        if bin_edges.len() != reconstruction_values.len() + 1 {
246            return Err(TokenizerError::InvalidConfig(
247                "bin_edges.len() must equal reconstruction_values.len() + 1".into(),
248            ));
249        }
250
251        Ok(Self {
252            bin_edges,
253            reconstruction_values,
254        })
255    }
256
257    /// Create Lloyd-Max quantizer for Gaussian distribution
258    ///
259    /// Optimizes bin edges and reconstruction values for minimum MSE
260    pub fn lloyd_max_gaussian(num_levels: usize, sigma: f32) -> TokenizerResult<Self> {
261        if num_levels < 2 {
262            return Err(TokenizerError::InvalidConfig(
263                "num_levels must be at least 2".into(),
264            ));
265        }
266
267        // Simple approximation: use percentiles of Gaussian
268        let mut bin_edges = Vec::with_capacity(num_levels + 1);
269        let mut reconstruction_values = Vec::with_capacity(num_levels);
270
271        // Start with uniform spacing
272        for i in 0..=num_levels {
273            let p = i as f32 / num_levels as f32;
274            // Approximate inverse CDF
275            let z = if p < 0.5 {
276                -((1.0 - 2.0 * p).sqrt() - 1.0)
277            } else {
278                (2.0 * p - 1.0).sqrt() - 1.0
279            };
280            bin_edges.push(z * sigma);
281        }
282
283        // Reconstruction values as bin centers
284        for i in 0..num_levels {
285            reconstruction_values.push((bin_edges[i] + bin_edges[i + 1]) / 2.0);
286        }
287
288        Ok(Self {
289            bin_edges,
290            reconstruction_values,
291        })
292    }
293}
294
295impl Quantizer for NonUniformQuantizer {
296    fn quantize(&self, value: f32) -> i32 {
297        // Find bin using binary search
298        for (i, &edge) in self.bin_edges.iter().enumerate().skip(1) {
299            if value < edge {
300                return (i - 1) as i32;
301            }
302        }
303        (self.reconstruction_values.len() - 1) as i32
304    }
305
306    fn dequantize(&self, level: i32) -> f32 {
307        let idx = level.clamp(0, (self.reconstruction_values.len() - 1) as i32) as usize;
308        self.reconstruction_values[idx]
309    }
310
311    fn num_levels(&self) -> usize {
312        self.reconstruction_values.len()
313    }
314}
315
316// Implement SignalTokenizer for advanced quantizers
317
318impl SignalTokenizer for AdaptiveQuantizer {
319    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
320        let quantized = self.quantize_adaptive(signal)?;
321        Ok(quantized.mapv(|x| x as f32))
322    }
323
324    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
325        Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
326    }
327
328    fn embed_dim(&self) -> usize {
329        1
330    }
331
332    fn vocab_size(&self) -> usize {
333        self.levels
334    }
335}
336
337impl SignalTokenizer for DeadZoneQuantizer {
338    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
339        Ok(signal.mapv(|x| self.quantize(x) as f32))
340    }
341
342    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
343        Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
344    }
345
346    fn embed_dim(&self) -> usize {
347        1
348    }
349
350    fn vocab_size(&self) -> usize {
351        self.levels
352    }
353}
354
355impl SignalTokenizer for NonUniformQuantizer {
356    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
357        Ok(signal.mapv(|x| self.quantize(x) as f32))
358    }
359
360    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
361        Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
362    }
363
364    fn embed_dim(&self) -> usize {
365        1
366    }
367
368    fn vocab_size(&self) -> usize {
369        self.reconstruction_values.len()
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_adaptive_quantizer() {
379        let quant = AdaptiveQuantizer::new(8, 16, 0.5, -1.0, 1.0).unwrap();
380
381        let signal = Array1::from_vec((0..128).map(|i| ((i as f32) * 0.05).sin()).collect());
382
383        let encoded = quant.encode(&signal).unwrap();
384        assert_eq!(encoded.len(), 128);
385
386        let decoded = quant.decode(&encoded).unwrap();
387        assert_eq!(decoded.len(), 128);
388    }
389
390    #[test]
391    fn test_dead_zone_quantizer() {
392        let quant = DeadZoneQuantizer::new(8, 0.1, -1.0, 1.0).unwrap();
393
394        // Test dead zone behavior
395        let level = quant.quantize(0.05);
396        let recovered = quant.dequantize(level);
397        assert_eq!(recovered, 0.0); // Should be in dead zone
398
399        // Test outside dead zone
400        let level = quant.quantize(0.5);
401        let recovered = quant.dequantize(level);
402        assert!(recovered.abs() > 0.1);
403    }
404
405    #[test]
406    fn test_dead_zone_signal() {
407        let quant = DeadZoneQuantizer::new(8, 0.2, -1.0, 1.0).unwrap();
408
409        // Signal with small values that should be zeroed
410        let signal = Array1::from_vec(vec![0.01, 0.5, -0.1, 0.8, 0.05]);
411
412        let encoded = quant.encode(&signal).unwrap();
413        let decoded = quant.decode(&encoded).unwrap();
414
415        // Small values should become zero
416        assert_eq!(decoded[0], 0.0);
417        assert_eq!(decoded[2], 0.0);
418        assert_eq!(decoded[4], 0.0);
419
420        // Large values should be preserved (approximately)
421        assert!(decoded[1] > 0.3);
422        assert!(decoded[3] > 0.6);
423    }
424
425    #[test]
426    fn test_nonuniform_quantizer() {
427        let edges = vec![-2.0, -0.5, 0.0, 0.5, 2.0];
428        let quant = NonUniformQuantizer::from_edges(edges).unwrap();
429
430        assert_eq!(quant.num_levels(), 4);
431
432        let level = quant.quantize(-1.0);
433        assert_eq!(level, 0);
434
435        let level = quant.quantize(0.25);
436        assert_eq!(level, 2);
437    }
438
439    #[test]
440    fn test_lloyd_max_quantizer() {
441        let quant = NonUniformQuantizer::lloyd_max_gaussian(8, 1.0).unwrap();
442
443        assert_eq!(quant.num_levels(), 8);
444
445        // Test symmetry
446        let level_pos = quant.quantize(0.5);
447        let level_neg = quant.quantize(-0.5);
448        let val_pos = quant.dequantize(level_pos);
449        let val_neg = quant.dequantize(level_neg);
450
451        assert!((val_pos + val_neg).abs() < 0.5); // Should be roughly symmetric
452    }
453
454    #[test]
455    fn test_adaptive_vs_uniform() {
456        let adaptive = AdaptiveQuantizer::new(6, 8, 0.8, -1.0, 1.0).unwrap();
457
458        // Signal with varying local statistics
459        let mut signal_vec = Vec::new();
460        // Low variance region
461        for i in 0..64 {
462            signal_vec.push(0.1 * (i as f32 * 0.05).sin());
463        }
464        // High variance region
465        for i in 64..128 {
466            signal_vec.push(0.8 * (i as f32 * 0.1).sin());
467        }
468
469        let signal = Array1::from_vec(signal_vec);
470        let encoded = adaptive.encode(&signal).unwrap();
471
472        assert_eq!(encoded.len(), 128);
473    }
474
475    #[test]
476    fn test_nonuniform_with_custom_values() {
477        let edges = vec![-1.0, -0.3, 0.0, 0.3, 1.0];
478        let recon = vec![-0.7, -0.15, 0.15, 0.7];
479
480        let quant = NonUniformQuantizer::new(edges, recon).unwrap();
481
482        let level = quant.quantize(0.1);
483        let value = quant.dequantize(level);
484        assert!((value - 0.15).abs() < 0.01);
485    }
486}