1use crate::onnx_proto::{attribute_proto, tensor_proto, AttributeProto, NodeProto, TensorProto};
23
24#[derive(Debug, Clone)]
30pub struct DequantLinearNames {
31 pub quantized_name: String,
33 pub scale_name: String,
35 pub zp_name: String,
37 pub node_name: String,
39 pub output_name: String,
42}
43
44impl DequantLinearNames {
45 pub fn from_original(original_name: &str) -> Self {
47 Self {
48 quantized_name: format!("{}_quantized", original_name),
49 scale_name: format!("{}_scale", original_name),
50 zp_name: format!("{}_zp", original_name),
51 node_name: format!("DequantizeLinear_{}", original_name),
52 output_name: original_name.to_string(),
53 }
54 }
55}
56
57pub fn build_dequantize_linear_node(names: &DequantLinearNames, axis: Option<usize>) -> NodeProto {
71 let attribute = match axis {
72 Some(a) => vec![AttributeProto {
73 name: "axis".to_string(),
74 r#type: attribute_proto::AttributeType::Int as i32,
75 i: a as i64,
76 ..Default::default()
77 }],
78 None => vec![],
79 };
80
81 NodeProto {
82 op_type: "DequantizeLinear".to_string(),
83 name: names.node_name.clone(),
84 input: vec![
85 names.quantized_name.clone(),
86 names.scale_name.clone(),
87 names.zp_name.clone(),
88 ],
89 output: vec![names.output_name.clone()],
90 attribute,
91 ..Default::default()
92 }
93}
94
95pub fn build_quantized_weight_tensor(
105 names: &DequantLinearNames,
106 values: &[i8],
107 shape: &[i64],
108) -> TensorProto {
109 TensorProto {
110 name: names.quantized_name.clone(),
111 data_type: tensor_proto::DataType::Int8 as i32,
112 dims: shape.to_vec(),
113 raw_data: values.iter().map(|&v| v as u8).collect(),
115 ..Default::default()
116 }
117}
118
119pub fn build_scale_tensor(names: &DequantLinearNames, scales: &[f32]) -> TensorProto {
125 let mut t = TensorProto {
126 name: names.scale_name.clone(),
127 data_type: tensor_proto::DataType::Float as i32,
128 float_data: scales.to_vec(),
129 ..Default::default()
130 };
131 if scales.len() > 1 {
132 t.dims = vec![scales.len() as i64];
134 }
135 t
137}
138
139pub fn build_zero_point_tensor(names: &DequantLinearNames, zps: &[i8]) -> TensorProto {
144 let mut t = TensorProto {
145 name: names.zp_name.clone(),
146 data_type: tensor_proto::DataType::Int8 as i32,
147 raw_data: zps.iter().map(|&v| v as u8).collect(),
148 ..Default::default()
149 };
150 if zps.len() > 1 {
151 t.dims = vec![zps.len() as i64];
153 }
154 t
156}
157
158#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::onnx_proto::tensor_proto;
166
167 #[test]
168 fn test_names_from_simple_weight() {
169 let n = DequantLinearNames::from_original("conv1.weight");
170 assert_eq!(n.quantized_name, "conv1.weight_quantized");
171 assert_eq!(n.scale_name, "conv1.weight_scale");
172 assert_eq!(n.zp_name, "conv1.weight_zp");
173 assert_eq!(n.node_name, "DequantizeLinear_conv1.weight");
174 assert_eq!(n.output_name, "conv1.weight");
175 }
176
177 #[test]
178 fn test_names_from_dotted_path() {
179 let n = DequantLinearNames::from_original("layer1.0.conv1.weight");
181 assert_eq!(n.quantized_name, "layer1.0.conv1.weight_quantized");
182 assert_eq!(n.output_name, "layer1.0.conv1.weight");
183 }
184
185 #[test]
186 fn test_dequantize_linear_node_inputs_outputs() {
187 let names = DequantLinearNames::from_original("fc.weight");
188 let node = build_dequantize_linear_node(&names, None);
189
190 assert_eq!(node.op_type, "DequantizeLinear");
191 assert_eq!(node.name, "DequantizeLinear_fc.weight");
192
193 assert_eq!(node.input.len(), 3);
194 assert_eq!(node.input[0], "fc.weight_quantized");
195 assert_eq!(node.input[1], "fc.weight_scale");
196 assert_eq!(node.input[2], "fc.weight_zp");
197
198 assert_eq!(node.output.len(), 1);
199 assert_eq!(node.output[0], "fc.weight");
200 assert!(node.attribute.is_empty());
201 }
202
203 #[test]
204 fn test_dequantize_linear_node_with_axis() {
205 let names = DequantLinearNames::from_original("conv.weight");
206 let node = build_dequantize_linear_node(&names, Some(0));
207
208 assert_eq!(node.attribute.len(), 1);
209 assert_eq!(node.attribute[0].name, "axis");
210 assert_eq!(node.attribute[0].i, 0);
211 }
212
213 #[test]
214 fn test_quantized_weight_tensor_shape_and_data() {
215 let names = DequantLinearNames::from_original("w");
216 let values = vec![1i8, -2, 3, -4, 5, 6];
217 let shape = vec![2i64, 3];
218 let t = build_quantized_weight_tensor(&names, &values, &shape);
219
220 assert_eq!(t.name, "w_quantized");
221 assert_eq!(t.data_type, tensor_proto::DataType::Int8 as i32);
222 assert_eq!(t.dims.len(), 2);
223 assert_eq!(t.dims[0], 2);
224 assert_eq!(t.dims[1], 3);
225
226 let recovered: Vec<i8> = t.raw_data.iter().map(|&b| b as i8).collect();
228 assert_eq!(recovered, values);
229 }
230
231 #[test]
232 fn test_scale_tensor_scalar() {
233 let names = DequantLinearNames::from_original("w");
234 let t = build_scale_tensor(&names, &[0.003921]);
235
236 assert_eq!(t.name, "w_scale");
237 assert_eq!(t.data_type, tensor_proto::DataType::Float as i32);
238 assert_eq!(t.dims.len(), 0, "single scale must be rank-0 scalar");
239 assert!((t.float_data[0] - 0.003921).abs() < 1e-6);
240 }
241
242 #[test]
243 fn test_scale_tensor_per_channel() {
244 let names = DequantLinearNames::from_original("w");
245 let t = build_scale_tensor(&names, &[0.01, 0.02, 0.03]);
246
247 assert_eq!(t.dims.len(), 1);
248 assert_eq!(t.dims[0], 3);
249 assert_eq!(t.float_data.len(), 3);
250 }
251
252 #[test]
253 fn test_zero_point_tensor_scalar() {
254 let names = DequantLinearNames::from_original("w");
255 let t = build_zero_point_tensor(&names, &[-3]);
256
257 assert_eq!(t.name, "w_zp");
258 assert_eq!(t.data_type, tensor_proto::DataType::Int8 as i32);
259 assert_eq!(t.dims.len(), 0, "single zp must be rank-0 scalar");
260 assert_eq!(t.raw_data[0], (-3i8) as u8);
261 }
262
263 #[test]
264 fn test_zero_point_tensor_per_channel() {
265 let names = DequantLinearNames::from_original("w");
266 let t = build_zero_point_tensor(&names, &[-3, 0, 5]);
267
268 assert_eq!(t.dims.len(), 1);
269 assert_eq!(t.dims[0], 3);
270 assert_eq!(t.raw_data.len(), 3);
271 }
272
273 #[test]
274 fn test_int4_range_values_round_trip() {
275 let names = DequantLinearNames::from_original("w");
277 let values = vec![-8i8, -1, 0, 7];
278 let shape = vec![4i64];
279 let t = build_quantized_weight_tensor(&names, &values, &shape);
280
281 let recovered: Vec<i8> = t.raw_data.iter().map(|&b| b as i8).collect();
282 assert_eq!(recovered, values);
283 }
284}