1use crate::error::{TokenizerError, TokenizerResult};
10use crate::{Quantizer, SignalTokenizer};
11use scirs2_core::ndarray::Array1;
12
13#[derive(Debug, Clone)]
18pub struct AdaptiveQuantizer {
19 _bits: u8,
21 levels: usize,
23 window_size: usize,
25 adaptation_strength: f32,
27 global_min: f32,
29 global_max: f32,
30}
31
32impl AdaptiveQuantizer {
33 pub fn new(
35 bits: u8,
36 window_size: usize,
37 adaptation_strength: f32,
38 global_min: f32,
39 global_max: f32,
40 ) -> TokenizerResult<Self> {
41 if bits == 0 || bits > 16 {
42 return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
43 }
44 if window_size == 0 {
45 return Err(TokenizerError::InvalidConfig(
46 "window_size must be positive".into(),
47 ));
48 }
49 if !(0.0..=1.0).contains(&adaptation_strength) {
50 return Err(TokenizerError::InvalidConfig(
51 "adaptation_strength must be in [0, 1]".into(),
52 ));
53 }
54
55 Ok(Self {
56 _bits: bits,
57 levels: 1usize << bits,
58 window_size,
59 adaptation_strength,
60 global_min,
61 global_max,
62 })
63 }
64
65 fn local_variance(&self, signal: &Array1<f32>, pos: usize) -> f32 {
67 let half_window = self.window_size / 2;
68 let start = pos.saturating_sub(half_window);
69 let end = (pos + half_window).min(signal.len());
70
71 let window: Vec<f32> = signal
72 .iter()
73 .skip(start)
74 .take(end - start)
75 .cloned()
76 .collect();
77 if window.is_empty() {
78 return 1.0;
79 }
80
81 let mean = window.iter().sum::<f32>() / window.len() as f32;
82 let variance = window.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / window.len() as f32;
83
84 variance.sqrt().max(1e-6) }
86
87 fn adaptive_step(&self, signal: &Array1<f32>, pos: usize) -> f32 {
89 let base_step = (self.global_max - self.global_min) / self.levels as f32;
90 let local_std = self.local_variance(signal, pos);
91
92 let global_std = (self.global_max - self.global_min) / 4.0; let scale = 1.0 + self.adaptation_strength * (local_std / global_std - 1.0);
95
96 base_step * scale.clamp(0.1, 10.0) }
98
99 pub fn quantize_adaptive(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<i32>> {
101 let mut result = Vec::with_capacity(signal.len());
102
103 for (i, &value) in signal.iter().enumerate() {
104 let step = self.adaptive_step(signal, i);
105 let clamped = value.clamp(self.global_min, self.global_max);
106 let normalized = (clamped - self.global_min) / (self.global_max - self.global_min);
107 let level = (normalized / step * (self.levels - 1) as f32).round() as i32;
108 result.push(level.clamp(0, (self.levels - 1) as i32));
109 }
110
111 Ok(Array1::from_vec(result))
112 }
113}
114
115impl Quantizer for AdaptiveQuantizer {
116 fn quantize(&self, value: f32) -> i32 {
117 let clamped = value.clamp(self.global_min, self.global_max);
119 let normalized = (clamped - self.global_min) / (self.global_max - self.global_min);
120 (normalized * (self.levels - 1) as f32).round() as i32
121 }
122
123 fn dequantize(&self, level: i32) -> f32 {
124 let clamped_level = level.clamp(0, (self.levels - 1) as i32);
125 let normalized = clamped_level as f32 / (self.levels - 1) as f32;
126 self.global_min + normalized * (self.global_max - self.global_min)
127 }
128
129 fn num_levels(&self) -> usize {
130 self.levels
131 }
132}
133
134#[derive(Debug, Clone)]
139pub struct DeadZoneQuantizer {
140 _base_bits: u8,
142 levels: usize,
143 dead_zone: f32,
145 min: f32,
147 max: f32,
148}
149
150impl DeadZoneQuantizer {
151 pub fn new(bits: u8, dead_zone: f32, min: f32, max: f32) -> TokenizerResult<Self> {
158 if bits == 0 || bits > 16 {
159 return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
160 }
161 if dead_zone < 0.0 {
162 return Err(TokenizerError::InvalidConfig(
163 "dead_zone must be non-negative".into(),
164 ));
165 }
166
167 Ok(Self {
168 _base_bits: bits,
169 levels: 1usize << bits,
170 dead_zone,
171 min,
172 max,
173 })
174 }
175}
176
177impl Quantizer for DeadZoneQuantizer {
178 fn quantize(&self, value: f32) -> i32 {
179 if value.abs() < self.dead_zone {
181 return (self.levels / 2) as i32; }
183
184 let clamped = value.clamp(self.min, self.max);
186 let normalized = (clamped - self.min) / (self.max - self.min);
187 (normalized * (self.levels - 1) as f32).round() as i32
188 }
189
190 fn dequantize(&self, level: i32) -> f32 {
191 let clamped_level = level.clamp(0, (self.levels - 1) as i32);
192
193 if clamped_level == (self.levels / 2) as i32 {
195 return 0.0;
196 }
197
198 let normalized = clamped_level as f32 / (self.levels - 1) as f32;
199 self.min + normalized * (self.max - self.min)
200 }
201
202 fn num_levels(&self) -> usize {
203 self.levels
204 }
205}
206
207#[derive(Debug, Clone)]
211pub struct NonUniformQuantizer {
212 bin_edges: Vec<f32>,
214 reconstruction_values: Vec<f32>,
216}
217
218impl NonUniformQuantizer {
219 pub fn from_edges(mut bin_edges: Vec<f32>) -> TokenizerResult<Self> {
223 if bin_edges.len() < 2 {
224 return Err(TokenizerError::InvalidConfig(
225 "Need at least 2 bin edges".into(),
226 ));
227 }
228
229 bin_edges.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
230
231 let mut reconstruction_values = Vec::with_capacity(bin_edges.len() - 1);
233 for i in 0..bin_edges.len() - 1 {
234 reconstruction_values.push((bin_edges[i] + bin_edges[i + 1]) / 2.0);
235 }
236
237 Ok(Self {
238 bin_edges,
239 reconstruction_values,
240 })
241 }
242
243 pub fn new(bin_edges: Vec<f32>, reconstruction_values: Vec<f32>) -> TokenizerResult<Self> {
245 if bin_edges.len() != reconstruction_values.len() + 1 {
246 return Err(TokenizerError::InvalidConfig(
247 "bin_edges.len() must equal reconstruction_values.len() + 1".into(),
248 ));
249 }
250
251 Ok(Self {
252 bin_edges,
253 reconstruction_values,
254 })
255 }
256
257 pub fn lloyd_max_gaussian(num_levels: usize, sigma: f32) -> TokenizerResult<Self> {
261 if num_levels < 2 {
262 return Err(TokenizerError::InvalidConfig(
263 "num_levels must be at least 2".into(),
264 ));
265 }
266
267 let mut bin_edges = Vec::with_capacity(num_levels + 1);
269 let mut reconstruction_values = Vec::with_capacity(num_levels);
270
271 for i in 0..=num_levels {
273 let p = i as f32 / num_levels as f32;
274 let z = if p < 0.5 {
276 -((1.0 - 2.0 * p).sqrt() - 1.0)
277 } else {
278 (2.0 * p - 1.0).sqrt() - 1.0
279 };
280 bin_edges.push(z * sigma);
281 }
282
283 for i in 0..num_levels {
285 reconstruction_values.push((bin_edges[i] + bin_edges[i + 1]) / 2.0);
286 }
287
288 Ok(Self {
289 bin_edges,
290 reconstruction_values,
291 })
292 }
293}
294
295impl Quantizer for NonUniformQuantizer {
296 fn quantize(&self, value: f32) -> i32 {
297 for (i, &edge) in self.bin_edges.iter().enumerate().skip(1) {
299 if value < edge {
300 return (i - 1) as i32;
301 }
302 }
303 (self.reconstruction_values.len() - 1) as i32
304 }
305
306 fn dequantize(&self, level: i32) -> f32 {
307 let idx = level.clamp(0, (self.reconstruction_values.len() - 1) as i32) as usize;
308 self.reconstruction_values[idx]
309 }
310
311 fn num_levels(&self) -> usize {
312 self.reconstruction_values.len()
313 }
314}
315
316impl SignalTokenizer for AdaptiveQuantizer {
319 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
320 let quantized = self.quantize_adaptive(signal)?;
321 Ok(quantized.mapv(|x| x as f32))
322 }
323
324 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
325 Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
326 }
327
328 fn embed_dim(&self) -> usize {
329 1
330 }
331
332 fn vocab_size(&self) -> usize {
333 self.levels
334 }
335}
336
337impl SignalTokenizer for DeadZoneQuantizer {
338 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
339 Ok(signal.mapv(|x| self.quantize(x) as f32))
340 }
341
342 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
343 Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
344 }
345
346 fn embed_dim(&self) -> usize {
347 1
348 }
349
350 fn vocab_size(&self) -> usize {
351 self.levels
352 }
353}
354
355impl SignalTokenizer for NonUniformQuantizer {
356 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
357 Ok(signal.mapv(|x| self.quantize(x) as f32))
358 }
359
360 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
361 Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
362 }
363
364 fn embed_dim(&self) -> usize {
365 1
366 }
367
368 fn vocab_size(&self) -> usize {
369 self.reconstruction_values.len()
370 }
371}
372
373fn find_bin(x: f32, edges: &[f32]) -> usize {
378 let mut lo = 0usize;
379 let mut hi = edges.len(); while lo < hi {
381 let mid = lo + (hi - lo) / 2;
382 if x <= edges[mid] {
383 hi = mid;
384 } else {
385 lo = mid + 1;
386 }
387 }
388 lo }
390
391pub struct EntropyConstrainedQuantizer {
409 bin_edges: Vec<f32>,
411 reconstruction_values: Vec<f32>,
413 lambda: f32,
415 target_bits_per_symbol: Option<f64>,
417 empirical_probs: Vec<f64>,
419}
420
421impl EntropyConstrainedQuantizer {
422 pub fn new(bin_edges: Vec<f32>, reconstruction_values: Vec<f32>, lambda: f32) -> Self {
426 let n = reconstruction_values.len();
427 Self {
428 bin_edges,
429 reconstruction_values,
430 lambda,
431 target_bits_per_symbol: None,
432 empirical_probs: vec![1.0 / n as f64; n],
433 }
434 }
435
436 pub fn fit_lagrangian(
448 signal: &Array1<f32>,
449 num_levels: usize,
450 lambda: f32,
451 max_iters: usize,
452 tol: f32,
453 ) -> TokenizerResult<Self> {
454 if num_levels < 2 {
455 return Err(TokenizerError::InvalidConfig(
456 "num_levels must be >= 2".into(),
457 ));
458 }
459 if signal.len() < num_levels {
460 return Err(TokenizerError::InvalidConfig(
461 "signal is too short for the requested num_levels".into(),
462 ));
463 }
464
465 let n = signal.len();
466 let sig_min = signal.iter().cloned().fold(f32::INFINITY, f32::min);
467 let sig_max = signal.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
468 let range = (sig_max - sig_min).max(1e-6);
469 let min_gap = 1e-6 * range;
470
471 let mut sorted_signal: Vec<f32> = signal.iter().cloned().collect();
473 sorted_signal.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
474
475 let mut bin_edges: Vec<f32> = (1..num_levels)
477 .map(|i| {
478 let idx = (i * n / num_levels).min(n - 1);
479 sorted_signal[idx]
480 })
481 .collect();
482
483 let mut recon: Vec<f32> = {
486 let mut r = Vec::with_capacity(num_levels);
487 r.push((sig_min + bin_edges[0]) * 0.5);
488 for i in 1..num_levels - 1 {
489 r.push((bin_edges[i - 1] + bin_edges[i]) * 0.5);
490 }
491 r.push((bin_edges[num_levels - 2] + sig_max) * 0.5);
492 r
493 };
494
495 let mut probs = vec![1.0f64 / num_levels as f64; num_levels];
496 let mut prev_cost = f64::INFINITY;
497
498 for _iter in 0..max_iters {
499 let mut sums = vec![0.0f64; num_levels];
501 let mut counts = vec![0usize; num_levels];
502 for &x in signal.iter() {
503 let b = find_bin(x, &bin_edges);
504 sums[b] += x as f64;
505 counts[b] += 1;
506 }
507 for i in 0..num_levels {
508 if counts[i] > 0 {
509 recon[i] = (sums[i] / counts[i] as f64) as f32;
510 }
511 }
513
514 for i in 1..num_levels {
517 if recon[i] <= recon[i - 1] + min_gap {
518 recon[i] = recon[i - 1] + min_gap;
519 }
520 }
521
522 let eps = 1e-10;
524 let denom = n as f64 + num_levels as f64 * eps;
525 for i in 0..num_levels {
526 probs[i] = (counts[i] as f64 + eps) / denom;
527 }
528
529 for i in 0..num_levels - 1 {
531 let r_left = recon[i];
532 let r_right = recon[i + 1];
533 let gap = (r_right - r_left).max(min_gap);
534 let p_left = probs[i].max(eps);
535 let p_right = probs[i + 1].max(eps);
536 let entropy_term = (lambda / gap) * (p_left.ln() - p_right.ln()) as f32;
537 bin_edges[i] = 0.5 * (r_left + r_right) + entropy_term;
538 }
539
540 for i in 1..bin_edges.len() {
542 if bin_edges[i] <= bin_edges[i - 1] + min_gap {
543 bin_edges[i] = bin_edges[i - 1] + min_gap;
544 }
545 }
546
547 let mut new_counts = vec![0usize; num_levels];
549 for &x in signal.iter() {
550 new_counts[find_bin(x, &bin_edges)] += 1;
551 }
552 for i in 0..num_levels {
553 probs[i] = (new_counts[i] as f64 + eps) / denom;
554 }
555
556 let distortion: f64 = signal
558 .iter()
559 .map(|&x| {
560 let b = find_bin(x, &bin_edges);
561 let d = x as f64 - recon[b] as f64;
562 d * d
563 })
564 .sum::<f64>()
565 / n as f64;
566
567 let entropy: f64 = probs
568 .iter()
569 .map(|&p| if p > eps { -p * p.log2() } else { 0.0 })
570 .sum();
571
572 let cost = distortion + lambda as f64 * entropy;
573
574 if (prev_cost - cost).abs() < tol as f64 {
575 break;
576 }
577 prev_cost = cost;
578 }
579
580 Ok(Self {
581 bin_edges,
582 reconstruction_values: recon,
583 lambda,
584 target_bits_per_symbol: None,
585 empirical_probs: probs,
586 })
587 }
588
589 pub fn encode_compressed(&self, signal: &Array1<f32>) -> TokenizerResult<(Vec<u8>, u64)> {
595 use crate::entropy::{compute_frequencies, HuffmanEncoder};
596
597 let symbols: Vec<u32> = signal
598 .iter()
599 .map(|&x| find_bin(x, &self.bin_edges) as u32)
600 .collect();
601
602 let freqs = compute_frequencies(&symbols);
603 let encoder = HuffmanEncoder::from_frequencies(&freqs)?;
604 let compressed = encoder.encode(&symbols)?;
605 let symbol_count = symbols.len() as u64;
606 Ok((compressed, symbol_count))
607 }
608
609 pub fn decode_compressed(
613 &self,
614 bytes: &[u8],
615 _symbol_count: u64,
616 ) -> TokenizerResult<Array1<f32>> {
617 use crate::entropy::{HuffmanDecoder, HuffmanEncoder};
618
619 let n_levels = self.reconstruction_values.len();
622 let total_pseudo = 1_000_000u64; let mut freqs = std::collections::HashMap::new();
624 let mut allocated = 0u64;
625 for i in 0..n_levels {
626 let cnt = (self.empirical_probs[i] * total_pseudo as f64).round() as u64;
627 let cnt = cnt.max(1); freqs.insert(i as u32, cnt);
629 allocated += cnt;
630 }
631 let _ = allocated; let encoder = HuffmanEncoder::from_frequencies(&freqs)?;
636 let decoder = HuffmanDecoder::new(encoder.tree());
637 let indices = decoder.decode(bytes)?;
638
639 let values: Vec<f32> = indices
640 .iter()
641 .map(|&idx| {
642 let b = (idx as usize).min(self.reconstruction_values.len() - 1);
643 self.reconstruction_values[b]
644 })
645 .collect();
646
647 Ok(Array1::from_vec(values))
648 }
649
650 pub fn fit_with_target_rate(
659 signal: &Array1<f32>,
660 num_levels: usize,
661 target_bpp: f64,
662 max_outer_iters: usize,
663 ) -> TokenizerResult<Self> {
664 let mut lambda_lo = 0.0f32;
665 let mut lambda_hi = 10.0f32;
666
667 let mut best = Self::fit_lagrangian(signal, num_levels, lambda_hi, 100, 1e-5)?;
669
670 for _ in 0..max_outer_iters {
671 let lambda_mid = (lambda_lo + lambda_hi) * 0.5;
672 let candidate = Self::fit_lagrangian(signal, num_levels, lambda_mid, 100, 1e-5)?;
673 let rate = candidate.compute_entropy_rate(signal);
674 if rate > target_bpp {
675 lambda_lo = lambda_mid;
677 } else {
678 lambda_hi = lambda_mid;
680 best = candidate;
681 }
682 if (lambda_hi - lambda_lo) < 1e-4 {
683 break;
684 }
685 }
686 best.target_bits_per_symbol = Some(target_bpp);
688 Ok(best)
689 }
690
691 pub fn compute_entropy_rate(&self, signal: &Array1<f32>) -> f64 {
694 let n = signal.len();
695 if n == 0 {
696 return 0.0;
697 }
698 let mut counts = vec![0usize; self.reconstruction_values.len()];
699 for &x in signal.iter() {
700 counts[find_bin(x, &self.bin_edges)] += 1;
701 }
702 counts
703 .iter()
704 .map(|&c| {
705 if c > 0 {
706 let p = c as f64 / n as f64;
707 -p * p.log2()
708 } else {
709 0.0
710 }
711 })
712 .sum()
713 }
714
715 pub fn empirical_distortion(&self, signal: &Array1<f32>) -> f64 {
717 let n = signal.len();
718 if n == 0 {
719 return 0.0;
720 }
721 signal
722 .iter()
723 .map(|&x| {
724 let r = self.reconstruction_values[find_bin(x, &self.bin_edges)];
725 let d = (x - r) as f64;
726 d * d
727 })
728 .sum::<f64>()
729 / n as f64
730 }
731
732 pub fn bin_edges(&self) -> &[f32] {
734 &self.bin_edges
735 }
736
737 pub fn reconstruction_values(&self) -> &[f32] {
739 &self.reconstruction_values
740 }
741
742 pub fn lambda(&self) -> f32 {
744 self.lambda
745 }
746
747 pub fn target_bits_per_symbol(&self) -> Option<f64> {
749 self.target_bits_per_symbol
750 }
751}
752
753impl Quantizer for EntropyConstrainedQuantizer {
754 fn quantize(&self, value: f32) -> i32 {
755 find_bin(value, &self.bin_edges) as i32
756 }
757
758 fn dequantize(&self, level: i32) -> f32 {
759 let idx = level.clamp(0, (self.reconstruction_values.len() - 1) as i32) as usize;
760 self.reconstruction_values[idx]
761 }
762
763 fn num_levels(&self) -> usize {
764 self.reconstruction_values.len()
765 }
766}
767
768impl SignalTokenizer for EntropyConstrainedQuantizer {
769 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
770 Ok(signal.mapv(|x| self.quantize(x) as f32))
771 }
772
773 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
774 Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
775 }
776
777 fn embed_dim(&self) -> usize {
778 1
779 }
780
781 fn vocab_size(&self) -> usize {
782 self.reconstruction_values.len()
783 }
784}
785
786#[cfg(test)]
791mod ecq_tests {
792 use super::*;
793
794 fn gaussian_signal(n: usize, seed: u64) -> Array1<f32> {
796 let mut state = seed;
797 let mut next_f32 = move || {
798 state = state
799 .wrapping_mul(6_364_136_223_846_793_005)
800 .wrapping_add(1_442_695_040_888_963_407);
801 (state >> 33) as f32 / u32::MAX as f32
802 };
803 Array1::from_iter((0..n).map(|_| {
804 let u1 = next_f32().max(1e-7);
805 let u2 = next_f32();
806 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
807 }))
808 }
809
810 #[test]
811 fn fit_lagrangian_convergence_and_monotonicity() {
812 let signal = gaussian_signal(10_000, 42);
813 let num_levels = 8;
814 let q = EntropyConstrainedQuantizer::fit_lagrangian(&signal, num_levels, 0.1, 50, 1e-5)
815 .expect("fit_lagrangian failed");
816
817 for &r in q.reconstruction_values() {
819 assert!(r.abs() <= 4.5, "recon value {r} outside [-4.5, 4.5]");
820 }
821
822 for i in 1..q.reconstruction_values().len() {
824 assert!(
825 q.reconstruction_values()[i] > q.reconstruction_values()[i - 1],
826 "recon values not monotonic at i={i}: {} <= {}",
827 q.reconstruction_values()[i],
828 q.reconstruction_values()[i - 1]
829 );
830 }
831 }
832
833 #[test]
834 fn rd_tradeoff_bracketed() {
835 let signal = gaussian_signal(10_000, 99);
836 let num_levels = 8;
837 let q_low =
838 EntropyConstrainedQuantizer::fit_lagrangian(&signal, num_levels, 0.01, 100, 1e-6)
839 .expect("fit low-lambda failed");
840 let q_high =
841 EntropyConstrainedQuantizer::fit_lagrangian(&signal, num_levels, 1.0, 100, 1e-6)
842 .expect("fit high-lambda failed");
843
844 let d_low = q_low.empirical_distortion(&signal);
845 let d_high = q_high.empirical_distortion(&signal);
846 let r_low = q_low.compute_entropy_rate(&signal);
847 let r_high = q_high.compute_entropy_rate(&signal);
848
849 assert!(
851 r_high + 1e-6 < r_low,
852 "R-D: high-λ should reduce rate: r_high={r_high} r_low={r_low}"
853 );
854 assert!(
856 d_low + 1e-6 < d_high,
857 "R-D: high-λ should increase distortion: d_low={d_low} d_high={d_high}"
858 );
859 }
860
861 #[test]
862 fn roundtrip_mse_vs_uniform() {
863 let signal = gaussian_signal(10_000, 7);
864 let num_levels = 8;
865
866 let q_ecq =
869 EntropyConstrainedQuantizer::fit_lagrangian(&signal, num_levels, 0.001, 100, 1e-6)
870 .expect("ECQ fit failed");
871 let mse_ecq = q_ecq.empirical_distortion(&signal);
872
873 let sig_min = signal.iter().cloned().fold(f32::INFINITY, f32::min);
875 let sig_max = signal.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
876 let step = (sig_max - sig_min) / num_levels as f32;
877 let mse_uniform: f64 = signal
878 .iter()
879 .map(|&x| {
880 let idx = ((x - sig_min) / step).floor() as usize;
881 let idx = idx.min(num_levels - 1);
882 let r = sig_min + (idx as f32 + 0.5) * step;
883 let d = (x - r) as f64;
884 d * d
885 })
886 .sum::<f64>()
887 / signal.len() as f64;
888
889 assert!(
890 mse_ecq <= mse_uniform * 3.0,
891 "ECQ MSE {mse_ecq} > 3× uniform MSE {mse_uniform}"
892 );
893 }
894
895 #[test]
896 fn determinism() {
897 let signal = gaussian_signal(10_000, 555);
898 let q1 = EntropyConstrainedQuantizer::fit_lagrangian(&signal, 8, 0.1, 50, 1e-5)
899 .expect("first fit failed");
900 let q2 = EntropyConstrainedQuantizer::fit_lagrangian(&signal, 8, 0.1, 50, 1e-5)
901 .expect("second fit failed");
902
903 for (a, b) in q1.bin_edges().iter().zip(q2.bin_edges().iter()) {
904 assert_eq!(
905 a.to_bits(),
906 b.to_bits(),
907 "non-deterministic bin edges: {a} vs {b}"
908 );
909 }
910 }
911
912 #[test]
913 fn fit_with_target_rate_in_range() {
914 let signal = gaussian_signal(10_000, 42);
915 let q = EntropyConstrainedQuantizer::fit_with_target_rate(&signal, 8, 2.5, 20)
916 .expect("fit_with_target_rate failed");
917 let rate = q.compute_entropy_rate(&signal);
918 assert!(
919 (1.5..=3.5).contains(&rate),
920 "target_rate=2.5 produced rate={rate} outside [1.5, 3.5]"
921 );
922 }
923
924 #[test]
925 fn signal_tokenizer_encode_decode_roundtrip() {
926 let signal = gaussian_signal(1_000, 13);
927 let q = EntropyConstrainedQuantizer::fit_lagrangian(&signal, 8, 0.1, 50, 1e-5)
928 .expect("fit failed");
929
930 let tokens = q.encode(&signal).expect("encode failed");
931 assert_eq!(tokens.len(), signal.len());
932
933 let reconstructed = q.decode(&tokens).expect("decode failed");
934 assert_eq!(reconstructed.len(), signal.len());
935
936 for &r in reconstructed.iter() {
938 assert!(
939 q.reconstruction_values().contains(&r),
940 "reconstructed value {r} not in reconstruction_values"
941 );
942 }
943 }
944
945 #[test]
946 fn invalid_config_rejected() {
947 let signal = gaussian_signal(100, 1);
948 assert!(
949 EntropyConstrainedQuantizer::fit_lagrangian(&signal, 1, 0.1, 10, 1e-5).is_err(),
950 "num_levels=1 should be rejected"
951 );
952 let tiny = gaussian_signal(3, 2);
953 assert!(
954 EntropyConstrainedQuantizer::fit_lagrangian(&tiny, 8, 0.1, 10, 1e-5).is_err(),
955 "signal shorter than num_levels should be rejected"
956 );
957 }
958
959 #[test]
960 fn encode_decode_compressed_roundtrip() {
961 let signal = gaussian_signal(1_000, 77);
962 let q = EntropyConstrainedQuantizer::fit_lagrangian(&signal, 8, 0.1, 50, 1e-5)
963 .expect("fit failed");
964
965 let (compressed, sym_count) = q
966 .encode_compressed(&signal)
967 .expect("encode_compressed failed");
968 let reconstructed = q
969 .decode_compressed(&compressed, sym_count)
970 .expect("decode_compressed failed");
971
972 assert_eq!(reconstructed.len(), signal.len());
974
975 for &r in reconstructed.iter() {
977 assert!(
978 q.reconstruction_values().contains(&r),
979 "decoded value {r} not in reconstruction_values"
980 );
981 }
982 }
983}
984
985#[cfg(test)]
986mod tests {
987 use super::*;
988
989 #[test]
990 fn test_adaptive_quantizer() {
991 let quant = AdaptiveQuantizer::new(8, 16, 0.5, -1.0, 1.0).unwrap();
992
993 let signal = Array1::from_vec((0..128).map(|i| ((i as f32) * 0.05).sin()).collect());
994
995 let encoded = quant.encode(&signal).unwrap();
996 assert_eq!(encoded.len(), 128);
997
998 let decoded = quant.decode(&encoded).unwrap();
999 assert_eq!(decoded.len(), 128);
1000 }
1001
1002 #[test]
1003 fn test_dead_zone_quantizer() {
1004 let quant = DeadZoneQuantizer::new(8, 0.1, -1.0, 1.0).unwrap();
1005
1006 let level = quant.quantize(0.05);
1008 let recovered = quant.dequantize(level);
1009 assert_eq!(recovered, 0.0); let level = quant.quantize(0.5);
1013 let recovered = quant.dequantize(level);
1014 assert!(recovered.abs() > 0.1);
1015 }
1016
1017 #[test]
1018 fn test_dead_zone_signal() {
1019 let quant = DeadZoneQuantizer::new(8, 0.2, -1.0, 1.0).unwrap();
1020
1021 let signal = Array1::from_vec(vec![0.01, 0.5, -0.1, 0.8, 0.05]);
1023
1024 let encoded = quant.encode(&signal).unwrap();
1025 let decoded = quant.decode(&encoded).unwrap();
1026
1027 assert_eq!(decoded[0], 0.0);
1029 assert_eq!(decoded[2], 0.0);
1030 assert_eq!(decoded[4], 0.0);
1031
1032 assert!(decoded[1] > 0.3);
1034 assert!(decoded[3] > 0.6);
1035 }
1036
1037 #[test]
1038 fn test_nonuniform_quantizer() {
1039 let edges = vec![-2.0, -0.5, 0.0, 0.5, 2.0];
1040 let quant = NonUniformQuantizer::from_edges(edges).unwrap();
1041
1042 assert_eq!(quant.num_levels(), 4);
1043
1044 let level = quant.quantize(-1.0);
1045 assert_eq!(level, 0);
1046
1047 let level = quant.quantize(0.25);
1048 assert_eq!(level, 2);
1049 }
1050
1051 #[test]
1052 fn test_lloyd_max_quantizer() {
1053 let quant = NonUniformQuantizer::lloyd_max_gaussian(8, 1.0).unwrap();
1054
1055 assert_eq!(quant.num_levels(), 8);
1056
1057 let level_pos = quant.quantize(0.5);
1059 let level_neg = quant.quantize(-0.5);
1060 let val_pos = quant.dequantize(level_pos);
1061 let val_neg = quant.dequantize(level_neg);
1062
1063 assert!((val_pos + val_neg).abs() < 0.5); }
1065
1066 #[test]
1067 fn test_adaptive_vs_uniform() {
1068 let adaptive = AdaptiveQuantizer::new(6, 8, 0.8, -1.0, 1.0).unwrap();
1069
1070 let mut signal_vec = Vec::new();
1072 for i in 0..64 {
1074 signal_vec.push(0.1 * (i as f32 * 0.05).sin());
1075 }
1076 for i in 64..128 {
1078 signal_vec.push(0.8 * (i as f32 * 0.1).sin());
1079 }
1080
1081 let signal = Array1::from_vec(signal_vec);
1082 let encoded = adaptive.encode(&signal).unwrap();
1083
1084 assert_eq!(encoded.len(), 128);
1085 }
1086
1087 #[test]
1088 fn test_nonuniform_with_custom_values() {
1089 let edges = vec![-1.0, -0.3, 0.0, 0.3, 1.0];
1090 let recon = vec![-0.7, -0.15, 0.15, 0.7];
1091
1092 let quant = NonUniformQuantizer::new(edges, recon).unwrap();
1093
1094 let level = quant.quantize(0.1);
1095 let value = quant.dequantize(level);
1096 assert!((value - 0.15).abs() < 0.01);
1097 }
1098}