1use crate::error::{InferenceError, InferenceResult};
17use kizzasi_core::HiddenState;
18use scirs2_core::ndarray::Array2;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CompressionMethod {
23 None,
25 Quantize8Bit,
27 Quantize4Bit,
29 Sparse,
31 QuantizedSparse,
33}
34
35#[derive(Debug, Clone)]
37pub struct CompressedState {
38 method: CompressionMethod,
40 data: Vec<u8>,
42 shape: Vec<usize>,
44 scale: f32,
46 zero_point: i32,
48 sparse_indices: Option<Vec<usize>>,
50}
51
52impl CompressedState {
53 pub fn compression_ratio(&self) -> f32 {
55 let original_size = self.shape.iter().product::<usize>() * std::mem::size_of::<f32>();
56 let compressed_size = self.data.len()
57 + self
58 .sparse_indices
59 .as_ref()
60 .map(|v| v.len() * std::mem::size_of::<usize>())
61 .unwrap_or(0);
62 original_size as f32 / compressed_size as f32
63 }
64
65 pub fn method(&self) -> CompressionMethod {
67 self.method
68 }
69}
70
71pub struct StateCompressor {
73 method: CompressionMethod,
74 sparsity_threshold: f32,
76}
77
78impl StateCompressor {
79 pub fn new(method: CompressionMethod) -> Self {
81 Self {
82 method,
83 sparsity_threshold: 1e-4,
84 }
85 }
86
87 pub fn with_sparsity_threshold(mut self, threshold: f32) -> Self {
89 self.sparsity_threshold = threshold;
90 self
91 }
92
93 pub fn compress(&self, state: &HiddenState) -> InferenceResult<CompressedState> {
95 match self.method {
96 CompressionMethod::None => self.compress_none(state),
97 CompressionMethod::Quantize8Bit => self.compress_quantize_8bit(state),
98 CompressionMethod::Quantize4Bit => self.compress_quantize_4bit(state),
99 CompressionMethod::Sparse => self.compress_sparse(state),
100 CompressionMethod::QuantizedSparse => self.compress_quantized_sparse(state),
101 }
102 }
103
104 pub fn decompress(&self, compressed: &CompressedState) -> InferenceResult<HiddenState> {
106 match compressed.method {
107 CompressionMethod::None => self.decompress_none(compressed),
108 CompressionMethod::Quantize8Bit => self.decompress_quantize_8bit(compressed),
109 CompressionMethod::Quantize4Bit => self.decompress_quantize_4bit(compressed),
110 CompressionMethod::Sparse => self.decompress_sparse(compressed),
111 CompressionMethod::QuantizedSparse => self.decompress_quantized_sparse(compressed),
112 }
113 }
114
115 fn compress_none(&self, state: &HiddenState) -> InferenceResult<CompressedState> {
117 let data_vec: Vec<f32> = state.state().iter().copied().collect();
118 let data_bytes: Vec<u8> = data_vec.iter().flat_map(|&f| f.to_le_bytes()).collect();
119
120 let shape_vec: Vec<usize> = state.state().shape().to_vec();
121
122 Ok(CompressedState {
123 method: CompressionMethod::None,
124 data: data_bytes,
125 shape: shape_vec,
126 scale: 1.0,
127 zero_point: 0,
128 sparse_indices: None,
129 })
130 }
131
132 fn decompress_none(&self, compressed: &CompressedState) -> InferenceResult<HiddenState> {
133 let floats: Vec<f32> = compressed
134 .data
135 .chunks_exact(4)
136 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
137 .collect();
138
139 let data = Array2::from_shape_vec((compressed.shape[0], compressed.shape[1]), floats)
140 .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
141
142 let mut hidden = HiddenState::new(compressed.shape[0], compressed.shape[1]);
143 hidden.update(data);
144 Ok(hidden)
145 }
146
147 fn compress_quantize_8bit(&self, state: &HiddenState) -> InferenceResult<CompressedState> {
149 let min_val = state.state().iter().copied().fold(f32::INFINITY, f32::min);
150 let max_val = state
151 .state()
152 .iter()
153 .copied()
154 .fold(f32::NEG_INFINITY, f32::max);
155
156 let scale = (max_val - min_val) / 255.0;
157 let zero_point = (-min_val / scale).round() as i32;
158
159 let quantized: Vec<u8> = state
160 .state()
161 .iter()
162 .map(|&v| {
163 let scaled = (v / scale + zero_point as f32).round();
164 scaled.clamp(0.0, 255.0) as u8
165 })
166 .collect();
167
168 let shape_vec: Vec<usize> = state.state().shape().to_vec();
169
170 Ok(CompressedState {
171 method: CompressionMethod::Quantize8Bit,
172 data: quantized,
173 shape: shape_vec,
174 scale,
175 zero_point,
176 sparse_indices: None,
177 })
178 }
179
180 fn decompress_quantize_8bit(
181 &self,
182 compressed: &CompressedState,
183 ) -> InferenceResult<HiddenState> {
184 let dequantized: Vec<f32> = compressed
185 .data
186 .iter()
187 .map(|&q| (q as f32 - compressed.zero_point as f32) * compressed.scale)
188 .collect();
189
190 let data = Array2::from_shape_vec((compressed.shape[0], compressed.shape[1]), dequantized)
191 .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
192
193 let mut hidden = HiddenState::new(compressed.shape[0], compressed.shape[1]);
194 hidden.update(data);
195 Ok(hidden)
196 }
197
198 fn compress_quantize_4bit(&self, state: &HiddenState) -> InferenceResult<CompressedState> {
200 let min_val = state.state().iter().copied().fold(f32::INFINITY, f32::min);
201 let max_val = state
202 .state()
203 .iter()
204 .copied()
205 .fold(f32::NEG_INFINITY, f32::max);
206
207 let scale = (max_val - min_val) / 15.0;
208 let zero_point = (-min_val / scale).round() as i32;
209
210 let mut quantized = Vec::new();
211 let mut iter = state.state().iter();
212
213 while let Some(&v1) = iter.next() {
214 let q1 = ((v1 / scale + zero_point as f32).round().clamp(0.0, 15.0) as u8) & 0x0F;
215 let q2 = if let Some(&v2) = iter.next() {
216 ((v2 / scale + zero_point as f32).round().clamp(0.0, 15.0) as u8) & 0x0F
217 } else {
218 0
219 };
220 quantized.push((q1 << 4) | q2);
221 }
222
223 let shape_vec: Vec<usize> = state.state().shape().to_vec();
224
225 Ok(CompressedState {
226 method: CompressionMethod::Quantize4Bit,
227 data: quantized,
228 shape: shape_vec,
229 scale,
230 zero_point,
231 sparse_indices: None,
232 })
233 }
234
235 fn decompress_quantize_4bit(
236 &self,
237 compressed: &CompressedState,
238 ) -> InferenceResult<HiddenState> {
239 let total_elements = compressed.shape.iter().product();
240 let mut dequantized = Vec::with_capacity(total_elements);
241
242 for &byte in &compressed.data {
243 let q1 = (byte >> 4) & 0x0F;
244 let q2 = byte & 0x0F;
245
246 dequantized.push((q1 as f32 - compressed.zero_point as f32) * compressed.scale);
247 if dequantized.len() < total_elements {
248 dequantized.push((q2 as f32 - compressed.zero_point as f32) * compressed.scale);
249 }
250 }
251
252 let data = Array2::from_shape_vec((compressed.shape[0], compressed.shape[1]), dequantized)
253 .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
254
255 let mut hidden = HiddenState::new(compressed.shape[0], compressed.shape[1]);
256 hidden.update(data);
257 Ok(hidden)
258 }
259
260 fn compress_sparse(&self, state: &HiddenState) -> InferenceResult<CompressedState> {
262 let mut values = Vec::new();
263 let mut indices = Vec::new();
264
265 for (i, &v) in state.state().iter().enumerate() {
266 if v.abs() > self.sparsity_threshold {
267 values.push(v);
268 indices.push(i);
269 }
270 }
271
272 let data_bytes: Vec<u8> = values.iter().flat_map(|&f| f.to_le_bytes()).collect();
273
274 let shape_vec: Vec<usize> = state.state().shape().to_vec();
275
276 Ok(CompressedState {
277 method: CompressionMethod::Sparse,
278 data: data_bytes,
279 shape: shape_vec,
280 scale: 1.0,
281 zero_point: 0,
282 sparse_indices: Some(indices),
283 })
284 }
285
286 fn decompress_sparse(&self, compressed: &CompressedState) -> InferenceResult<HiddenState> {
287 let values: Vec<f32> = compressed
288 .data
289 .chunks_exact(4)
290 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
291 .collect();
292
293 let indices = compressed
294 .sparse_indices
295 .as_ref()
296 .ok_or(InferenceError::ForwardError(
297 "Missing sparse indices".to_string(),
298 ))?;
299
300 let total_elements: usize = compressed.shape.iter().product();
301 let mut dense = vec![0.0f32; total_elements];
302 for (&idx, &val) in indices.iter().zip(values.iter()) {
303 if idx < dense.len() {
304 dense[idx] = val;
305 }
306 }
307
308 let data = Array2::from_shape_vec((compressed.shape[0], compressed.shape[1]), dense)
309 .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
310
311 let mut hidden = HiddenState::new(compressed.shape[0], compressed.shape[1]);
312 hidden.update(data);
313 Ok(hidden)
314 }
315
316 fn compress_quantized_sparse(&self, state: &HiddenState) -> InferenceResult<CompressedState> {
318 let mut values = Vec::new();
319 let mut indices = Vec::new();
320
321 for (i, &v) in state.state().iter().enumerate() {
322 if v.abs() > self.sparsity_threshold {
323 values.push(v);
324 indices.push(i);
325 }
326 }
327
328 let shape_vec: Vec<usize> = state.state().shape().to_vec();
329
330 if values.is_empty() {
331 return Ok(CompressedState {
332 method: CompressionMethod::QuantizedSparse,
333 data: Vec::new(),
334 shape: shape_vec,
335 scale: 1.0,
336 zero_point: 0,
337 sparse_indices: Some(indices),
338 });
339 }
340
341 let min_val = values.iter().copied().fold(f32::INFINITY, f32::min);
342 let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
343
344 let scale = (max_val - min_val) / 255.0;
345 let zero_point = (-min_val / scale).round() as i32;
346
347 let quantized: Vec<u8> = values
348 .iter()
349 .map(|&v| {
350 let scaled = (v / scale + zero_point as f32).round();
351 scaled.clamp(0.0, 255.0) as u8
352 })
353 .collect();
354
355 Ok(CompressedState {
356 method: CompressionMethod::QuantizedSparse,
357 data: quantized,
358 shape: shape_vec,
359 scale,
360 zero_point,
361 sparse_indices: Some(indices),
362 })
363 }
364
365 fn decompress_quantized_sparse(
366 &self,
367 compressed: &CompressedState,
368 ) -> InferenceResult<HiddenState> {
369 let indices = compressed
370 .sparse_indices
371 .as_ref()
372 .ok_or(InferenceError::ForwardError(
373 "Missing sparse indices".to_string(),
374 ))?;
375
376 let total_elements: usize = compressed.shape.iter().product();
377 let mut dense = vec![0.0f32; total_elements];
378
379 if !compressed.data.is_empty() {
380 for (&idx, &q) in indices.iter().zip(compressed.data.iter()) {
381 if idx < dense.len() {
382 dense[idx] = (q as f32 - compressed.zero_point as f32) * compressed.scale;
383 }
384 }
385 }
386
387 let data = Array2::from_shape_vec((compressed.shape[0], compressed.shape[1]), dense)
388 .map_err(|e| InferenceError::ForwardError(e.to_string()))?;
389
390 let mut hidden = HiddenState::new(compressed.shape[0], compressed.shape[1]);
391 hidden.update(data);
392 Ok(hidden)
393 }
394}
395
396