kizzasi_tokenizer/
quantizer.rs1use crate::error::{TokenizerError, TokenizerResult};
7use crate::SignalTokenizer;
8use scirs2_core::ndarray::Array1;
9
10pub trait Quantizer {
12 fn quantize(&self, value: f32) -> i32;
14
15 fn dequantize(&self, level: i32) -> f32;
17
18 fn num_levels(&self) -> usize;
20}
21
22#[derive(Debug, Clone)]
24pub struct LinearQuantizer {
25 min: f32,
27 max: f32,
29 bits: u8,
31 levels: usize,
33 step: f32,
35}
36
37impl LinearQuantizer {
38 pub fn new(min: f32, max: f32, bits: u8) -> TokenizerResult<Self> {
40 if min >= max {
41 return Err(TokenizerError::InvalidConfig(
42 "min must be less than max".into(),
43 ));
44 }
45 if bits == 0 || bits > 16 {
46 return Err(TokenizerError::InvalidConfig("bits must be 1-16".into()));
47 }
48
49 let levels = 1usize << bits;
50 let step = (max - min) / levels as f32;
51
52 Ok(Self {
53 min,
54 max,
55 bits,
56 levels,
57 step,
58 })
59 }
60
61 pub fn normalized(bits: u8) -> TokenizerResult<Self> {
63 Self::new(-1.0, 1.0, bits)
64 }
65
66 pub fn range(&self) -> (f32, f32) {
68 (self.min, self.max)
69 }
70
71 pub fn bits(&self) -> u8 {
73 self.bits
74 }
75
76 pub fn step_size(&self) -> f32 {
78 self.step
79 }
80}
81
82impl Quantizer for LinearQuantizer {
83 fn quantize(&self, value: f32) -> i32 {
84 let clamped = value.clamp(self.min, self.max);
85 let normalized = (clamped - self.min) / (self.max - self.min);
86 (normalized * (self.levels - 1) as f32).round() as i32
87 }
88
89 fn dequantize(&self, level: i32) -> f32 {
90 let clamped_level = level.clamp(0, (self.levels - 1) as i32);
91 let normalized = clamped_level as f32 / (self.levels - 1) as f32;
92 self.min + normalized * (self.max - self.min)
93 }
94
95 fn num_levels(&self) -> usize {
96 self.levels
97 }
98}
99
100impl SignalTokenizer for LinearQuantizer {
101 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
102 Ok(signal.mapv(|x| self.quantize(x) as f32))
104 }
105
106 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
107 Ok(tokens.mapv(|t| self.dequantize(t.round() as i32)))
108 }
109
110 fn embed_dim(&self) -> usize {
111 1
112 }
113
114 fn vocab_size(&self) -> usize {
115 self.levels
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn test_linear_quantizer() {
125 let q = LinearQuantizer::new(-1.0, 1.0, 8).unwrap();
126
127 assert_eq!(q.num_levels(), 256);
128
129 let level = q.quantize(0.0);
131 assert!((level - 127).abs() <= 1);
132
133 assert_eq!(q.quantize(-1.0), 0);
135 assert_eq!(q.quantize(1.0), 255);
136 }
137
138 #[test]
139 fn test_roundtrip() {
140 let q = LinearQuantizer::new(0.0, 100.0, 10).unwrap();
141
142 for value in [0.0, 25.0, 50.0, 75.0, 100.0] {
143 let level = q.quantize(value);
144 let recovered = q.dequantize(level);
145 assert!(
146 (recovered - value).abs() < 0.2,
147 "Roundtrip failed for {} -> {} -> {}",
148 value,
149 level,
150 recovered
151 );
152 }
153 }
154
155 #[test]
156 fn test_clamping() {
157 let q = LinearQuantizer::new(-1.0, 1.0, 8).unwrap();
158
159 assert_eq!(q.quantize(-2.0), 0);
161 assert_eq!(q.quantize(2.0), 255);
162 }
163}