Skip to main content

axonml_serialize/
lib.rs

1//! Axonml Serialize - Model Serialization for Axonml ML Framework
2//!
3//! # File
4//! `crates/axonml-serialize/src/lib.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17#![warn(missing_docs)]
18#![warn(clippy::all)]
19#![warn(clippy::pedantic)]
20// ML/tensor-specific allowances
21#![allow(clippy::cast_possible_truncation)]
22#![allow(clippy::cast_sign_loss)]
23#![allow(clippy::cast_precision_loss)]
24#![allow(clippy::cast_possible_wrap)]
25#![allow(clippy::missing_errors_doc)]
26#![allow(clippy::missing_panics_doc)]
27#![allow(clippy::must_use_candidate)]
28#![allow(clippy::module_name_repetitions)]
29#![allow(clippy::similar_names)]
30#![allow(clippy::many_single_char_names)]
31#![allow(clippy::too_many_arguments)]
32#![allow(clippy::doc_markdown)]
33#![allow(clippy::cast_lossless)]
34#![allow(clippy::needless_pass_by_value)]
35#![allow(clippy::redundant_closure_for_method_calls)]
36#![allow(clippy::uninlined_format_args)]
37#![allow(clippy::ptr_arg)]
38#![allow(clippy::return_self_not_must_use)]
39#![allow(clippy::not_unsafe_ptr_arg_deref)]
40#![allow(clippy::items_after_statements)]
41#![allow(clippy::unreadable_literal)]
42#![allow(clippy::if_same_then_else)]
43#![allow(clippy::needless_range_loop)]
44#![allow(clippy::trivially_copy_pass_by_ref)]
45#![allow(clippy::unnecessary_wraps)]
46#![allow(clippy::match_same_arms)]
47#![allow(clippy::unused_self)]
48#![allow(clippy::too_many_lines)]
49#![allow(clippy::single_match_else)]
50#![allow(clippy::fn_params_excessive_bools)]
51#![allow(clippy::struct_excessive_bools)]
52#![allow(clippy::format_push_string)]
53#![allow(clippy::erasing_op)]
54#![allow(clippy::type_repetition_in_bounds)]
55#![allow(clippy::iter_without_into_iter)]
56#![allow(clippy::should_implement_trait)]
57#![allow(clippy::use_debug)]
58#![allow(clippy::case_sensitive_file_extension_comparisons)]
59#![allow(clippy::large_enum_variant)]
60#![allow(clippy::panic)]
61#![allow(clippy::struct_field_names)]
62#![allow(clippy::missing_fields_in_debug)]
63#![allow(clippy::upper_case_acronyms)]
64#![allow(clippy::assigning_clones)]
65#![allow(clippy::option_if_let_else)]
66#![allow(clippy::manual_let_else)]
67#![allow(clippy::explicit_iter_loop)]
68#![allow(clippy::default_trait_access)]
69#![allow(clippy::only_used_in_recursion)]
70#![allow(clippy::manual_clamp)]
71#![allow(clippy::ref_option)]
72#![allow(clippy::multiple_bound_locations)]
73#![allow(clippy::comparison_chain)]
74#![allow(clippy::manual_assert)]
75#![allow(clippy::unnecessary_debug_formatting)]
76
77// =============================================================================
78// Modules
79// =============================================================================
80
81mod checkpoint;
82mod convert;
83mod format;
84mod state_dict;
85
86// =============================================================================
87// Re-exports
88// =============================================================================
89
90pub use checkpoint::{Checkpoint, CheckpointBuilder, TrainingState};
91pub use convert::{
92    OnnxOpType, convert_from_pytorch, from_onnx_shape, from_pytorch_key, pytorch_layer_mapping,
93    to_onnx_shape, to_pytorch_key, transpose_linear_weights,
94};
95pub use format::{Format, detect_format, detect_format_from_bytes};
96pub use state_dict::{StateDict, StateDictEntry, TensorData};
97
98// =============================================================================
99// Imports
100// =============================================================================
101
102use axonml_core::{Error, Result};
103use axonml_nn::Module;
104use std::fs::File;
105use std::io::{BufReader, BufWriter, Read, Write};
106use std::path::Path;
107
108// =============================================================================
109// High-Level API
110// =============================================================================
111
112/// Save a model's state dictionary to a file.
113///
114/// The format is automatically determined from the file extension.
115pub fn save_model<M: Module, P: AsRef<Path>>(model: &M, path: P) -> Result<()> {
116    let path = path.as_ref();
117    let format = detect_format(path);
118    let state_dict = StateDict::from_module(model);
119
120    save_state_dict(&state_dict, path, format)
121}
122
123/// Save a state dictionary to a file with specified format.
124pub fn save_state_dict<P: AsRef<Path>>(
125    state_dict: &StateDict,
126    path: P,
127    format: Format,
128) -> Result<()> {
129    let path = path.as_ref();
130    let file = File::create(path).map_err(|e| Error::InvalidOperation {
131        message: e.to_string(),
132    })?;
133    let mut writer = BufWriter::new(file);
134
135    match format {
136        Format::Axonml => {
137            let encoded = bincode::serialize(state_dict).map_err(|e| Error::InvalidOperation {
138                message: e.to_string(),
139            })?;
140            writer
141                .write_all(&encoded)
142                .map_err(|e| Error::InvalidOperation {
143                    message: e.to_string(),
144                })?;
145        }
146        Format::Json => {
147            serde_json::to_writer_pretty(&mut writer, state_dict).map_err(|e| {
148                Error::InvalidOperation {
149                    message: e.to_string(),
150                }
151            })?;
152        }
153        #[cfg(feature = "safetensors")]
154        Format::SafeTensors => {
155            save_safetensors(state_dict, path)?;
156        }
157        #[cfg(not(feature = "safetensors"))]
158        Format::SafeTensors => {
159            return Err(Error::InvalidOperation {
160                message: "SafeTensors format requires 'safetensors' feature".to_string(),
161            });
162        }
163    }
164
165    Ok(())
166}
167
168/// Load a state dictionary from a file.
169pub fn load_state_dict<P: AsRef<Path>>(path: P) -> Result<StateDict> {
170    let path = path.as_ref();
171    let format = detect_format(path);
172
173    let file = File::open(path).map_err(|e| Error::InvalidOperation {
174        message: e.to_string(),
175    })?;
176    let mut reader = BufReader::new(file);
177
178    match format {
179        Format::Axonml => {
180            let mut bytes = Vec::new();
181            reader
182                .read_to_end(&mut bytes)
183                .map_err(|e| Error::InvalidOperation {
184                    message: e.to_string(),
185                })?;
186            bincode::deserialize(&bytes).map_err(|e| Error::InvalidOperation {
187                message: e.to_string(),
188            })
189        }
190        Format::Json => serde_json::from_reader(reader).map_err(|e| Error::InvalidOperation {
191            message: e.to_string(),
192        }),
193        #[cfg(feature = "safetensors")]
194        Format::SafeTensors => load_safetensors(path),
195        #[cfg(not(feature = "safetensors"))]
196        Format::SafeTensors => Err(Error::InvalidOperation {
197            message: "SafeTensors format requires 'safetensors' feature".to_string(),
198        }),
199    }
200}
201
202/// Save a complete training checkpoint.
203pub fn save_checkpoint<P: AsRef<Path>>(checkpoint: &Checkpoint, path: P) -> Result<()> {
204    let path = path.as_ref();
205    let file = File::create(path).map_err(|e| Error::InvalidOperation {
206        message: e.to_string(),
207    })?;
208    let writer = BufWriter::new(file);
209
210    bincode::serialize_into(writer, checkpoint).map_err(|e| Error::InvalidOperation {
211        message: e.to_string(),
212    })
213}
214
215/// Load a training checkpoint.
216pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> Result<Checkpoint> {
217    let path = path.as_ref();
218    let file = File::open(path).map_err(|e| Error::InvalidOperation {
219        message: e.to_string(),
220    })?;
221    let reader = BufReader::new(file);
222
223    bincode::deserialize_from(reader).map_err(|e| Error::InvalidOperation {
224        message: e.to_string(),
225    })
226}
227
228// =============================================================================
229// SafeTensors Support
230// =============================================================================
231
232#[cfg(feature = "safetensors")]
233fn save_safetensors<P: AsRef<Path>>(state_dict: &StateDict, path: P) -> Result<()> {
234    use safetensors::tensor::SafeTensors;
235    use std::collections::HashMap;
236
237    let mut tensors: HashMap<String, Vec<u8>> = HashMap::new();
238    let mut metadata: HashMap<String, String> = HashMap::new();
239
240    for (name, entry) in state_dict.entries() {
241        let data_bytes: Vec<u8> = entry
242            .data
243            .values
244            .iter()
245            .flat_map(|f| f.to_le_bytes())
246            .collect();
247        tensors.insert(name.clone(), data_bytes);
248        metadata.insert(format!("{}.shape", name), format!("{:?}", entry.data.shape));
249    }
250
251    // Write using safetensors
252    let bytes =
253        safetensors::serialize(&tensors, &Some(metadata)).map_err(|e| Error::InvalidOperation {
254            message: e.to_string(),
255        })?;
256
257    std::fs::write(path, bytes).map_err(|e| Error::InvalidOperation {
258        message: e.to_string(),
259    })
260}
261
262#[cfg(feature = "safetensors")]
263fn load_safetensors<P: AsRef<Path>>(path: P) -> Result<StateDict> {
264    let bytes = std::fs::read(path).map_err(|e| Error::InvalidOperation {
265        message: e.to_string(),
266    })?;
267
268    let tensors =
269        safetensors::SafeTensors::deserialize(&bytes).map_err(|e| Error::InvalidOperation {
270            message: e.to_string(),
271        })?;
272
273    let mut state_dict = StateDict::new();
274
275    for (name, tensor) in tensors.tensors() {
276        let data = tensor.data();
277        let shape: Vec<usize> = tensor.shape().to_vec();
278
279        // Convert bytes to f32
280        let values: Vec<f32> = data
281            .chunks(4)
282            .map(|chunk| {
283                let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
284                f32::from_le_bytes(bytes)
285            })
286            .collect();
287
288        state_dict.insert(name.to_string(), TensorData { shape, values });
289    }
290
291    Ok(state_dict)
292}
293
294// =============================================================================
295// Tests
296// =============================================================================
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_format_detection() {
304        assert_eq!(detect_format("model.axonml"), Format::Axonml);
305        assert_eq!(detect_format("model.json"), Format::Json);
306        assert_eq!(detect_format("model.safetensors"), Format::SafeTensors);
307        assert_eq!(detect_format("model.bin"), Format::Axonml); // default
308    }
309
310    #[test]
311    fn test_state_dict_creation() {
312        let state_dict = StateDict::new();
313        assert!(state_dict.is_empty());
314        assert_eq!(state_dict.len(), 0);
315    }
316
317    #[test]
318    fn test_state_dict_insert_get() {
319        let mut state_dict = StateDict::new();
320        let data = TensorData {
321            shape: vec![2, 3],
322            values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
323        };
324
325        state_dict.insert("layer.weight".to_string(), data);
326
327        assert_eq!(state_dict.len(), 1);
328        assert!(state_dict.contains("layer.weight"));
329
330        let retrieved = state_dict.get("layer.weight").unwrap();
331        assert_eq!(retrieved.data.shape, vec![2, 3]);
332    }
333
334    #[test]
335    fn test_tensor_data_to_tensor() {
336        let data = TensorData {
337            shape: vec![2, 2],
338            values: vec![1.0, 2.0, 3.0, 4.0],
339        };
340
341        let tensor = data.to_tensor().unwrap();
342        assert_eq!(tensor.shape(), &[2, 2]);
343        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
344    }
345}