Skip to main content

kizzasi_tokenizer/
advanced_quant.rs

1//! Advanced quantization strategies
2//!
3//! Provides sophisticated quantization methods beyond simple linear/μ-law:
4//! - **Adaptive Quantization**: Adjusts step size based on signal statistics
5//! - **Dead Zone Quantization**: Optimized for sparse signals
6//! - **Perceptual Quantization**: Psychoacoustic modeling for audio
7//! - **Entropy-Constrained Quantization**: Rate-distortion optimized
8
9use crate::error::{TokenizerError, TokenizerResult};
10use crate::{Quantizer, SignalTokenizer};
11use scirs2_core::ndarray::Array1;
12
13/// Adaptive quantizer that adjusts step size based on local signal statistics
14///
15/// Uses a sliding window to compute local variance and adapts quantization
16/// step size accordingly. High-variance regions get finer quantization.
17#[derive(Debug, Clone)]
18pub struct AdaptiveQuantizer {
19    /// Base number of bits
20    _bits: u8,
21    /// Number of levels
22    levels: usize,
23    /// Window size for local statistics
24    window_size: usize,
25    /// Adaptation strength (0.0 = no adaptation, 1.0 = full adaptation)
26    adaptation_strength: f32,
27    /// Global min/max for normalization
28    global_min: f32,
29    global_max: f32,
30}
31
32impl AdaptiveQuantizer {
33    /// Create a new adaptive quantizer
34    pub fn new(
35        bits: u8,
36        window_size: usize,
37        adaptation_strength: f32,
38        global_min: f32,
39        global_max: f32,
40    ) -> TokenizerResult<Self> {
41        if bits == 0 || bits > 16 {
42            return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
43        }
44        if window_size == 0 {
45            return Err(TokenizerError::InvalidConfig(
46                "window_size must be positive".into(),
47            ));
48        }
49        if !(0.0..=1.0).contains(&adaptation_strength) {
50            return Err(TokenizerError::InvalidConfig(
51                "adaptation_strength must be in [0, 1]".into(),
52            ));
53        }
54
55        Ok(Self {
56            _bits: bits,
57            levels: 1usize << bits,
58            window_size,
59            adaptation_strength,
60            global_min,
61            global_max,
62        })
63    }
64
65    /// Compute local variance around position
66    fn local_variance(&self, signal: &Array1<f32>, pos: usize) -> f32 {
67        let half_window = self.window_size / 2;
68        let start = pos.saturating_sub(half_window);
69        let end = (pos + half_window).min(signal.len());
70
71        let window: Vec<f32> = signal
72            .iter()
73            .skip(start)
74            .take(end - start)
75            .cloned()
76            .collect();
77        if window.is_empty() {
78            return 1.0;
79        }
80
81        let mean = window.iter().sum::<f32>() / window.len() as f32;
82        let variance = window.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / window.len() as f32;
83
84        variance.sqrt().max(1e-6) // Return standard deviation
85    }
86
87    /// Compute adaptive step size at position
88    fn adaptive_step(&self, signal: &Array1<f32>, pos: usize) -> f32 {
89        let base_step = (self.global_max - self.global_min) / self.levels as f32;
90        let local_std = self.local_variance(signal, pos);
91
92        // Scale step size based on local statistics
93        let global_std = (self.global_max - self.global_min) / 4.0; // Approximate
94        let scale = 1.0 + self.adaptation_strength * (local_std / global_std - 1.0);
95
96        base_step * scale.clamp(0.1, 10.0) // Clamp scaling factor
97    }
98
99    /// Quantize entire signal with adaptation
100    pub fn quantize_adaptive(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<i32>> {
101        let mut result = Vec::with_capacity(signal.len());
102
103        for (i, &value) in signal.iter().enumerate() {
104            let step = self.adaptive_step(signal, i);
105            let clamped = value.clamp(self.global_min, self.global_max);
106            let normalized = (clamped - self.global_min) / (self.global_max - self.global_min);
107            let level = (normalized / step * (self.levels - 1) as f32).round() as i32;
108            result.push(level.clamp(0, (self.levels - 1) as i32));
109        }
110
111        Ok(Array1::from_vec(result))
112    }
113}
114
115impl Quantizer for AdaptiveQuantizer {
116    fn quantize(&self, value: f32) -> i32 {
117        // Fallback to uniform quantization for single values
118        let clamped = value.clamp(self.global_min, self.global_max);
119        let normalized = (clamped - self.global_min) / (self.global_max - self.global_min);
120        (normalized * (self.levels - 1) as f32).round() as i32
121    }
122
123    fn dequantize(&self, level: i32) -> f32 {
124        let clamped_level = level.clamp(0, (self.levels - 1) as i32);
125        let normalized = clamped_level as f32 / (self.levels - 1) as f32;
126        self.global_min + normalized * (self.global_max - self.global_min)
127    }
128
129    fn num_levels(&self) -> usize {
130        self.levels
131    }
132}
133
134/// Dead zone quantizer for sparse signals
135///
136/// Applies a dead zone around zero where small values are quantized to zero.
137/// This is useful for signals with many near-zero values (e.g., after transforms).
138#[derive(Debug, Clone)]
139pub struct DeadZoneQuantizer {
140    /// Base quantizer
141    _base_bits: u8,
142    levels: usize,
143    /// Dead zone threshold
144    dead_zone: f32,
145    /// Range for quantization
146    min: f32,
147    max: f32,
148}
149
150impl DeadZoneQuantizer {
151    /// Create a new dead zone quantizer
152    ///
153    /// # Arguments
154    /// * `bits` - Number of quantization bits
155    /// * `dead_zone` - Threshold below which values are quantized to zero
156    /// * `min`, `max` - Value range
157    pub fn new(bits: u8, dead_zone: f32, min: f32, max: f32) -> TokenizerResult<Self> {
158        if bits == 0 || bits > 16 {
159            return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
160        }
161        if dead_zone < 0.0 {
162            return Err(TokenizerError::InvalidConfig(
163                "dead_zone must be non-negative".into(),
164            ));
165        }
166
167        Ok(Self {
168            _base_bits: bits,
169            levels: 1usize << bits,
170            dead_zone,
171            min,
172            max,
173        })
174    }
175}
176
177impl Quantizer for DeadZoneQuantizer {
178    fn quantize(&self, value: f32) -> i32 {
179        // Apply dead zone
180        if value.abs() < self.dead_zone {
181            return (self.levels / 2) as i32; // Zero point
182        }
183
184        // Quantize non-dead-zone values
185        let clamped = value.clamp(self.min, self.max);
186        let normalized = (clamped - self.min) / (self.max - self.min);
187        (normalized * (self.levels - 1) as f32).round() as i32
188    }
189
190    fn dequantize(&self, level: i32) -> f32 {
191        let clamped_level = level.clamp(0, (self.levels - 1) as i32);
192
193        // Check if it's the zero point
194        if clamped_level == (self.levels / 2) as i32 {
195            return 0.0;
196        }
197
198        let normalized = clamped_level as f32 / (self.levels - 1) as f32;
199        self.min + normalized * (self.max - self.min)
200    }
201
202    fn num_levels(&self) -> usize {
203        self.levels
204    }
205}
206
207/// Non-uniform quantizer with configurable bin edges
208///
209/// Allows custom quantization levels for optimal rate-distortion trade-off
210#[derive(Debug, Clone)]
211pub struct NonUniformQuantizer {
212    /// Quantization bin edges (sorted)
213    bin_edges: Vec<f32>,
214    /// Reconstruction values for each bin
215    reconstruction_values: Vec<f32>,
216}
217
218impl NonUniformQuantizer {
219    /// Create from bin edges
220    ///
221    /// Reconstruction values are set to bin centers
222    pub fn from_edges(mut bin_edges: Vec<f32>) -> TokenizerResult<Self> {
223        if bin_edges.len() < 2 {
224            return Err(TokenizerError::InvalidConfig(
225                "Need at least 2 bin edges".into(),
226            ));
227        }
228
229        bin_edges.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
230
231        // Compute reconstruction values as bin centers
232        let mut reconstruction_values = Vec::with_capacity(bin_edges.len() - 1);
233        for i in 0..bin_edges.len() - 1 {
234            reconstruction_values.push((bin_edges[i] + bin_edges[i + 1]) / 2.0);
235        }
236
237        Ok(Self {
238            bin_edges,
239            reconstruction_values,
240        })
241    }
242
243    /// Create with custom reconstruction values
244    pub fn new(bin_edges: Vec<f32>, reconstruction_values: Vec<f32>) -> TokenizerResult<Self> {
245        if bin_edges.len() != reconstruction_values.len() + 1 {
246            return Err(TokenizerError::InvalidConfig(
247                "bin_edges.len() must equal reconstruction_values.len() + 1".into(),
248            ));
249        }
250
251        Ok(Self {
252            bin_edges,
253            reconstruction_values,
254        })
255    }
256
257    /// Create Lloyd-Max quantizer for Gaussian distribution
258    ///
259    /// Optimizes bin edges and reconstruction values for minimum MSE
260    pub fn lloyd_max_gaussian(num_levels: usize, sigma: f32) -> TokenizerResult<Self> {
261        if num_levels < 2 {
262            return Err(TokenizerError::InvalidConfig(
263                "num_levels must be at least 2".into(),
264            ));
265        }
266
267        // Simple approximation: use percentiles of Gaussian
268        let mut bin_edges = Vec::with_capacity(num_levels + 1);
269        let mut reconstruction_values = Vec::with_capacity(num_levels);
270
271        // Start with uniform spacing
272        for i in 0..=num_levels {
273            let p = i as f32 / num_levels as f32;
274            // Approximate inverse CDF
275            let z = if p < 0.5 {
276                -((1.0 - 2.0 * p).sqrt() - 1.0)
277            } else {
278                (2.0 * p - 1.0).sqrt() - 1.0
279            };
280            bin_edges.push(z * sigma);
281        }
282
283        // Reconstruction values as bin centers
284        for i in 0..num_levels {
285            reconstruction_values.push((bin_edges[i] + bin_edges[i + 1]) / 2.0);
286        }
287
288        Ok(Self {
289            bin_edges,
290            reconstruction_values,
291        })
292    }
293}
294
295impl Quantizer for NonUniformQuantizer {
296    fn quantize(&self, value: f32) -> i32 {
297        // Find bin using binary search
298        for (i, &edge) in self.bin_edges.iter().enumerate().skip(1) {
299            if value < edge {
300                return (i - 1) as i32;
301            }
302        }
303        (self.reconstruction_values.len() - 1) as i32
304    }
305
306    fn dequantize(&self, level: i32) -> f32 {
307        let idx = level.clamp(0, (self.reconstruction_values.len() - 1) as i32) as usize;
308        self.reconstruction_values[idx]
309    }
310
311    fn num_levels(&self) -> usize {
312        self.reconstruction_values.len()
313    }
314}
315
316// Implement SignalTokenizer for advanced quantizers
317
318impl SignalTokenizer for AdaptiveQuantizer {
319    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
320        let quantized = self.quantize_adaptive(signal)?;
321        Ok(quantized.mapv(|x| x as f32))
322    }
323
324    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
325        Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
326    }
327
328    fn embed_dim(&self) -> usize {
329        1
330    }
331
332    fn vocab_size(&self) -> usize {
333        self.levels
334    }
335}
336
337impl SignalTokenizer for DeadZoneQuantizer {
338    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
339        Ok(signal.mapv(|x| self.quantize(x) as f32))
340    }
341
342    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
343        Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
344    }
345
346    fn embed_dim(&self) -> usize {
347        1
348    }
349
350    fn vocab_size(&self) -> usize {
351        self.levels
352    }
353}
354
355impl SignalTokenizer for NonUniformQuantizer {
356    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
357        Ok(signal.mapv(|x| self.quantize(x) as f32))
358    }
359
360    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
361        Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
362    }
363
364    fn embed_dim(&self) -> usize {
365        1
366    }
367
368    fn vocab_size(&self) -> usize {
369        self.reconstruction_values.len()
370    }
371}
372
373// ─────────────────────────────────────────────────────────────────────────────
374// Private helper: binary search over interior bin edges
375// Returns the bin index in [0, edges.len()] for scalar `x`.
376// ─────────────────────────────────────────────────────────────────────────────
377fn find_bin(x: f32, edges: &[f32]) -> usize {
378    let mut lo = 0usize;
379    let mut hi = edges.len(); // hi == num_levels - 1 for L levels
380    while lo < hi {
381        let mid = lo + (hi - lo) / 2;
382        if x <= edges[mid] {
383            hi = mid;
384        } else {
385            lo = mid + 1;
386        }
387    }
388    lo // in [0, num_levels - 1]
389}
390
391// ─────────────────────────────────────────────────────────────────────────────
392// EntropyConstrainedQuantizer
393// ─────────────────────────────────────────────────────────────────────────────
394
395/// Lagrangian Rate-Distortion optimal scalar quantizer.
396///
397/// Minimizes `D + λ·R` where D is MSE distortion and R is empirical entropy.
398/// Implements an entropy-regularized Lloyd-Max algorithm:
399///
400/// 1. Initialize bin edges from equal-mass (percentile) split of the signal.
401/// 2. Iterate:
402///    - **Centroid update**: recon[i] = mean of samples assigned to bin i.
403///    - **Entropy-regularized edge update**:
404///      `edge[i] = 0.5·(recon[i-1]+recon[i]) + (λ/(recon[i]-recon[i-1]))·(ln p[i-1] − ln p[i])`
405///    - Clamp edges to be strictly monotonic.
406///    - Recompute empirical probabilities.
407///    - Evaluate `cost = D + λ·R` and stop when Δcost < tol.
408pub struct EntropyConstrainedQuantizer {
409    /// Interior bin edges (length = num_levels - 1, strictly sorted).
410    bin_edges: Vec<f32>,
411    /// Reconstruction value for each bin (length = num_levels).
412    reconstruction_values: Vec<f32>,
413    /// Lagrange multiplier controlling R-D trade-off.
414    lambda: f32,
415    /// Optional target bits-per-symbol (set by `fit_with_target_rate`).
416    target_bits_per_symbol: Option<f64>,
417    /// Empirical probabilities of each bin after fitting.
418    empirical_probs: Vec<f64>,
419}
420
421impl EntropyConstrainedQuantizer {
422    /// Construct directly from pre-computed edges and reconstruction values.
423    ///
424    /// Uniform prior probabilities are assumed until `fit_lagrangian` is called.
425    pub fn new(bin_edges: Vec<f32>, reconstruction_values: Vec<f32>, lambda: f32) -> Self {
426        let n = reconstruction_values.len();
427        Self {
428            bin_edges,
429            reconstruction_values,
430            lambda,
431            target_bits_per_symbol: None,
432            empirical_probs: vec![1.0 / n as f64; n],
433        }
434    }
435
436    /// Fit ECQ via entropy-regularized Lloyd-Max iteration.
437    ///
438    /// Returns `Err` if `num_levels < 2` or the signal is too short.
439    ///
440    /// # Arguments
441    ///
442    /// * `signal`     – Input signal samples.
443    /// * `num_levels` – Number of quantization bins (≥ 2).
444    /// * `lambda`     – Lagrange multiplier; larger → more compression, less fidelity.
445    /// * `max_iters`  – Maximum Lloyd-Max iterations.
446    /// * `tol`        – Convergence threshold on `|Δcost|`.
447    pub fn fit_lagrangian(
448        signal: &Array1<f32>,
449        num_levels: usize,
450        lambda: f32,
451        max_iters: usize,
452        tol: f32,
453    ) -> TokenizerResult<Self> {
454        if num_levels < 2 {
455            return Err(TokenizerError::InvalidConfig(
456                "num_levels must be >= 2".into(),
457            ));
458        }
459        if signal.len() < num_levels {
460            return Err(TokenizerError::InvalidConfig(
461                "signal is too short for the requested num_levels".into(),
462            ));
463        }
464
465        let n = signal.len();
466        let sig_min = signal.iter().cloned().fold(f32::INFINITY, f32::min);
467        let sig_max = signal.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
468        let range = (sig_max - sig_min).max(1e-6);
469        let min_gap = 1e-6 * range;
470
471        // Step 1: Init bin edges from equal-mass (percentile) split.
472        let mut sorted_signal: Vec<f32> = signal.iter().cloned().collect();
473        sorted_signal.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
474
475        // bin_edges has `num_levels - 1` interior edges.
476        let mut bin_edges: Vec<f32> = (1..num_levels)
477            .map(|i| {
478                let idx = (i * n / num_levels).min(n - 1);
479                sorted_signal[idx]
480            })
481            .collect();
482
483        // Step 2: Init reconstruction values as bin midpoints.
484        // Bins: (-∞, edge[0]], (edge[0], edge[1]], …, (edge[L-2], +∞)
485        let mut recon: Vec<f32> = {
486            let mut r = Vec::with_capacity(num_levels);
487            r.push((sig_min + bin_edges[0]) * 0.5);
488            for i in 1..num_levels - 1 {
489                r.push((bin_edges[i - 1] + bin_edges[i]) * 0.5);
490            }
491            r.push((bin_edges[num_levels - 2] + sig_max) * 0.5);
492            r
493        };
494
495        let mut probs = vec![1.0f64 / num_levels as f64; num_levels];
496        let mut prev_cost = f64::INFINITY;
497
498        for _iter in 0..max_iters {
499            // ── Step 3a: Centroid update ──────────────────────────────────
500            let mut sums = vec![0.0f64; num_levels];
501            let mut counts = vec![0usize; num_levels];
502            for &x in signal.iter() {
503                let b = find_bin(x, &bin_edges);
504                sums[b] += x as f64;
505                counts[b] += 1;
506            }
507            for i in 0..num_levels {
508                if counts[i] > 0 {
509                    recon[i] = (sums[i] / counts[i] as f64) as f32;
510                }
511                // Empty bin: keep previous centroid (no panic).
512            }
513
514            // Monotonicity guard after centroid update (insurance for
515            // edge cases where empty bins collapse centroids).
516            for i in 1..num_levels {
517                if recon[i] <= recon[i - 1] + min_gap {
518                    recon[i] = recon[i - 1] + min_gap;
519                }
520            }
521
522            // ── Step 3b: Recompute probs for edge update ─────────────────
523            let eps = 1e-10;
524            let denom = n as f64 + num_levels as f64 * eps;
525            for i in 0..num_levels {
526                probs[i] = (counts[i] as f64 + eps) / denom;
527            }
528
529            // ── Step 3c: Entropy-regularized edge update ──────────────────
530            for i in 0..num_levels - 1 {
531                let r_left = recon[i];
532                let r_right = recon[i + 1];
533                let gap = (r_right - r_left).max(min_gap);
534                let p_left = probs[i].max(eps);
535                let p_right = probs[i + 1].max(eps);
536                let entropy_term = (lambda / gap) * (p_left.ln() - p_right.ln()) as f32;
537                bin_edges[i] = 0.5 * (r_left + r_right) + entropy_term;
538            }
539
540            // ── Step 3d: Enforce strict monotonicity ─────────────────────
541            for i in 1..bin_edges.len() {
542                if bin_edges[i] <= bin_edges[i - 1] + min_gap {
543                    bin_edges[i] = bin_edges[i - 1] + min_gap;
544                }
545            }
546
547            // ── Step 3e: Recompute probs with new edges ───────────────────
548            let mut new_counts = vec![0usize; num_levels];
549            for &x in signal.iter() {
550                new_counts[find_bin(x, &bin_edges)] += 1;
551            }
552            for i in 0..num_levels {
553                probs[i] = (new_counts[i] as f64 + eps) / denom;
554            }
555
556            // ── Step 3f: Convergence check on D + λ·R ────────────────────
557            let distortion: f64 = signal
558                .iter()
559                .map(|&x| {
560                    let b = find_bin(x, &bin_edges);
561                    let d = x as f64 - recon[b] as f64;
562                    d * d
563                })
564                .sum::<f64>()
565                / n as f64;
566
567            let entropy: f64 = probs
568                .iter()
569                .map(|&p| if p > eps { -p * p.log2() } else { 0.0 })
570                .sum();
571
572            let cost = distortion + lambda as f64 * entropy;
573
574            if (prev_cost - cost).abs() < tol as f64 {
575                break;
576            }
577            prev_cost = cost;
578        }
579
580        Ok(Self {
581            bin_edges,
582            reconstruction_values: recon,
583            lambda,
584            target_bits_per_symbol: None,
585            empirical_probs: probs,
586        })
587    }
588
589    /// Compress `signal` using Huffman coding built from the fitted bin
590    /// probabilities.
591    ///
592    /// Returns `(compressed_bytes, symbol_count)`.  The `symbol_count` is
593    /// needed for lossless decompression.
594    pub fn encode_compressed(&self, signal: &Array1<f32>) -> TokenizerResult<(Vec<u8>, u64)> {
595        use crate::entropy::{compute_frequencies, HuffmanEncoder};
596
597        let symbols: Vec<u32> = signal
598            .iter()
599            .map(|&x| find_bin(x, &self.bin_edges) as u32)
600            .collect();
601
602        let freqs = compute_frequencies(&symbols);
603        let encoder = HuffmanEncoder::from_frequencies(&freqs)?;
604        let compressed = encoder.encode(&symbols)?;
605        let symbol_count = symbols.len() as u64;
606        Ok((compressed, symbol_count))
607    }
608
609    /// Decompress bytes produced by `encode_compressed` back to a signal.
610    ///
611    /// The `symbol_count` must match the value returned by `encode_compressed`.
612    pub fn decode_compressed(
613        &self,
614        bytes: &[u8],
615        _symbol_count: u64,
616    ) -> TokenizerResult<Array1<f32>> {
617        use crate::entropy::{HuffmanDecoder, HuffmanEncoder};
618
619        // Re-build the same frequency table from `empirical_probs` so we can
620        // reconstruct the Huffman tree without storing it separately.
621        let n_levels = self.reconstruction_values.len();
622        let total_pseudo = 1_000_000u64; // Scale probs to integer counts.
623        let mut freqs = std::collections::HashMap::new();
624        let mut allocated = 0u64;
625        for i in 0..n_levels {
626            let cnt = (self.empirical_probs[i] * total_pseudo as f64).round() as u64;
627            let cnt = cnt.max(1); // At least 1 to keep symbol in codebook.
628            freqs.insert(i as u32, cnt);
629            allocated += cnt;
630        }
631        // Give the leftover to symbol 0 to keep totals consistent (doesn't
632        // affect code *lengths*, only the tree shape).
633        let _ = allocated; // unused; counts only need to be proportional.
634
635        let encoder = HuffmanEncoder::from_frequencies(&freqs)?;
636        let decoder = HuffmanDecoder::new(encoder.tree());
637        let indices = decoder.decode(bytes)?;
638
639        let values: Vec<f32> = indices
640            .iter()
641            .map(|&idx| {
642                let b = (idx as usize).min(self.reconstruction_values.len() - 1);
643                self.reconstruction_values[b]
644            })
645            .collect();
646
647        Ok(Array1::from_vec(values))
648    }
649
650    /// Fit ECQ using binary search over λ to hit a target entropy rate.
651    ///
652    /// # Arguments
653    ///
654    /// * `signal`          – Training signal.
655    /// * `num_levels`      – Number of quantization bins.
656    /// * `target_bpp`      – Desired bits per symbol.
657    /// * `max_outer_iters` – Number of λ bisection steps.
658    pub fn fit_with_target_rate(
659        signal: &Array1<f32>,
660        num_levels: usize,
661        target_bpp: f64,
662        max_outer_iters: usize,
663    ) -> TokenizerResult<Self> {
664        let mut lambda_lo = 0.0f32;
665        let mut lambda_hi = 10.0f32;
666
667        // Start with the high-lambda (low-rate) end.
668        let mut best = Self::fit_lagrangian(signal, num_levels, lambda_hi, 100, 1e-5)?;
669
670        for _ in 0..max_outer_iters {
671            let lambda_mid = (lambda_lo + lambda_hi) * 0.5;
672            let candidate = Self::fit_lagrangian(signal, num_levels, lambda_mid, 100, 1e-5)?;
673            let rate = candidate.compute_entropy_rate(signal);
674            if rate > target_bpp {
675                // Rate too high → increase lambda to compress more.
676                lambda_lo = lambda_mid;
677            } else {
678                // Rate low enough → record this as "best so far" and try less compression.
679                lambda_hi = lambda_mid;
680                best = candidate;
681            }
682            if (lambda_hi - lambda_lo) < 1e-4 {
683                break;
684            }
685        }
686        // Update the stored target for reference.
687        best.target_bits_per_symbol = Some(target_bpp);
688        Ok(best)
689    }
690
691    /// Compute empirical Shannon entropy (bits/symbol) of the signal under
692    /// this quantizer's bin partition.
693    pub fn compute_entropy_rate(&self, signal: &Array1<f32>) -> f64 {
694        let n = signal.len();
695        if n == 0 {
696            return 0.0;
697        }
698        let mut counts = vec![0usize; self.reconstruction_values.len()];
699        for &x in signal.iter() {
700            counts[find_bin(x, &self.bin_edges)] += 1;
701        }
702        counts
703            .iter()
704            .map(|&c| {
705                if c > 0 {
706                    let p = c as f64 / n as f64;
707                    -p * p.log2()
708                } else {
709                    0.0
710                }
711            })
712            .sum()
713    }
714
715    /// Compute mean-squared distortion of the signal under this quantizer.
716    pub fn empirical_distortion(&self, signal: &Array1<f32>) -> f64 {
717        let n = signal.len();
718        if n == 0 {
719            return 0.0;
720        }
721        signal
722            .iter()
723            .map(|&x| {
724                let r = self.reconstruction_values[find_bin(x, &self.bin_edges)];
725                let d = (x - r) as f64;
726                d * d
727            })
728            .sum::<f64>()
729            / n as f64
730    }
731
732    /// Return a reference to the interior bin edges.
733    pub fn bin_edges(&self) -> &[f32] {
734        &self.bin_edges
735    }
736
737    /// Return a reference to the reconstruction values.
738    pub fn reconstruction_values(&self) -> &[f32] {
739        &self.reconstruction_values
740    }
741
742    /// Return the Lagrange multiplier used during fitting.
743    pub fn lambda(&self) -> f32 {
744        self.lambda
745    }
746
747    /// Return the optional target bits-per-symbol set by `fit_with_target_rate`.
748    pub fn target_bits_per_symbol(&self) -> Option<f64> {
749        self.target_bits_per_symbol
750    }
751}
752
753impl Quantizer for EntropyConstrainedQuantizer {
754    fn quantize(&self, value: f32) -> i32 {
755        find_bin(value, &self.bin_edges) as i32
756    }
757
758    fn dequantize(&self, level: i32) -> f32 {
759        let idx = level.clamp(0, (self.reconstruction_values.len() - 1) as i32) as usize;
760        self.reconstruction_values[idx]
761    }
762
763    fn num_levels(&self) -> usize {
764        self.reconstruction_values.len()
765    }
766}
767
768impl SignalTokenizer for EntropyConstrainedQuantizer {
769    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
770        Ok(signal.mapv(|x| self.quantize(x) as f32))
771    }
772
773    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
774        Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
775    }
776
777    fn embed_dim(&self) -> usize {
778        1
779    }
780
781    fn vocab_size(&self) -> usize {
782        self.reconstruction_values.len()
783    }
784}
785
786// ─────────────────────────────────────────────────────────────────────────────
787// Tests
788// ─────────────────────────────────────────────────────────────────────────────
789
790#[cfg(test)]
791mod ecq_tests {
792    use super::*;
793
794    /// Simple pseudo-Gaussian generator (Box-Muller + LCG) — no external deps.
795    fn gaussian_signal(n: usize, seed: u64) -> Array1<f32> {
796        let mut state = seed;
797        let mut next_f32 = move || {
798            state = state
799                .wrapping_mul(6_364_136_223_846_793_005)
800                .wrapping_add(1_442_695_040_888_963_407);
801            (state >> 33) as f32 / u32::MAX as f32
802        };
803        Array1::from_iter((0..n).map(|_| {
804            let u1 = next_f32().max(1e-7);
805            let u2 = next_f32();
806            (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
807        }))
808    }
809
810    #[test]
811    fn fit_lagrangian_convergence_and_monotonicity() {
812        let signal = gaussian_signal(10_000, 42);
813        let num_levels = 8;
814        let q = EntropyConstrainedQuantizer::fit_lagrangian(&signal, num_levels, 0.1, 50, 1e-5)
815            .expect("fit_lagrangian failed");
816
817        // Reconstruction values should be in a reasonable range for N(0,1).
818        for &r in q.reconstruction_values() {
819            assert!(r.abs() <= 4.5, "recon value {r} outside [-4.5, 4.5]");
820        }
821
822        // Strictly monotonic reconstruction values.
823        for i in 1..q.reconstruction_values().len() {
824            assert!(
825                q.reconstruction_values()[i] > q.reconstruction_values()[i - 1],
826                "recon values not monotonic at i={i}: {} <= {}",
827                q.reconstruction_values()[i],
828                q.reconstruction_values()[i - 1]
829            );
830        }
831    }
832
833    #[test]
834    fn rd_tradeoff_bracketed() {
835        let signal = gaussian_signal(10_000, 99);
836        let num_levels = 8;
837        let q_low =
838            EntropyConstrainedQuantizer::fit_lagrangian(&signal, num_levels, 0.01, 100, 1e-6)
839                .expect("fit low-lambda failed");
840        let q_high =
841            EntropyConstrainedQuantizer::fit_lagrangian(&signal, num_levels, 1.0, 100, 1e-6)
842                .expect("fit high-lambda failed");
843
844        let d_low = q_low.empirical_distortion(&signal);
845        let d_high = q_high.empirical_distortion(&signal);
846        let r_low = q_low.compute_entropy_rate(&signal);
847        let r_high = q_high.compute_entropy_rate(&signal);
848
849        // High lambda → more compression (lower rate).
850        assert!(
851            r_high + 1e-6 < r_low,
852            "R-D: high-λ should reduce rate: r_high={r_high} r_low={r_low}"
853        );
854        // High lambda → higher distortion.
855        assert!(
856            d_low + 1e-6 < d_high,
857            "R-D: high-λ should increase distortion: d_low={d_low} d_high={d_high}"
858        );
859    }
860
861    #[test]
862    fn roundtrip_mse_vs_uniform() {
863        let signal = gaussian_signal(10_000, 7);
864        let num_levels = 8;
865
866        // Use a tiny lambda so ECQ operates near standard Lloyd-Max (low-rate
867        // penalty) — the R-D tradeoff is well-exercised in rd_tradeoff_bracketed.
868        let q_ecq =
869            EntropyConstrainedQuantizer::fit_lagrangian(&signal, num_levels, 0.001, 100, 1e-6)
870                .expect("ECQ fit failed");
871        let mse_ecq = q_ecq.empirical_distortion(&signal);
872
873        // Naive uniform quantizer baseline.
874        let sig_min = signal.iter().cloned().fold(f32::INFINITY, f32::min);
875        let sig_max = signal.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
876        let step = (sig_max - sig_min) / num_levels as f32;
877        let mse_uniform: f64 = signal
878            .iter()
879            .map(|&x| {
880                let idx = ((x - sig_min) / step).floor() as usize;
881                let idx = idx.min(num_levels - 1);
882                let r = sig_min + (idx as f32 + 0.5) * step;
883                let d = (x - r) as f64;
884                d * d
885            })
886            .sum::<f64>()
887            / signal.len() as f64;
888
889        assert!(
890            mse_ecq <= mse_uniform * 3.0,
891            "ECQ MSE {mse_ecq} > 3× uniform MSE {mse_uniform}"
892        );
893    }
894
895    #[test]
896    fn determinism() {
897        let signal = gaussian_signal(10_000, 555);
898        let q1 = EntropyConstrainedQuantizer::fit_lagrangian(&signal, 8, 0.1, 50, 1e-5)
899            .expect("first fit failed");
900        let q2 = EntropyConstrainedQuantizer::fit_lagrangian(&signal, 8, 0.1, 50, 1e-5)
901            .expect("second fit failed");
902
903        for (a, b) in q1.bin_edges().iter().zip(q2.bin_edges().iter()) {
904            assert_eq!(
905                a.to_bits(),
906                b.to_bits(),
907                "non-deterministic bin edges: {a} vs {b}"
908            );
909        }
910    }
911
912    #[test]
913    fn fit_with_target_rate_in_range() {
914        let signal = gaussian_signal(10_000, 42);
915        let q = EntropyConstrainedQuantizer::fit_with_target_rate(&signal, 8, 2.5, 20)
916            .expect("fit_with_target_rate failed");
917        let rate = q.compute_entropy_rate(&signal);
918        assert!(
919            (1.5..=3.5).contains(&rate),
920            "target_rate=2.5 produced rate={rate} outside [1.5, 3.5]"
921        );
922    }
923
924    #[test]
925    fn signal_tokenizer_encode_decode_roundtrip() {
926        let signal = gaussian_signal(1_000, 13);
927        let q = EntropyConstrainedQuantizer::fit_lagrangian(&signal, 8, 0.1, 50, 1e-5)
928            .expect("fit failed");
929
930        let tokens = q.encode(&signal).expect("encode failed");
931        assert_eq!(tokens.len(), signal.len());
932
933        let reconstructed = q.decode(&tokens).expect("decode failed");
934        assert_eq!(reconstructed.len(), signal.len());
935
936        // Every reconstructed value must be a valid reconstruction value.
937        for &r in reconstructed.iter() {
938            assert!(
939                q.reconstruction_values().contains(&r),
940                "reconstructed value {r} not in reconstruction_values"
941            );
942        }
943    }
944
945    #[test]
946    fn invalid_config_rejected() {
947        let signal = gaussian_signal(100, 1);
948        assert!(
949            EntropyConstrainedQuantizer::fit_lagrangian(&signal, 1, 0.1, 10, 1e-5).is_err(),
950            "num_levels=1 should be rejected"
951        );
952        let tiny = gaussian_signal(3, 2);
953        assert!(
954            EntropyConstrainedQuantizer::fit_lagrangian(&tiny, 8, 0.1, 10, 1e-5).is_err(),
955            "signal shorter than num_levels should be rejected"
956        );
957    }
958
959    #[test]
960    fn encode_decode_compressed_roundtrip() {
961        let signal = gaussian_signal(1_000, 77);
962        let q = EntropyConstrainedQuantizer::fit_lagrangian(&signal, 8, 0.1, 50, 1e-5)
963            .expect("fit failed");
964
965        let (compressed, sym_count) = q
966            .encode_compressed(&signal)
967            .expect("encode_compressed failed");
968        let reconstructed = q
969            .decode_compressed(&compressed, sym_count)
970            .expect("decode_compressed failed");
971
972        // Lengths must match.
973        assert_eq!(reconstructed.len(), signal.len());
974
975        // Every reconstructed value is a valid reconstruction value.
976        for &r in reconstructed.iter() {
977            assert!(
978                q.reconstruction_values().contains(&r),
979                "decoded value {r} not in reconstruction_values"
980            );
981        }
982    }
983}
984
985#[cfg(test)]
986mod tests {
987    use super::*;
988
989    #[test]
990    fn test_adaptive_quantizer() {
991        let quant = AdaptiveQuantizer::new(8, 16, 0.5, -1.0, 1.0).unwrap();
992
993        let signal = Array1::from_vec((0..128).map(|i| ((i as f32) * 0.05).sin()).collect());
994
995        let encoded = quant.encode(&signal).unwrap();
996        assert_eq!(encoded.len(), 128);
997
998        let decoded = quant.decode(&encoded).unwrap();
999        assert_eq!(decoded.len(), 128);
1000    }
1001
1002    #[test]
1003    fn test_dead_zone_quantizer() {
1004        let quant = DeadZoneQuantizer::new(8, 0.1, -1.0, 1.0).unwrap();
1005
1006        // Test dead zone behavior
1007        let level = quant.quantize(0.05);
1008        let recovered = quant.dequantize(level);
1009        assert_eq!(recovered, 0.0); // Should be in dead zone
1010
1011        // Test outside dead zone
1012        let level = quant.quantize(0.5);
1013        let recovered = quant.dequantize(level);
1014        assert!(recovered.abs() > 0.1);
1015    }
1016
1017    #[test]
1018    fn test_dead_zone_signal() {
1019        let quant = DeadZoneQuantizer::new(8, 0.2, -1.0, 1.0).unwrap();
1020
1021        // Signal with small values that should be zeroed
1022        let signal = Array1::from_vec(vec![0.01, 0.5, -0.1, 0.8, 0.05]);
1023
1024        let encoded = quant.encode(&signal).unwrap();
1025        let decoded = quant.decode(&encoded).unwrap();
1026
1027        // Small values should become zero
1028        assert_eq!(decoded[0], 0.0);
1029        assert_eq!(decoded[2], 0.0);
1030        assert_eq!(decoded[4], 0.0);
1031
1032        // Large values should be preserved (approximately)
1033        assert!(decoded[1] > 0.3);
1034        assert!(decoded[3] > 0.6);
1035    }
1036
1037    #[test]
1038    fn test_nonuniform_quantizer() {
1039        let edges = vec![-2.0, -0.5, 0.0, 0.5, 2.0];
1040        let quant = NonUniformQuantizer::from_edges(edges).unwrap();
1041
1042        assert_eq!(quant.num_levels(), 4);
1043
1044        let level = quant.quantize(-1.0);
1045        assert_eq!(level, 0);
1046
1047        let level = quant.quantize(0.25);
1048        assert_eq!(level, 2);
1049    }
1050
1051    #[test]
1052    fn test_lloyd_max_quantizer() {
1053        let quant = NonUniformQuantizer::lloyd_max_gaussian(8, 1.0).unwrap();
1054
1055        assert_eq!(quant.num_levels(), 8);
1056
1057        // Test symmetry
1058        let level_pos = quant.quantize(0.5);
1059        let level_neg = quant.quantize(-0.5);
1060        let val_pos = quant.dequantize(level_pos);
1061        let val_neg = quant.dequantize(level_neg);
1062
1063        assert!((val_pos + val_neg).abs() < 0.5); // Should be roughly symmetric
1064    }
1065
1066    #[test]
1067    fn test_adaptive_vs_uniform() {
1068        let adaptive = AdaptiveQuantizer::new(6, 8, 0.8, -1.0, 1.0).unwrap();
1069
1070        // Signal with varying local statistics
1071        let mut signal_vec = Vec::new();
1072        // Low variance region
1073        for i in 0..64 {
1074            signal_vec.push(0.1 * (i as f32 * 0.05).sin());
1075        }
1076        // High variance region
1077        for i in 64..128 {
1078            signal_vec.push(0.8 * (i as f32 * 0.1).sin());
1079        }
1080
1081        let signal = Array1::from_vec(signal_vec);
1082        let encoded = adaptive.encode(&signal).unwrap();
1083
1084        assert_eq!(encoded.len(), 128);
1085    }
1086
1087    #[test]
1088    fn test_nonuniform_with_custom_values() {
1089        let edges = vec![-1.0, -0.3, 0.0, 0.3, 1.0];
1090        let recon = vec![-0.7, -0.15, 0.15, 0.7];
1091
1092        let quant = NonUniformQuantizer::new(edges, recon).unwrap();
1093
1094        let level = quant.quantize(0.1);
1095        let value = quant.dequantize(level);
1096        assert!((value - 0.15).abs() < 0.01);
1097    }
1098}