1use crate::onnx_proto::{
23 AttributeProto, NodeProto, TensorProto,
24 attribute_proto, tensor_proto,
25};
26
27#[derive(Debug, Clone)]
33pub struct DequantLinearNames {
34 pub quantized_name: String,
36 pub scale_name: String,
38 pub zp_name: String,
40 pub node_name: String,
42 pub output_name: String,
45}
46
47impl DequantLinearNames {
48 pub fn from_original(original_name: &str) -> Self {
50 Self {
51 quantized_name: format!("{}_quantized", original_name),
52 scale_name: format!("{}_scale", original_name),
53 zp_name: format!("{}_zp", original_name),
54 node_name: format!("DequantizeLinear_{}", original_name),
55 output_name: original_name.to_string(),
56 }
57 }
58}
59
60pub fn build_dequantize_linear_node(
74 names: &DequantLinearNames,
75 axis: Option<usize>,
76) -> NodeProto {
77 let attribute = match axis {
78 Some(a) => vec![AttributeProto {
79 name: "axis".to_string(),
80 r#type: attribute_proto::AttributeType::Int as i32,
81 i: a as i64,
82 ..Default::default()
83 }],
84 None => vec![],
85 };
86
87 NodeProto {
88 op_type: "DequantizeLinear".to_string(),
89 name: names.node_name.clone(),
90 input: vec![
91 names.quantized_name.clone(),
92 names.scale_name.clone(),
93 names.zp_name.clone(),
94 ],
95 output: vec![names.output_name.clone()],
96 attribute,
97 ..Default::default()
98 }
99}
100
101pub fn build_quantized_weight_tensor(
111 names: &DequantLinearNames,
112 values: &[i8],
113 shape: &[i64],
114) -> TensorProto {
115 TensorProto {
116 name: names.quantized_name.clone(),
117 data_type: tensor_proto::DataType::Int8 as i32,
118 dims: shape.to_vec(),
119 raw_data: values.iter().map(|&v| v as u8).collect(),
121 ..Default::default()
122 }
123}
124
125pub fn build_scale_tensor(names: &DequantLinearNames, scales: &[f32]) -> TensorProto {
131 let mut t = TensorProto {
132 name: names.scale_name.clone(),
133 data_type: tensor_proto::DataType::Float as i32,
134 float_data: scales.to_vec(),
135 ..Default::default()
136 };
137 if scales.len() > 1 {
138 t.dims = vec![scales.len() as i64];
140 }
141 t
143}
144
145pub fn build_zero_point_tensor(names: &DequantLinearNames, zps: &[i8]) -> TensorProto {
150 let mut t = TensorProto {
151 name: names.zp_name.clone(),
152 data_type: tensor_proto::DataType::Int8 as i32,
153 raw_data: zps.iter().map(|&v| v as u8).collect(),
154 ..Default::default()
155 };
156 if zps.len() > 1 {
157 t.dims = vec![zps.len() as i64];
159 }
160 t
162}
163
164#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::onnx_proto::tensor_proto;
172
173 #[test]
174 fn test_names_from_simple_weight() {
175 let n = DequantLinearNames::from_original("conv1.weight");
176 assert_eq!(n.quantized_name, "conv1.weight_quantized");
177 assert_eq!(n.scale_name, "conv1.weight_scale");
178 assert_eq!(n.zp_name, "conv1.weight_zp");
179 assert_eq!(n.node_name, "DequantizeLinear_conv1.weight");
180 assert_eq!(n.output_name, "conv1.weight");
181 }
182
183 #[test]
184 fn test_names_from_dotted_path() {
185 let n = DequantLinearNames::from_original("layer1.0.conv1.weight");
187 assert_eq!(n.quantized_name, "layer1.0.conv1.weight_quantized");
188 assert_eq!(n.output_name, "layer1.0.conv1.weight");
189 }
190
191 #[test]
192 fn test_dequantize_linear_node_inputs_outputs() {
193 let names = DequantLinearNames::from_original("fc.weight");
194 let node = build_dequantize_linear_node(&names, None);
195
196 assert_eq!(node.op_type, "DequantizeLinear");
197 assert_eq!(node.name, "DequantizeLinear_fc.weight");
198
199 assert_eq!(node.input.len(), 3);
200 assert_eq!(node.input[0], "fc.weight_quantized");
201 assert_eq!(node.input[1], "fc.weight_scale");
202 assert_eq!(node.input[2], "fc.weight_zp");
203
204 assert_eq!(node.output.len(), 1);
205 assert_eq!(node.output[0], "fc.weight");
206 assert!(node.attribute.is_empty());
207 }
208
209 #[test]
210 fn test_dequantize_linear_node_with_axis() {
211 let names = DequantLinearNames::from_original("conv.weight");
212 let node = build_dequantize_linear_node(&names, Some(0));
213
214 assert_eq!(node.attribute.len(), 1);
215 assert_eq!(node.attribute[0].name, "axis");
216 assert_eq!(node.attribute[0].i, 0);
217 }
218
219 #[test]
220 fn test_quantized_weight_tensor_shape_and_data() {
221 let names = DequantLinearNames::from_original("w");
222 let values = vec![1i8, -2, 3, -4, 5, 6];
223 let shape = vec![2i64, 3];
224 let t = build_quantized_weight_tensor(&names, &values, &shape);
225
226 assert_eq!(t.name, "w_quantized");
227 assert_eq!(t.data_type, tensor_proto::DataType::Int8 as i32);
228 assert_eq!(t.dims.len(), 2);
229 assert_eq!(t.dims[0], 2);
230 assert_eq!(t.dims[1], 3);
231
232 let recovered: Vec<i8> = t.raw_data.iter().map(|&b| b as i8).collect();
234 assert_eq!(recovered, values);
235 }
236
237 #[test]
238 fn test_scale_tensor_scalar() {
239 let names = DequantLinearNames::from_original("w");
240 let t = build_scale_tensor(&names, &[0.003921]);
241
242 assert_eq!(t.name, "w_scale");
243 assert_eq!(t.data_type, tensor_proto::DataType::Float as i32);
244 assert_eq!(t.dims.len(), 0, "single scale must be rank-0 scalar");
245 assert!((t.float_data[0] - 0.003921).abs() < 1e-6);
246 }
247
248 #[test]
249 fn test_scale_tensor_per_channel() {
250 let names = DequantLinearNames::from_original("w");
251 let t = build_scale_tensor(&names, &[0.01, 0.02, 0.03]);
252
253 assert_eq!(t.dims.len(), 1);
254 assert_eq!(t.dims[0], 3);
255 assert_eq!(t.float_data.len(), 3);
256 }
257
258 #[test]
259 fn test_zero_point_tensor_scalar() {
260 let names = DequantLinearNames::from_original("w");
261 let t = build_zero_point_tensor(&names, &[-3]);
262
263 assert_eq!(t.name, "w_zp");
264 assert_eq!(t.data_type, tensor_proto::DataType::Int8 as i32);
265 assert_eq!(t.dims.len(), 0, "single zp must be rank-0 scalar");
266 assert_eq!(t.raw_data[0], (-3i8) as u8);
267 }
268
269 #[test]
270 fn test_zero_point_tensor_per_channel() {
271 let names = DequantLinearNames::from_original("w");
272 let t = build_zero_point_tensor(&names, &[-3, 0, 5]);
273
274 assert_eq!(t.dims.len(), 1);
275 assert_eq!(t.dims[0], 3);
276 assert_eq!(t.raw_data.len(), 3);
277 }
278
279 #[test]
280 fn test_int4_range_values_round_trip() {
281 let names = DequantLinearNames::from_original("w");
283 let values = vec![-8i8, -1, 0, 7];
284 let shape = vec![4i64];
285 let t = build_quantized_weight_tensor(&names, &values, &shape);
286
287 let recovered: Vec<i8> = t.raw_data.iter().map(|&b| b as i8).collect();
288 assert_eq!(recovered, values);
289 }
290}