1use yscv_tensor::Tensor;
2
3use crate::ModelError;
4
5#[derive(Debug, Clone)]
7pub struct QuantizedTensor {
8 pub data: Vec<i8>,
9 pub shape: Vec<usize>,
10 pub scale: f32,
11 pub zero_point: i8,
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum QuantMode {
17 Symmetric,
19 Asymmetric,
21}
22
23impl QuantizedTensor {
24 pub fn from_tensor(tensor: &Tensor, mode: QuantMode) -> Self {
26 let data = tensor.data();
27 let shape = tensor.shape().to_vec();
28
29 match mode {
30 QuantMode::Symmetric => {
31 let max_abs = data
32 .iter()
33 .map(|v| v.abs())
34 .fold(0.0f32, f32::max)
35 .max(1e-8);
36 let scale = max_abs / 127.0;
37 let quantized: Vec<i8> = data
38 .iter()
39 .map(|&v| (v / scale).round().clamp(-127.0, 127.0) as i8)
40 .collect();
41 Self {
42 data: quantized,
43 shape,
44 scale,
45 zero_point: 0,
46 }
47 }
48 QuantMode::Asymmetric => {
49 let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
50 let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
51 let range = (max_val - min_val).max(1e-8);
52 let scale = range / 255.0;
53 let zp = (-128.0 - min_val / scale).round().clamp(-128.0, 127.0) as i8;
54 let quantized: Vec<i8> = data
55 .iter()
56 .map(|&v| (v / scale + zp as f32).round().clamp(-128.0, 127.0) as i8)
57 .collect();
58 Self {
59 data: quantized,
60 shape,
61 scale,
62 zero_point: zp,
63 }
64 }
65 }
66 }
67
68 pub fn to_tensor(&self) -> Result<Tensor, ModelError> {
70 let data: Vec<f32> = self
71 .data
72 .iter()
73 .map(|&q| (q as f32 - self.zero_point as f32) * self.scale)
74 .collect();
75 Tensor::from_vec(self.shape.clone(), data).map_err(Into::into)
76 }
77
78 pub fn len(&self) -> usize {
80 self.data.len()
81 }
82
83 pub fn is_empty(&self) -> bool {
85 self.data.is_empty()
86 }
87
88 pub fn compression_ratio(&self) -> f32 {
90 4.0
91 }
92
93 pub fn byte_size(&self) -> usize {
95 self.data.len()
96 }
97}
98
99pub fn quantized_matmul(
109 lhs: &QuantizedTensor,
110 rhs: &QuantizedTensor,
111 mode: QuantMode,
112) -> Result<QuantizedTensor, ModelError> {
113 if lhs.shape.len() != 2 || rhs.shape.len() != 2 {
114 let a = lhs.to_tensor()?;
116 let b = rhs.to_tensor()?;
117 let c = yscv_kernels::matmul_2d(&a, &b)?;
118 return Ok(QuantizedTensor::from_tensor(&c, mode));
119 }
120
121 let m = lhs.shape[0];
122 let k = lhs.shape[1];
123 let n = rhs.shape[1];
124 if rhs.shape[0] != k {
125 let a = lhs.to_tensor()?;
126 let b = rhs.to_tensor()?;
127 let c = yscv_kernels::matmul_2d(&a, &b)?;
128 return Ok(QuantizedTensor::from_tensor(&c, mode));
129 }
130
131 let zp_a = lhs.zero_point as i32;
132 let zp_b = rhs.zero_point as i32;
133 let combined_scale = lhs.scale * rhs.scale;
134
135 let mut c_f32 = vec![0.0f32; m * n];
137 for i in 0..m {
138 for j in 0..n {
139 let mut acc = 0i32;
140 for kk in 0..k {
141 let a_val = lhs.data[i * k + kk] as i32 - zp_a;
142 let b_val = rhs.data[kk * n + j] as i32 - zp_b;
143 acc += a_val * b_val;
144 }
145 c_f32[i * n + j] = acc as f32 * combined_scale;
146 }
147 }
148
149 let result = Tensor::from_vec(vec![m, n], c_f32)?;
150 Ok(QuantizedTensor::from_tensor(&result, mode))
151}
152
153pub fn quantize_weights(weights: &[Tensor], mode: QuantMode) -> Vec<QuantizedTensor> {
157 weights
158 .iter()
159 .map(|w| QuantizedTensor::from_tensor(w, mode))
160 .collect()
161}
162
163pub fn dequantize_weights(quantized: &[QuantizedTensor]) -> Result<Vec<Tensor>, ModelError> {
165 quantized.iter().map(|q| q.to_tensor()).collect()
166}
167
168pub struct PerChannelQuantResult {
173 pub data: Vec<i8>,
174 pub scales: Vec<f32>,
175 pub shape: Vec<usize>,
176}
177
178pub fn quantize_per_channel(
179 tensor: &Tensor,
180 channel_axis: usize,
181) -> Result<PerChannelQuantResult, ModelError> {
182 let shape = tensor.shape();
183 let data = tensor.data();
184 let num_channels = shape[channel_axis];
185 let total = data.len();
186 let channel_stride: usize = shape[channel_axis + 1..].iter().product();
187 let _outer_stride: usize = shape[channel_axis..].iter().product();
188
189 let mut scales = vec![0.0f32; num_channels];
190 let mut quantized = vec![0i8; total];
191
192 for (i, &v) in data.iter().enumerate() {
194 let ch = (i / channel_stride) % num_channels;
195 scales[ch] = scales[ch].max(v.abs());
196 }
197 for s in &mut scales {
198 *s = (*s).max(1e-8) / 127.0;
199 }
200
201 for (i, &v) in data.iter().enumerate() {
202 let ch = (i / channel_stride) % num_channels;
203 quantized[i] = (v / scales[ch]).round().clamp(-127.0, 127.0) as i8;
204 }
205
206 Ok(PerChannelQuantResult {
207 data: quantized,
208 scales,
209 shape: shape.to_vec(),
210 })
211}
212
213#[derive(Debug, Clone)]
219pub struct PrunedTensor {
220 pub mask: Tensor,
222 pub pruned_weights: Tensor,
224 pub sparsity: f32,
226}
227
228pub fn prune_magnitude(weights: &Tensor, sparsity: f32) -> Result<PrunedTensor, ModelError> {
233 if !(0.0..=1.0).contains(&sparsity) {
234 return Err(ModelError::InvalidDropoutRate { rate: sparsity });
235 }
236 let data = weights.data();
237 let n = data.len();
238 if n == 0 || sparsity == 0.0 {
239 let mask = Tensor::from_vec(weights.shape().to_vec(), vec![1.0f32; n])?;
240 return Ok(PrunedTensor {
241 mask,
242 pruned_weights: weights.clone(),
243 sparsity: 0.0,
244 });
245 }
246
247 let mut abs_vals: Vec<f32> = data.iter().map(|v| v.abs()).collect();
249 abs_vals.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
250 let cutoff_idx = ((n as f32 * sparsity) as usize).min(n - 1);
251 let threshold = abs_vals[cutoff_idx];
252
253 let mut mask_data = Vec::with_capacity(n);
254 let mut pruned_data = Vec::with_capacity(n);
255 let mut pruned_count = 0usize;
256 for &v in data {
257 if v.abs() <= threshold {
258 mask_data.push(0.0f32);
259 pruned_data.push(0.0f32);
260 pruned_count += 1;
261 } else {
262 mask_data.push(1.0f32);
263 pruned_data.push(v);
264 }
265 }
266
267 let actual_sparsity = pruned_count as f32 / n as f32;
268 let mask = Tensor::from_vec(weights.shape().to_vec(), mask_data)?;
269 let pruned_weights = Tensor::from_vec(weights.shape().to_vec(), pruned_data)?;
270
271 Ok(PrunedTensor {
272 mask,
273 pruned_weights,
274 sparsity: actual_sparsity,
275 })
276}
277
278pub fn apply_pruning_mask(weights: &Tensor, mask: &Tensor) -> Result<Tensor, ModelError> {
280 weights.mul(mask).map_err(ModelError::Tensor)
281}