1use crate::{StateDict, TensorData};
6use std::collections::HashMap;
7
8#[must_use] pub fn from_pytorch_key(key: &str) -> String {
20 let mut result = key.to_string();
21
22 if result.starts_with("module.") {
24 result = result.strip_prefix("module.").unwrap().to_string();
25 }
26 if result.starts_with("_orig_mod.") {
27 result = result.strip_prefix("_orig_mod.").unwrap().to_string();
28 }
29
30 result
31}
32
33#[must_use] pub fn to_pytorch_key(key: &str) -> String {
35 key.to_string()
37}
38
39#[must_use] pub fn pytorch_layer_mapping() -> HashMap<&'static str, &'static str> {
41 let mut map = HashMap::new();
42
43 map.insert("fc", "linear");
45 map.insert("dense", "linear");
46
47 map.insert("conv", "conv");
49
50 map.insert("bn", "batch_norm");
52 map.insert("batch_norm", "batch_norm");
53 map.insert("layer_norm", "layer_norm");
54 map.insert("ln", "layer_norm");
55
56 map.insert("self_attn", "attention");
58 map.insert("multihead_attn", "attention");
59
60 map
61}
62
63#[must_use] pub fn to_onnx_shape(shape: &[usize], include_batch: bool) -> Vec<i64> {
69 if include_batch {
70 std::iter::once(-1i64)
72 .chain(shape.iter().map(|&d| d as i64))
73 .collect()
74 } else {
75 shape.iter().map(|&d| d as i64).collect()
76 }
77}
78
79#[must_use] pub fn from_onnx_shape(shape: &[i64], default_dynamic: usize) -> Vec<usize> {
81 shape
82 .iter()
83 .map(|&d| if d < 0 { default_dynamic } else { d as usize })
84 .collect()
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum OnnxOpType {
90 Add,
92 Sub,
94 Mul,
96 Div,
98 MatMul,
100 Gemm,
102 Relu,
104 Sigmoid,
106 Tanh,
108 Softmax,
110 Gelu,
112 Conv,
114 ConvTranspose,
116 MaxPool,
118 AveragePool,
120 GlobalAveragePool,
122 BatchNormalization,
124 LayerNormalization,
126 Reshape,
128 Transpose,
130 Flatten,
132 Squeeze,
134 Unsqueeze,
136 Concat,
138 ReduceSum,
140 ReduceMean,
142 ReduceMax,
144 ReduceMin,
146 Dropout,
148 Constant,
150 Identity,
152 Unknown,
154}
155
156impl OnnxOpType {
157 #[must_use] pub fn from_str(s: &str) -> Self {
159 match s {
160 "Add" => Self::Add,
161 "Sub" => Self::Sub,
162 "Mul" => Self::Mul,
163 "Div" => Self::Div,
164 "MatMul" => Self::MatMul,
165 "Gemm" => Self::Gemm,
166 "Relu" => Self::Relu,
167 "Sigmoid" => Self::Sigmoid,
168 "Tanh" => Self::Tanh,
169 "Softmax" => Self::Softmax,
170 "Gelu" => Self::Gelu,
171 "Conv" => Self::Conv,
172 "ConvTranspose" => Self::ConvTranspose,
173 "MaxPool" => Self::MaxPool,
174 "AveragePool" => Self::AveragePool,
175 "GlobalAveragePool" => Self::GlobalAveragePool,
176 "BatchNormalization" => Self::BatchNormalization,
177 "LayerNormalization" => Self::LayerNormalization,
178 "Reshape" => Self::Reshape,
179 "Transpose" => Self::Transpose,
180 "Flatten" => Self::Flatten,
181 "Squeeze" => Self::Squeeze,
182 "Unsqueeze" => Self::Unsqueeze,
183 "Concat" => Self::Concat,
184 "ReduceSum" => Self::ReduceSum,
185 "ReduceMean" => Self::ReduceMean,
186 "ReduceMax" => Self::ReduceMax,
187 "ReduceMin" => Self::ReduceMin,
188 "Dropout" => Self::Dropout,
189 "Constant" => Self::Constant,
190 "Identity" => Self::Identity,
191 _ => Self::Unknown,
192 }
193 }
194
195 #[must_use] pub fn as_str(&self) -> &'static str {
197 match self {
198 Self::Add => "Add",
199 Self::Sub => "Sub",
200 Self::Mul => "Mul",
201 Self::Div => "Div",
202 Self::MatMul => "MatMul",
203 Self::Gemm => "Gemm",
204 Self::Relu => "Relu",
205 Self::Sigmoid => "Sigmoid",
206 Self::Tanh => "Tanh",
207 Self::Softmax => "Softmax",
208 Self::Gelu => "Gelu",
209 Self::Conv => "Conv",
210 Self::ConvTranspose => "ConvTranspose",
211 Self::MaxPool => "MaxPool",
212 Self::AveragePool => "AveragePool",
213 Self::GlobalAveragePool => "GlobalAveragePool",
214 Self::BatchNormalization => "BatchNormalization",
215 Self::LayerNormalization => "LayerNormalization",
216 Self::Reshape => "Reshape",
217 Self::Transpose => "Transpose",
218 Self::Flatten => "Flatten",
219 Self::Squeeze => "Squeeze",
220 Self::Unsqueeze => "Unsqueeze",
221 Self::Concat => "Concat",
222 Self::ReduceSum => "ReduceSum",
223 Self::ReduceMean => "ReduceMean",
224 Self::ReduceMax => "ReduceMax",
225 Self::ReduceMin => "ReduceMin",
226 Self::Dropout => "Dropout",
227 Self::Constant => "Constant",
228 Self::Identity => "Identity",
229 Self::Unknown => "Unknown",
230 }
231 }
232}
233
234#[must_use] pub fn convert_from_pytorch(state_dict: &StateDict) -> StateDict {
240 let mut converted = StateDict::new();
241
242 for (key, entry) in state_dict.entries() {
243 let new_key = from_pytorch_key(key);
244 converted.insert_entry(new_key, entry.clone());
245 }
246
247 converted
248}
249
250#[must_use] pub fn transpose_linear_weights(data: &TensorData) -> TensorData {
255 if data.shape.len() != 2 {
256 return data.clone();
257 }
258
259 let (rows, cols) = (data.shape[0], data.shape[1]);
260 let mut transposed = vec![0.0; data.values.len()];
261
262 for i in 0..rows {
263 for j in 0..cols {
264 transposed[j * rows + i] = data.values[i * cols + j];
265 }
266 }
267
268 TensorData {
269 shape: vec![cols, rows],
270 values: transposed,
271 }
272}
273
274#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn test_from_pytorch_key() {
284 assert_eq!(from_pytorch_key("module.layer1.weight"), "layer1.weight");
285 assert_eq!(from_pytorch_key("layer1.weight"), "layer1.weight");
286 assert_eq!(
287 from_pytorch_key("_orig_mod.encoder.weight"),
288 "encoder.weight"
289 );
290 }
291
292 #[test]
293 fn test_to_onnx_shape() {
294 assert_eq!(to_onnx_shape(&[3, 4], false), vec![3, 4]);
295 assert_eq!(to_onnx_shape(&[3, 4], true), vec![-1, 3, 4]);
296 }
297
298 #[test]
299 fn test_from_onnx_shape() {
300 assert_eq!(from_onnx_shape(&[3, 4], 1), vec![3, 4]);
301 assert_eq!(from_onnx_shape(&[-1, 3, 4], 8), vec![8, 3, 4]);
302 }
303
304 #[test]
305 fn test_onnx_op_type() {
306 assert_eq!(OnnxOpType::from_str("Relu"), OnnxOpType::Relu);
307 assert_eq!(OnnxOpType::from_str("MatMul"), OnnxOpType::MatMul);
308 assert_eq!(OnnxOpType::from_str("Unknown"), OnnxOpType::Unknown);
309
310 assert_eq!(OnnxOpType::Relu.as_str(), "Relu");
311 }
312
313 #[test]
314 fn test_transpose_linear_weights() {
315 let data = TensorData {
316 shape: vec![2, 3],
317 values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
318 };
319
320 let transposed = transpose_linear_weights(&data);
321 assert_eq!(transposed.shape, vec![3, 2]);
322 assert_eq!(transposed.values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
323 }
324
325 #[test]
326 fn test_convert_from_pytorch() {
327 let mut pytorch_dict = StateDict::new();
328 pytorch_dict.insert(
329 "module.linear.weight".to_string(),
330 TensorData {
331 shape: vec![10, 5],
332 values: vec![0.0; 50],
333 },
334 );
335
336 let converted = convert_from_pytorch(&pytorch_dict);
337 assert!(converted.contains("linear.weight"));
338 assert!(!converted.contains("module.linear.weight"));
339 }
340}