1use crate::error::{TokenizerError, TokenizerResult};
10use crate::{Quantizer, SignalTokenizer};
11use scirs2_core::ndarray::Array1;
12
13#[derive(Debug, Clone)]
18pub struct AdaptiveQuantizer {
19 _bits: u8,
21 levels: usize,
23 window_size: usize,
25 adaptation_strength: f32,
27 global_min: f32,
29 global_max: f32,
30}
31
32impl AdaptiveQuantizer {
33 pub fn new(
35 bits: u8,
36 window_size: usize,
37 adaptation_strength: f32,
38 global_min: f32,
39 global_max: f32,
40 ) -> TokenizerResult<Self> {
41 if bits == 0 || bits > 16 {
42 return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
43 }
44 if window_size == 0 {
45 return Err(TokenizerError::InvalidConfig(
46 "window_size must be positive".into(),
47 ));
48 }
49 if !(0.0..=1.0).contains(&adaptation_strength) {
50 return Err(TokenizerError::InvalidConfig(
51 "adaptation_strength must be in [0, 1]".into(),
52 ));
53 }
54
55 Ok(Self {
56 _bits: bits,
57 levels: 1usize << bits,
58 window_size,
59 adaptation_strength,
60 global_min,
61 global_max,
62 })
63 }
64
65 fn local_variance(&self, signal: &Array1<f32>, pos: usize) -> f32 {
67 let half_window = self.window_size / 2;
68 let start = pos.saturating_sub(half_window);
69 let end = (pos + half_window).min(signal.len());
70
71 let window: Vec<f32> = signal
72 .iter()
73 .skip(start)
74 .take(end - start)
75 .cloned()
76 .collect();
77 if window.is_empty() {
78 return 1.0;
79 }
80
81 let mean = window.iter().sum::<f32>() / window.len() as f32;
82 let variance = window.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / window.len() as f32;
83
84 variance.sqrt().max(1e-6) }
86
87 fn adaptive_step(&self, signal: &Array1<f32>, pos: usize) -> f32 {
89 let base_step = (self.global_max - self.global_min) / self.levels as f32;
90 let local_std = self.local_variance(signal, pos);
91
92 let global_std = (self.global_max - self.global_min) / 4.0; let scale = 1.0 + self.adaptation_strength * (local_std / global_std - 1.0);
95
96 base_step * scale.clamp(0.1, 10.0) }
98
99 pub fn quantize_adaptive(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<i32>> {
101 let mut result = Vec::with_capacity(signal.len());
102
103 for (i, &value) in signal.iter().enumerate() {
104 let step = self.adaptive_step(signal, i);
105 let clamped = value.clamp(self.global_min, self.global_max);
106 let normalized = (clamped - self.global_min) / (self.global_max - self.global_min);
107 let level = (normalized / step * (self.levels - 1) as f32).round() as i32;
108 result.push(level.clamp(0, (self.levels - 1) as i32));
109 }
110
111 Ok(Array1::from_vec(result))
112 }
113}
114
115impl Quantizer for AdaptiveQuantizer {
116 fn quantize(&self, value: f32) -> i32 {
117 let clamped = value.clamp(self.global_min, self.global_max);
119 let normalized = (clamped - self.global_min) / (self.global_max - self.global_min);
120 (normalized * (self.levels - 1) as f32).round() as i32
121 }
122
123 fn dequantize(&self, level: i32) -> f32 {
124 let clamped_level = level.clamp(0, (self.levels - 1) as i32);
125 let normalized = clamped_level as f32 / (self.levels - 1) as f32;
126 self.global_min + normalized * (self.global_max - self.global_min)
127 }
128
129 fn num_levels(&self) -> usize {
130 self.levels
131 }
132}
133
134#[derive(Debug, Clone)]
139pub struct DeadZoneQuantizer {
140 _base_bits: u8,
142 levels: usize,
143 dead_zone: f32,
145 min: f32,
147 max: f32,
148}
149
150impl DeadZoneQuantizer {
151 pub fn new(bits: u8, dead_zone: f32, min: f32, max: f32) -> TokenizerResult<Self> {
158 if bits == 0 || bits > 16 {
159 return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
160 }
161 if dead_zone < 0.0 {
162 return Err(TokenizerError::InvalidConfig(
163 "dead_zone must be non-negative".into(),
164 ));
165 }
166
167 Ok(Self {
168 _base_bits: bits,
169 levels: 1usize << bits,
170 dead_zone,
171 min,
172 max,
173 })
174 }
175}
176
177impl Quantizer for DeadZoneQuantizer {
178 fn quantize(&self, value: f32) -> i32 {
179 if value.abs() < self.dead_zone {
181 return (self.levels / 2) as i32; }
183
184 let clamped = value.clamp(self.min, self.max);
186 let normalized = (clamped - self.min) / (self.max - self.min);
187 (normalized * (self.levels - 1) as f32).round() as i32
188 }
189
190 fn dequantize(&self, level: i32) -> f32 {
191 let clamped_level = level.clamp(0, (self.levels - 1) as i32);
192
193 if clamped_level == (self.levels / 2) as i32 {
195 return 0.0;
196 }
197
198 let normalized = clamped_level as f32 / (self.levels - 1) as f32;
199 self.min + normalized * (self.max - self.min)
200 }
201
202 fn num_levels(&self) -> usize {
203 self.levels
204 }
205}
206
207#[derive(Debug, Clone)]
211pub struct NonUniformQuantizer {
212 bin_edges: Vec<f32>,
214 reconstruction_values: Vec<f32>,
216}
217
218impl NonUniformQuantizer {
219 pub fn from_edges(mut bin_edges: Vec<f32>) -> TokenizerResult<Self> {
223 if bin_edges.len() < 2 {
224 return Err(TokenizerError::InvalidConfig(
225 "Need at least 2 bin edges".into(),
226 ));
227 }
228
229 bin_edges.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
230
231 let mut reconstruction_values = Vec::with_capacity(bin_edges.len() - 1);
233 for i in 0..bin_edges.len() - 1 {
234 reconstruction_values.push((bin_edges[i] + bin_edges[i + 1]) / 2.0);
235 }
236
237 Ok(Self {
238 bin_edges,
239 reconstruction_values,
240 })
241 }
242
243 pub fn new(bin_edges: Vec<f32>, reconstruction_values: Vec<f32>) -> TokenizerResult<Self> {
245 if bin_edges.len() != reconstruction_values.len() + 1 {
246 return Err(TokenizerError::InvalidConfig(
247 "bin_edges.len() must equal reconstruction_values.len() + 1".into(),
248 ));
249 }
250
251 Ok(Self {
252 bin_edges,
253 reconstruction_values,
254 })
255 }
256
257 pub fn lloyd_max_gaussian(num_levels: usize, sigma: f32) -> TokenizerResult<Self> {
261 if num_levels < 2 {
262 return Err(TokenizerError::InvalidConfig(
263 "num_levels must be at least 2".into(),
264 ));
265 }
266
267 let mut bin_edges = Vec::with_capacity(num_levels + 1);
269 let mut reconstruction_values = Vec::with_capacity(num_levels);
270
271 for i in 0..=num_levels {
273 let p = i as f32 / num_levels as f32;
274 let z = if p < 0.5 {
276 -((1.0 - 2.0 * p).sqrt() - 1.0)
277 } else {
278 (2.0 * p - 1.0).sqrt() - 1.0
279 };
280 bin_edges.push(z * sigma);
281 }
282
283 for i in 0..num_levels {
285 reconstruction_values.push((bin_edges[i] + bin_edges[i + 1]) / 2.0);
286 }
287
288 Ok(Self {
289 bin_edges,
290 reconstruction_values,
291 })
292 }
293}
294
295impl Quantizer for NonUniformQuantizer {
296 fn quantize(&self, value: f32) -> i32 {
297 for (i, &edge) in self.bin_edges.iter().enumerate().skip(1) {
299 if value < edge {
300 return (i - 1) as i32;
301 }
302 }
303 (self.reconstruction_values.len() - 1) as i32
304 }
305
306 fn dequantize(&self, level: i32) -> f32 {
307 let idx = level.clamp(0, (self.reconstruction_values.len() - 1) as i32) as usize;
308 self.reconstruction_values[idx]
309 }
310
311 fn num_levels(&self) -> usize {
312 self.reconstruction_values.len()
313 }
314}
315
316impl SignalTokenizer for AdaptiveQuantizer {
319 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
320 let quantized = self.quantize_adaptive(signal)?;
321 Ok(quantized.mapv(|x| x as f32))
322 }
323
324 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
325 Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
326 }
327
328 fn embed_dim(&self) -> usize {
329 1
330 }
331
332 fn vocab_size(&self) -> usize {
333 self.levels
334 }
335}
336
337impl SignalTokenizer for DeadZoneQuantizer {
338 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
339 Ok(signal.mapv(|x| self.quantize(x) as f32))
340 }
341
342 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
343 Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
344 }
345
346 fn embed_dim(&self) -> usize {
347 1
348 }
349
350 fn vocab_size(&self) -> usize {
351 self.levels
352 }
353}
354
355impl SignalTokenizer for NonUniformQuantizer {
356 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
357 Ok(signal.mapv(|x| self.quantize(x) as f32))
358 }
359
360 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
361 Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
362 }
363
364 fn embed_dim(&self) -> usize {
365 1
366 }
367
368 fn vocab_size(&self) -> usize {
369 self.reconstruction_values.len()
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_adaptive_quantizer() {
379 let quant = AdaptiveQuantizer::new(8, 16, 0.5, -1.0, 1.0).unwrap();
380
381 let signal = Array1::from_vec((0..128).map(|i| ((i as f32) * 0.05).sin()).collect());
382
383 let encoded = quant.encode(&signal).unwrap();
384 assert_eq!(encoded.len(), 128);
385
386 let decoded = quant.decode(&encoded).unwrap();
387 assert_eq!(decoded.len(), 128);
388 }
389
390 #[test]
391 fn test_dead_zone_quantizer() {
392 let quant = DeadZoneQuantizer::new(8, 0.1, -1.0, 1.0).unwrap();
393
394 let level = quant.quantize(0.05);
396 let recovered = quant.dequantize(level);
397 assert_eq!(recovered, 0.0); let level = quant.quantize(0.5);
401 let recovered = quant.dequantize(level);
402 assert!(recovered.abs() > 0.1);
403 }
404
405 #[test]
406 fn test_dead_zone_signal() {
407 let quant = DeadZoneQuantizer::new(8, 0.2, -1.0, 1.0).unwrap();
408
409 let signal = Array1::from_vec(vec![0.01, 0.5, -0.1, 0.8, 0.05]);
411
412 let encoded = quant.encode(&signal).unwrap();
413 let decoded = quant.decode(&encoded).unwrap();
414
415 assert_eq!(decoded[0], 0.0);
417 assert_eq!(decoded[2], 0.0);
418 assert_eq!(decoded[4], 0.0);
419
420 assert!(decoded[1] > 0.3);
422 assert!(decoded[3] > 0.6);
423 }
424
425 #[test]
426 fn test_nonuniform_quantizer() {
427 let edges = vec![-2.0, -0.5, 0.0, 0.5, 2.0];
428 let quant = NonUniformQuantizer::from_edges(edges).unwrap();
429
430 assert_eq!(quant.num_levels(), 4);
431
432 let level = quant.quantize(-1.0);
433 assert_eq!(level, 0);
434
435 let level = quant.quantize(0.25);
436 assert_eq!(level, 2);
437 }
438
439 #[test]
440 fn test_lloyd_max_quantizer() {
441 let quant = NonUniformQuantizer::lloyd_max_gaussian(8, 1.0).unwrap();
442
443 assert_eq!(quant.num_levels(), 8);
444
445 let level_pos = quant.quantize(0.5);
447 let level_neg = quant.quantize(-0.5);
448 let val_pos = quant.dequantize(level_pos);
449 let val_neg = quant.dequantize(level_neg);
450
451 assert!((val_pos + val_neg).abs() < 0.5); }
453
454 #[test]
455 fn test_adaptive_vs_uniform() {
456 let adaptive = AdaptiveQuantizer::new(6, 8, 0.8, -1.0, 1.0).unwrap();
457
458 let mut signal_vec = Vec::new();
460 for i in 0..64 {
462 signal_vec.push(0.1 * (i as f32 * 0.05).sin());
463 }
464 for i in 64..128 {
466 signal_vec.push(0.8 * (i as f32 * 0.1).sin());
467 }
468
469 let signal = Array1::from_vec(signal_vec);
470 let encoded = adaptive.encode(&signal).unwrap();
471
472 assert_eq!(encoded.len(), 128);
473 }
474
475 #[test]
476 fn test_nonuniform_with_custom_values() {
477 let edges = vec![-1.0, -0.3, 0.0, 0.3, 1.0];
478 let recon = vec![-0.7, -0.15, 0.15, 0.7];
479
480 let quant = NonUniformQuantizer::new(edges, recon).unwrap();
481
482 let level = quant.quantize(0.1);
483 let value = quant.dequantize(level);
484 assert!((value - 0.15).abs() < 0.01);
485 }
486}