Skip to main content

axonml_serialize/
convert.rs

1//! Model Conversion Utilities
2//!
3//! Helpers for converting between Axonml and other formats (`PyTorch`, ONNX).
4
5use crate::{StateDict, TensorData};
6use std::collections::HashMap;
7
8// =============================================================================
9// PyTorch Conversion
10// =============================================================================
11
12/// Convert a PyTorch-style key to Axonml format.
13///
14/// `PyTorch` uses keys like:
15/// - "module.layer1.weight"
16/// - "`encoder.layers.0.self_attn.q_proj.weight`"
17///
18/// This function normalizes them for Axonml.
19#[must_use] pub fn from_pytorch_key(key: &str) -> String {
20    let mut result = key.to_string();
21
22    // Remove common prefixes
23    if result.starts_with("module.") {
24        result = result.strip_prefix("module.").unwrap().to_string();
25    }
26    if result.starts_with("_orig_mod.") {
27        result = result.strip_prefix("_orig_mod.").unwrap().to_string();
28    }
29
30    result
31}
32
33/// Convert a Axonml key to `PyTorch` format.
34#[must_use] pub fn to_pytorch_key(key: &str) -> String {
35    // Add "module." prefix if not present (for DDP models)
36    key.to_string()
37}
38
39/// Map of `PyTorch` layer names to Axonml equivalents.
40#[must_use] pub fn pytorch_layer_mapping() -> HashMap<&'static str, &'static str> {
41    let mut map = HashMap::new();
42
43    // Linear layers
44    map.insert("fc", "linear");
45    map.insert("dense", "linear");
46
47    // Convolutions
48    map.insert("conv", "conv");
49
50    // Normalization
51    map.insert("bn", "batch_norm");
52    map.insert("batch_norm", "batch_norm");
53    map.insert("layer_norm", "layer_norm");
54    map.insert("ln", "layer_norm");
55
56    // Attention
57    map.insert("self_attn", "attention");
58    map.insert("multihead_attn", "attention");
59
60    map
61}
62
63// =============================================================================
64// ONNX Conversion
65// =============================================================================
66
67/// Convert a shape to ONNX format (with batch dimension handling).
68#[must_use] pub fn to_onnx_shape(shape: &[usize], include_batch: bool) -> Vec<i64> {
69    if include_batch {
70        // ONNX uses -1 for dynamic batch size
71        std::iter::once(-1i64)
72            .chain(shape.iter().map(|&d| d as i64))
73            .collect()
74    } else {
75        shape.iter().map(|&d| d as i64).collect()
76    }
77}
78
79/// Convert from ONNX shape (handling -1 for dynamic dimensions).
80#[must_use] pub fn from_onnx_shape(shape: &[i64], default_dynamic: usize) -> Vec<usize> {
81    shape
82        .iter()
83        .map(|&d| if d < 0 { default_dynamic } else { d as usize })
84        .collect()
85}
86
87/// ONNX operator type mapping.
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum OnnxOpType {
90    /// Element-wise addition.
91    Add,
92    /// Element-wise subtraction.
93    Sub,
94    /// Element-wise multiplication.
95    Mul,
96    /// Element-wise division.
97    Div,
98    /// Matrix multiplication.
99    MatMul,
100    /// General matrix multiplication (with bias).
101    Gemm,
102    /// Rectified Linear Unit activation.
103    Relu,
104    /// Sigmoid activation.
105    Sigmoid,
106    /// Hyperbolic tangent activation.
107    Tanh,
108    /// Softmax activation.
109    Softmax,
110    /// Gaussian Error Linear Unit activation.
111    Gelu,
112    /// Convolution operation.
113    Conv,
114    /// Transposed convolution (deconvolution).
115    ConvTranspose,
116    /// Max pooling operation.
117    MaxPool,
118    /// Average pooling operation.
119    AveragePool,
120    /// Global average pooling operation.
121    GlobalAveragePool,
122    /// Batch normalization.
123    BatchNormalization,
124    /// Layer normalization.
125    LayerNormalization,
126    /// Reshape tensor dimensions.
127    Reshape,
128    /// Transpose tensor dimensions.
129    Transpose,
130    /// Flatten tensor to 2D.
131    Flatten,
132    /// Remove dimensions of size 1.
133    Squeeze,
134    /// Add dimension of size 1.
135    Unsqueeze,
136    /// Concatenate tensors along axis.
137    Concat,
138    /// Sum reduction along axis.
139    ReduceSum,
140    /// Mean reduction along axis.
141    ReduceMean,
142    /// Max reduction along axis.
143    ReduceMax,
144    /// Min reduction along axis.
145    ReduceMin,
146    /// Dropout layer (training regularization).
147    Dropout,
148    /// Constant tensor.
149    Constant,
150    /// Identity pass-through.
151    Identity,
152    /// Unknown or unsupported operator.
153    Unknown,
154}
155
156impl OnnxOpType {
157    /// Parse ONNX operator type from string.
158    #[must_use] pub fn from_str(s: &str) -> Self {
159        match s {
160            "Add" => Self::Add,
161            "Sub" => Self::Sub,
162            "Mul" => Self::Mul,
163            "Div" => Self::Div,
164            "MatMul" => Self::MatMul,
165            "Gemm" => Self::Gemm,
166            "Relu" => Self::Relu,
167            "Sigmoid" => Self::Sigmoid,
168            "Tanh" => Self::Tanh,
169            "Softmax" => Self::Softmax,
170            "Gelu" => Self::Gelu,
171            "Conv" => Self::Conv,
172            "ConvTranspose" => Self::ConvTranspose,
173            "MaxPool" => Self::MaxPool,
174            "AveragePool" => Self::AveragePool,
175            "GlobalAveragePool" => Self::GlobalAveragePool,
176            "BatchNormalization" => Self::BatchNormalization,
177            "LayerNormalization" => Self::LayerNormalization,
178            "Reshape" => Self::Reshape,
179            "Transpose" => Self::Transpose,
180            "Flatten" => Self::Flatten,
181            "Squeeze" => Self::Squeeze,
182            "Unsqueeze" => Self::Unsqueeze,
183            "Concat" => Self::Concat,
184            "ReduceSum" => Self::ReduceSum,
185            "ReduceMean" => Self::ReduceMean,
186            "ReduceMax" => Self::ReduceMax,
187            "ReduceMin" => Self::ReduceMin,
188            "Dropout" => Self::Dropout,
189            "Constant" => Self::Constant,
190            "Identity" => Self::Identity,
191            _ => Self::Unknown,
192        }
193    }
194
195    /// Get the ONNX operator name.
196    #[must_use] pub fn as_str(&self) -> &'static str {
197        match self {
198            Self::Add => "Add",
199            Self::Sub => "Sub",
200            Self::Mul => "Mul",
201            Self::Div => "Div",
202            Self::MatMul => "MatMul",
203            Self::Gemm => "Gemm",
204            Self::Relu => "Relu",
205            Self::Sigmoid => "Sigmoid",
206            Self::Tanh => "Tanh",
207            Self::Softmax => "Softmax",
208            Self::Gelu => "Gelu",
209            Self::Conv => "Conv",
210            Self::ConvTranspose => "ConvTranspose",
211            Self::MaxPool => "MaxPool",
212            Self::AveragePool => "AveragePool",
213            Self::GlobalAveragePool => "GlobalAveragePool",
214            Self::BatchNormalization => "BatchNormalization",
215            Self::LayerNormalization => "LayerNormalization",
216            Self::Reshape => "Reshape",
217            Self::Transpose => "Transpose",
218            Self::Flatten => "Flatten",
219            Self::Squeeze => "Squeeze",
220            Self::Unsqueeze => "Unsqueeze",
221            Self::Concat => "Concat",
222            Self::ReduceSum => "ReduceSum",
223            Self::ReduceMean => "ReduceMean",
224            Self::ReduceMax => "ReduceMax",
225            Self::ReduceMin => "ReduceMin",
226            Self::Dropout => "Dropout",
227            Self::Constant => "Constant",
228            Self::Identity => "Identity",
229            Self::Unknown => "Unknown",
230        }
231    }
232}
233
234// =============================================================================
235// State Dict Conversion
236// =============================================================================
237
238/// Convert a state dict from `PyTorch` naming conventions.
239#[must_use] pub fn convert_from_pytorch(state_dict: &StateDict) -> StateDict {
240    let mut converted = StateDict::new();
241
242    for (key, entry) in state_dict.entries() {
243        let new_key = from_pytorch_key(key);
244        converted.insert_entry(new_key, entry.clone());
245    }
246
247    converted
248}
249
250/// Transpose weights if needed for format conversion.
251///
252/// `PyTorch` Linear: [`out_features`, `in_features`]
253/// Some frameworks: [`in_features`, `out_features`]
254#[must_use] pub fn transpose_linear_weights(data: &TensorData) -> TensorData {
255    if data.shape.len() != 2 {
256        return data.clone();
257    }
258
259    let (rows, cols) = (data.shape[0], data.shape[1]);
260    let mut transposed = vec![0.0; data.values.len()];
261
262    for i in 0..rows {
263        for j in 0..cols {
264            transposed[j * rows + i] = data.values[i * cols + j];
265        }
266    }
267
268    TensorData {
269        shape: vec![cols, rows],
270        values: transposed,
271    }
272}
273
274// =============================================================================
275// Tests
276// =============================================================================
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_from_pytorch_key() {
284        assert_eq!(from_pytorch_key("module.layer1.weight"), "layer1.weight");
285        assert_eq!(from_pytorch_key("layer1.weight"), "layer1.weight");
286        assert_eq!(
287            from_pytorch_key("_orig_mod.encoder.weight"),
288            "encoder.weight"
289        );
290    }
291
292    #[test]
293    fn test_to_onnx_shape() {
294        assert_eq!(to_onnx_shape(&[3, 4], false), vec![3, 4]);
295        assert_eq!(to_onnx_shape(&[3, 4], true), vec![-1, 3, 4]);
296    }
297
298    #[test]
299    fn test_from_onnx_shape() {
300        assert_eq!(from_onnx_shape(&[3, 4], 1), vec![3, 4]);
301        assert_eq!(from_onnx_shape(&[-1, 3, 4], 8), vec![8, 3, 4]);
302    }
303
304    #[test]
305    fn test_onnx_op_type() {
306        assert_eq!(OnnxOpType::from_str("Relu"), OnnxOpType::Relu);
307        assert_eq!(OnnxOpType::from_str("MatMul"), OnnxOpType::MatMul);
308        assert_eq!(OnnxOpType::from_str("Unknown"), OnnxOpType::Unknown);
309
310        assert_eq!(OnnxOpType::Relu.as_str(), "Relu");
311    }
312
313    #[test]
314    fn test_transpose_linear_weights() {
315        let data = TensorData {
316            shape: vec![2, 3],
317            values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
318        };
319
320        let transposed = transpose_linear_weights(&data);
321        assert_eq!(transposed.shape, vec![3, 2]);
322        assert_eq!(transposed.values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
323    }
324
325    #[test]
326    fn test_convert_from_pytorch() {
327        let mut pytorch_dict = StateDict::new();
328        pytorch_dict.insert(
329            "module.linear.weight".to_string(),
330            TensorData {
331                shape: vec![10, 5],
332                values: vec![0.0; 50],
333            },
334        );
335
336        let converted = convert_from_pytorch(&pytorch_dict);
337        assert!(converted.contains("linear.weight"));
338        assert!(!converted.contains("module.linear.weight"));
339    }
340}