1use crate::onnx_proto::{attribute_proto, tensor_proto, AttributeProto, NodeProto, TensorProto};
23
24#[derive(Debug, Clone)]
30pub struct DequantLinearNames {
31 pub quantized_name: String,
33 pub scale_name: String,
35 pub zp_name: String,
37 pub node_name: String,
39 pub output_name: String,
42}
43
44impl DequantLinearNames {
45 pub fn from_original(original_name: &str) -> Self {
47 Self {
48 quantized_name: format!("{}_quantized", original_name),
49 scale_name: format!("{}_scale", original_name),
50 zp_name: format!("{}_zp", original_name),
51 node_name: format!("DequantizeLinear_{}", original_name),
52 output_name: original_name.to_string(),
53 }
54 }
55}
56
57pub fn build_dequantize_linear_node(names: &DequantLinearNames, axis: Option<usize>) -> NodeProto {
71 let attribute = match axis {
72 Some(a) => vec![AttributeProto {
73 name: "axis".to_string(),
74 r#type: attribute_proto::AttributeType::Int as i32,
75 i: a as i64,
76 ..Default::default()
77 }],
78 None => vec![],
79 };
80
81 NodeProto {
82 op_type: "DequantizeLinear".to_string(),
83 name: names.node_name.clone(),
84 input: vec![
85 names.quantized_name.clone(),
86 names.scale_name.clone(),
87 names.zp_name.clone(),
88 ],
89 output: vec![names.output_name.clone()],
90 attribute,
91 ..Default::default()
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum StorageFormat {
106 Int8Widened,
109 NativeInt4,
112}
113
114pub(crate) fn pack_int4_onnx(values: &[i8]) -> Vec<u8> {
121 let mut packed = Vec::with_capacity(values.len().div_ceil(2));
122 for chunk in values.chunks(2) {
123 let lo = (chunk[0] & 0x0F) as u8;
124 let hi = if chunk.len() > 1 {
125 (chunk[1] & 0x0F) as u8
126 } else {
127 0
128 };
129 packed.push((hi << 4) | lo);
130 }
131 packed
132}
133
134pub(crate) fn unpack_int4_onnx(packed: &[u8], num_values: usize) -> Vec<i8> {
137 let mut values = Vec::with_capacity(num_values);
138 for &byte in packed {
139 let lo = byte & 0x0F;
140 let hi = (byte >> 4) & 0x0F;
141 values.push(sign_extend_nibble(lo));
142 if values.len() < num_values {
143 values.push(sign_extend_nibble(hi));
144 }
145 }
146 values.truncate(num_values);
147 values
148}
149
150#[inline]
151fn sign_extend_nibble(nibble: u8) -> i8 {
152 if nibble >= 8 {
153 (nibble as i8) | !0x0F
154 } else {
155 nibble as i8
156 }
157}
158
159pub fn build_quantized_weight_tensor(
170 names: &DequantLinearNames,
171 values: &[i8],
172 shape: &[i64],
173 format: StorageFormat,
174) -> TensorProto {
175 match format {
176 StorageFormat::Int8Widened => TensorProto {
177 name: names.quantized_name.clone(),
178 data_type: tensor_proto::DataType::Int8 as i32,
179 dims: shape.to_vec(),
180 raw_data: values.iter().map(|&v| v as u8).collect(),
182 ..Default::default()
183 },
184 StorageFormat::NativeInt4 => TensorProto {
185 name: names.quantized_name.clone(),
186 data_type: tensor_proto::DataType::Int4 as i32,
187 dims: shape.to_vec(),
188 raw_data: pack_int4_onnx(values),
189 ..Default::default()
190 },
191 }
192}
193
194pub fn build_scale_tensor(names: &DequantLinearNames, scales: &[f32]) -> TensorProto {
200 let mut t = TensorProto {
201 name: names.scale_name.clone(),
202 data_type: tensor_proto::DataType::Float as i32,
203 float_data: scales.to_vec(),
204 ..Default::default()
205 };
206 if scales.len() > 1 {
207 t.dims = vec![scales.len() as i64];
209 }
210 t
212}
213
214pub fn build_zero_point_tensor(
221 names: &DequantLinearNames,
222 zps: &[i8],
223 format: StorageFormat,
224) -> TensorProto {
225 let (data_type, raw_data) = match format {
226 StorageFormat::Int8Widened => (
227 tensor_proto::DataType::Int8 as i32,
228 zps.iter().map(|&v| v as u8).collect(),
229 ),
230 StorageFormat::NativeInt4 => (tensor_proto::DataType::Int4 as i32, pack_int4_onnx(zps)),
231 };
232
233 let mut t = TensorProto {
234 name: names.zp_name.clone(),
235 data_type,
236 raw_data,
237 ..Default::default()
238 };
239 if zps.len() > 1 {
240 t.dims = vec![zps.len() as i64];
242 }
243 t
245}
246
247#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::onnx_proto::tensor_proto;
255
256 #[test]
257 fn test_names_from_simple_weight() {
258 let n = DequantLinearNames::from_original("conv1.weight");
259 assert_eq!(n.quantized_name, "conv1.weight_quantized");
260 assert_eq!(n.scale_name, "conv1.weight_scale");
261 assert_eq!(n.zp_name, "conv1.weight_zp");
262 assert_eq!(n.node_name, "DequantizeLinear_conv1.weight");
263 assert_eq!(n.output_name, "conv1.weight");
264 }
265
266 #[test]
267 fn test_names_from_dotted_path() {
268 let n = DequantLinearNames::from_original("layer1.0.conv1.weight");
270 assert_eq!(n.quantized_name, "layer1.0.conv1.weight_quantized");
271 assert_eq!(n.output_name, "layer1.0.conv1.weight");
272 }
273
274 #[test]
275 fn test_dequantize_linear_node_inputs_outputs() {
276 let names = DequantLinearNames::from_original("fc.weight");
277 let node = build_dequantize_linear_node(&names, None);
278
279 assert_eq!(node.op_type, "DequantizeLinear");
280 assert_eq!(node.name, "DequantizeLinear_fc.weight");
281
282 assert_eq!(node.input.len(), 3);
283 assert_eq!(node.input[0], "fc.weight_quantized");
284 assert_eq!(node.input[1], "fc.weight_scale");
285 assert_eq!(node.input[2], "fc.weight_zp");
286
287 assert_eq!(node.output.len(), 1);
288 assert_eq!(node.output[0], "fc.weight");
289 assert!(node.attribute.is_empty());
290 }
291
292 #[test]
293 fn test_dequantize_linear_node_with_axis() {
294 let names = DequantLinearNames::from_original("conv.weight");
295 let node = build_dequantize_linear_node(&names, Some(0));
296
297 assert_eq!(node.attribute.len(), 1);
298 assert_eq!(node.attribute[0].name, "axis");
299 assert_eq!(node.attribute[0].i, 0);
300 }
301
302 #[test]
303 fn test_quantized_weight_tensor_shape_and_data() {
304 let names = DequantLinearNames::from_original("w");
305 let values = vec![1i8, -2, 3, -4, 5, 6];
306 let shape = vec![2i64, 3];
307 let t = build_quantized_weight_tensor(&names, &values, &shape, StorageFormat::Int8Widened);
308
309 assert_eq!(t.name, "w_quantized");
310 assert_eq!(t.data_type, tensor_proto::DataType::Int8 as i32);
311 assert_eq!(t.dims.len(), 2);
312 assert_eq!(t.dims[0], 2);
313 assert_eq!(t.dims[1], 3);
314
315 let recovered: Vec<i8> = t.raw_data.iter().map(|&b| b as i8).collect();
317 assert_eq!(recovered, values);
318 }
319
320 #[test]
321 fn test_scale_tensor_scalar() {
322 let names = DequantLinearNames::from_original("w");
323 let t = build_scale_tensor(&names, &[0.003921]);
324
325 assert_eq!(t.name, "w_scale");
326 assert_eq!(t.data_type, tensor_proto::DataType::Float as i32);
327 assert_eq!(t.dims.len(), 0, "single scale must be rank-0 scalar");
328 assert!((t.float_data[0] - 0.003921).abs() < 1e-6);
329 }
330
331 #[test]
332 fn test_scale_tensor_per_channel() {
333 let names = DequantLinearNames::from_original("w");
334 let t = build_scale_tensor(&names, &[0.01, 0.02, 0.03]);
335
336 assert_eq!(t.dims.len(), 1);
337 assert_eq!(t.dims[0], 3);
338 assert_eq!(t.float_data.len(), 3);
339 }
340
341 #[test]
342 fn test_zero_point_tensor_scalar() {
343 let names = DequantLinearNames::from_original("w");
344 let t = build_zero_point_tensor(&names, &[-3], StorageFormat::Int8Widened);
345
346 assert_eq!(t.name, "w_zp");
347 assert_eq!(t.data_type, tensor_proto::DataType::Int8 as i32);
348 assert_eq!(t.dims.len(), 0, "single zp must be rank-0 scalar");
349 assert_eq!(t.raw_data[0], (-3i8) as u8);
350 }
351
352 #[test]
353 fn test_zero_point_tensor_per_channel() {
354 let names = DequantLinearNames::from_original("w");
355 let t = build_zero_point_tensor(&names, &[-3, 0, 5], StorageFormat::Int8Widened);
356
357 assert_eq!(t.dims.len(), 1);
358 assert_eq!(t.dims[0], 3);
359 assert_eq!(t.raw_data.len(), 3);
360 }
361
362 #[test]
363 fn test_int4_range_values_round_trip() {
364 let names = DequantLinearNames::from_original("w");
366 let values = vec![-8i8, -1, 0, 7];
367 let shape = vec![4i64];
368 let t = build_quantized_weight_tensor(&names, &values, &shape, StorageFormat::Int8Widened);
369
370 let recovered: Vec<i8> = t.raw_data.iter().map(|&b| b as i8).collect();
371 assert_eq!(recovered, values);
372 }
373
374 #[test]
379 fn test_onnx_pack_layout_even_index_in_low_nibble() {
380 let packed = pack_int4_onnx(&[1, 2]);
383 assert_eq!(packed, vec![0x21]);
384
385 let packed = pack_int4_onnx(&[0, 0x7]);
386 assert_eq!(packed, vec![0x70]);
387 }
388
389 #[test]
390 fn test_onnx_pack_negative_values() {
391 assert_eq!(pack_int4_onnx(&[-1, -1]), vec![0xFF]);
394
395 assert_eq!(pack_int4_onnx(&[-8, 7]), vec![0x78]);
397 }
398
399 #[test]
400 fn test_onnx_pack_odd_length_zero_pads_high_nibble() {
401 assert_eq!(pack_int4_onnx(&[0x3]), vec![0x03]);
403 assert_eq!(pack_int4_onnx(&[-1]), vec![0x0F]);
404 }
405
406 #[test]
407 fn test_onnx_pack_unpack_round_trip_all_values() {
408 let values: Vec<i8> = (-8..=7).collect();
409 let packed = pack_int4_onnx(&values);
410 let unpacked = unpack_int4_onnx(&packed, values.len());
411 assert_eq!(unpacked, values);
412 assert_eq!(packed.len(), 8, "16 values must pack to exactly 8 bytes");
413 }
414
415 #[test]
416 fn test_onnx_pack_unpack_round_trip_odd_length() {
417 let values: Vec<i8> = vec![-8, -1, 0, 7, -3];
418 let packed = pack_int4_onnx(&values);
419 let unpacked = unpack_int4_onnx(&packed, values.len());
420 assert_eq!(unpacked, values);
421 assert_eq!(packed.len(), 3, "5 values must pack to ceil(5/2) = 3 bytes");
422 }
423
424 #[test]
425 fn test_native_int4_weight_tensor_uses_int4_data_type() {
426 let names = DequantLinearNames::from_original("w");
427 let values = vec![-8i8, -1, 0, 7];
428 let shape = vec![4i64];
429 let t = build_quantized_weight_tensor(&names, &values, &shape, StorageFormat::NativeInt4);
430
431 assert_eq!(t.data_type, tensor_proto::DataType::Int4 as i32);
432 assert_eq!(t.dims, vec![4], "dims should be logical element count");
433 assert_eq!(t.raw_data.len(), 2, "4 values → 2 packed bytes");
434
435 let recovered = unpack_int4_onnx(&t.raw_data, values.len());
436 assert_eq!(recovered, values);
437 }
438
439 #[test]
440 fn test_native_int4_zero_point_scalar() {
441 let names = DequantLinearNames::from_original("w");
442 let t = build_zero_point_tensor(&names, &[-3], StorageFormat::NativeInt4);
443
444 assert_eq!(t.data_type, tensor_proto::DataType::Int4 as i32);
445 assert_eq!(t.dims.len(), 0, "scalar zp has rank 0");
446 assert_eq!(t.raw_data.len(), 1);
447
448 let recovered = unpack_int4_onnx(&t.raw_data, 1);
449 assert_eq!(recovered, vec![-3]);
450 }
451
452 #[test]
453 fn test_native_int4_zero_point_per_channel() {
454 let names = DequantLinearNames::from_original("w");
455 let zps = vec![-3, 0, 5, -1, 7];
456 let t = build_zero_point_tensor(&names, &zps, StorageFormat::NativeInt4);
457
458 assert_eq!(t.data_type, tensor_proto::DataType::Int4 as i32);
459 assert_eq!(t.dims, vec![5], "per-channel zp has rank 1");
460 assert_eq!(t.raw_data.len(), 3, "5 values → 3 packed bytes");
461
462 let recovered = unpack_int4_onnx(&t.raw_data, zps.len());
463 assert_eq!(recovered, zps);
464 }
465}