Skip to main content

yscv_model/
quantize.rs

1use yscv_tensor::Tensor;
2
3use crate::ModelError;
4
5/// Quantized tensor representation: INT8 values + per-tensor scale + zero-point.
6#[derive(Debug, Clone)]
7pub struct QuantizedTensor {
8    pub data: Vec<i8>,
9    pub shape: Vec<usize>,
10    pub scale: f32,
11    pub zero_point: i8,
12}
13
14/// Quantization mode.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum QuantMode {
17    /// Symmetric: zero_point = 0, range maps to [-127, 127].
18    Symmetric,
19    /// Asymmetric: full [-128, 127] range with dynamic zero_point.
20    Asymmetric,
21}
22
23impl QuantizedTensor {
24    /// Quantize an f32 tensor to INT8.
25    pub fn from_tensor(tensor: &Tensor, mode: QuantMode) -> Self {
26        let data = tensor.data();
27        let shape = tensor.shape().to_vec();
28
29        match mode {
30            QuantMode::Symmetric => {
31                let max_abs = data
32                    .iter()
33                    .map(|v| v.abs())
34                    .fold(0.0f32, f32::max)
35                    .max(1e-8);
36                let scale = max_abs / 127.0;
37                let quantized: Vec<i8> = data
38                    .iter()
39                    .map(|&v| (v / scale).round().clamp(-127.0, 127.0) as i8)
40                    .collect();
41                Self {
42                    data: quantized,
43                    shape,
44                    scale,
45                    zero_point: 0,
46                }
47            }
48            QuantMode::Asymmetric => {
49                let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
50                let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
51                let range = (max_val - min_val).max(1e-8);
52                let scale = range / 255.0;
53                let zp = (-128.0 - min_val / scale).round().clamp(-128.0, 127.0) as i8;
54                let quantized: Vec<i8> = data
55                    .iter()
56                    .map(|&v| (v / scale + zp as f32).round().clamp(-128.0, 127.0) as i8)
57                    .collect();
58                Self {
59                    data: quantized,
60                    shape,
61                    scale,
62                    zero_point: zp,
63                }
64            }
65        }
66    }
67
68    /// Dequantize back to f32 tensor.
69    pub fn to_tensor(&self) -> Result<Tensor, ModelError> {
70        let data: Vec<f32> = self
71            .data
72            .iter()
73            .map(|&q| (q as f32 - self.zero_point as f32) * self.scale)
74            .collect();
75        Tensor::from_vec(self.shape.clone(), data).map_err(Into::into)
76    }
77
78    /// Number of elements.
79    pub fn len(&self) -> usize {
80        self.data.len()
81    }
82
83    /// Whether empty.
84    pub fn is_empty(&self) -> bool {
85        self.data.is_empty()
86    }
87
88    /// Compression ratio vs f32 (4x for INT8).
89    pub fn compression_ratio(&self) -> f32 {
90        4.0
91    }
92
93    /// Total bytes of quantized data (not including metadata).
94    pub fn byte_size(&self) -> usize {
95        self.data.len()
96    }
97}
98
99/// Quantized matmul: dequantize -> f32 matmul -> re-quantize.
100///
101/// This is a naive implementation; a production path would use integer GEMM.
102/// Integer-accumulating quantized matmul: `C = A @ B` in INT8 with INT32 accumulation.
103///
104/// Avoids dequantizing to f32 — computes directly in integer domain:
105/// `C_f32[i,j] = scale_a * scale_b * sum_k((A_i8[i,k] - zp_a) * (B_i8[k,j] - zp_b))`
106///
107/// Then re-quantizes the result.
108pub fn quantized_matmul(
109    lhs: &QuantizedTensor,
110    rhs: &QuantizedTensor,
111    mode: QuantMode,
112) -> Result<QuantizedTensor, ModelError> {
113    if lhs.shape.len() != 2 || rhs.shape.len() != 2 {
114        // Fallback to dequant path for non-2D
115        let a = lhs.to_tensor()?;
116        let b = rhs.to_tensor()?;
117        let c = yscv_kernels::matmul_2d(&a, &b)?;
118        return Ok(QuantizedTensor::from_tensor(&c, mode));
119    }
120
121    let m = lhs.shape[0];
122    let k = lhs.shape[1];
123    let n = rhs.shape[1];
124    if rhs.shape[0] != k {
125        let a = lhs.to_tensor()?;
126        let b = rhs.to_tensor()?;
127        let c = yscv_kernels::matmul_2d(&a, &b)?;
128        return Ok(QuantizedTensor::from_tensor(&c, mode));
129    }
130
131    let zp_a = lhs.zero_point as i32;
132    let zp_b = rhs.zero_point as i32;
133    let combined_scale = lhs.scale * rhs.scale;
134
135    // Integer GEMM with i32 accumulation
136    let mut c_f32 = vec![0.0f32; m * n];
137    for i in 0..m {
138        for j in 0..n {
139            let mut acc = 0i32;
140            for kk in 0..k {
141                let a_val = lhs.data[i * k + kk] as i32 - zp_a;
142                let b_val = rhs.data[kk * n + j] as i32 - zp_b;
143                acc += a_val * b_val;
144            }
145            c_f32[i * n + j] = acc as f32 * combined_scale;
146        }
147    }
148
149    let result = Tensor::from_vec(vec![m, n], c_f32)?;
150    Ok(QuantizedTensor::from_tensor(&result, mode))
151}
152
153/// Quantize all weight tensors in a model checkpoint for storage/inference.
154///
155/// Returns `(quantized_weights, original_shapes)` for each weight tensor.
156pub fn quantize_weights(weights: &[Tensor], mode: QuantMode) -> Vec<QuantizedTensor> {
157    weights
158        .iter()
159        .map(|w| QuantizedTensor::from_tensor(w, mode))
160        .collect()
161}
162
163/// Dequantize a set of quantized weights back to f32 tensors.
164pub fn dequantize_weights(quantized: &[QuantizedTensor]) -> Result<Vec<Tensor>, ModelError> {
165    quantized.iter().map(|q| q.to_tensor()).collect()
166}
167
168/// Per-channel symmetric quantization for conv weights `[KH, KW, C_in, C_out]`.
169///
170/// Each output channel gets its own scale factor for better accuracy.
171/// Per-channel quantization result.
172pub struct PerChannelQuantResult {
173    pub data: Vec<i8>,
174    pub scales: Vec<f32>,
175    pub shape: Vec<usize>,
176}
177
178pub fn quantize_per_channel(
179    tensor: &Tensor,
180    channel_axis: usize,
181) -> Result<PerChannelQuantResult, ModelError> {
182    let shape = tensor.shape();
183    let data = tensor.data();
184    let num_channels = shape[channel_axis];
185    let total = data.len();
186    let channel_stride: usize = shape[channel_axis + 1..].iter().product();
187    let _outer_stride: usize = shape[channel_axis..].iter().product();
188
189    let mut scales = vec![0.0f32; num_channels];
190    let mut quantized = vec![0i8; total];
191
192    // Compute per-channel max abs
193    for (i, &v) in data.iter().enumerate() {
194        let ch = (i / channel_stride) % num_channels;
195        scales[ch] = scales[ch].max(v.abs());
196    }
197    for s in &mut scales {
198        *s = (*s).max(1e-8) / 127.0;
199    }
200
201    for (i, &v) in data.iter().enumerate() {
202        let ch = (i / channel_stride) % num_channels;
203        quantized[i] = (v / scales[ch]).round().clamp(-127.0, 127.0) as i8;
204    }
205
206    Ok(PerChannelQuantResult {
207        data: quantized,
208        scales,
209        shape: shape.to_vec(),
210    })
211}
212
213// ---------------------------------------------------------------------------
214// Weight Pruning
215// ---------------------------------------------------------------------------
216
217/// Result of magnitude-based weight pruning.
218#[derive(Debug, Clone)]
219pub struct PrunedTensor {
220    /// Binary mask: 1.0 = keep, 0.0 = pruned.
221    pub mask: Tensor,
222    /// Original weights with pruned values zeroed out.
223    pub pruned_weights: Tensor,
224    /// Fraction of weights set to zero (0.0–1.0).
225    pub sparsity: f32,
226}
227
228/// Prune weights by magnitude: zero out the smallest `sparsity` fraction.
229///
230/// For example, `sparsity = 0.5` removes the 50% of weights with smallest
231/// absolute value. Returns the pruned weights and binary mask.
232pub fn prune_magnitude(weights: &Tensor, sparsity: f32) -> Result<PrunedTensor, ModelError> {
233    if !(0.0..=1.0).contains(&sparsity) {
234        return Err(ModelError::InvalidDropoutRate { rate: sparsity });
235    }
236    let data = weights.data();
237    let n = data.len();
238    if n == 0 || sparsity == 0.0 {
239        let mask = Tensor::from_vec(weights.shape().to_vec(), vec![1.0f32; n])?;
240        return Ok(PrunedTensor {
241            mask,
242            pruned_weights: weights.clone(),
243            sparsity: 0.0,
244        });
245    }
246
247    // Find threshold: sort absolute values, pick the sparsity-percentile value
248    let mut abs_vals: Vec<f32> = data.iter().map(|v| v.abs()).collect();
249    abs_vals.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
250    let cutoff_idx = ((n as f32 * sparsity) as usize).min(n - 1);
251    let threshold = abs_vals[cutoff_idx];
252
253    let mut mask_data = Vec::with_capacity(n);
254    let mut pruned_data = Vec::with_capacity(n);
255    let mut pruned_count = 0usize;
256    for &v in data {
257        if v.abs() <= threshold {
258            mask_data.push(0.0f32);
259            pruned_data.push(0.0f32);
260            pruned_count += 1;
261        } else {
262            mask_data.push(1.0f32);
263            pruned_data.push(v);
264        }
265    }
266
267    let actual_sparsity = pruned_count as f32 / n as f32;
268    let mask = Tensor::from_vec(weights.shape().to_vec(), mask_data)?;
269    let pruned_weights = Tensor::from_vec(weights.shape().to_vec(), pruned_data)?;
270
271    Ok(PrunedTensor {
272        mask,
273        pruned_weights,
274        sparsity: actual_sparsity,
275    })
276}
277
278/// Apply a binary mask to weights (element-wise multiply).
279pub fn apply_pruning_mask(weights: &Tensor, mask: &Tensor) -> Result<Tensor, ModelError> {
280    weights.mul(mask).map_err(ModelError::Tensor)
281}