Skip to main content

axonml_serialize/
lib.rs

1//! Axonml Serialize - Model Serialization for Axonml ML Framework
2//!
3//! This crate provides functionality for saving and loading trained models,
4//! including state dictionaries, model checkpoints, and format conversion.
5//!
6//! # Supported Formats
7//!
8//! - **Axonml Native** (.axonml) - Efficient binary format
9//! - **JSON** (.json) - Human-readable format for debugging
10//! - **`SafeTensors`** (.safetensors) - Safe, fast format (optional feature)
11//!
12//! # Example
13//!
14//! ```ignore
15//! use axonml_serialize::{save_model, load_model, StateDict};
16//!
17//! // Save model
18//! save_model(&model, "model.axonml")?;
19//!
20//! // Load model
21//! let state_dict = load_state_dict("model.axonml")?;
22//! model.load_state_dict(&state_dict)?;
23//! ```
24//!
25//! @version 0.1.0
26//! @author `AutomataNexus` Development Team
27
28#![warn(missing_docs)]
29#![warn(clippy::all)]
30#![warn(clippy::pedantic)]
31// ML/tensor-specific allowances
32#![allow(clippy::cast_possible_truncation)]
33#![allow(clippy::cast_sign_loss)]
34#![allow(clippy::cast_precision_loss)]
35#![allow(clippy::cast_possible_wrap)]
36#![allow(clippy::missing_errors_doc)]
37#![allow(clippy::missing_panics_doc)]
38#![allow(clippy::must_use_candidate)]
39#![allow(clippy::module_name_repetitions)]
40#![allow(clippy::similar_names)]
41#![allow(clippy::many_single_char_names)]
42#![allow(clippy::too_many_arguments)]
43#![allow(clippy::doc_markdown)]
44#![allow(clippy::cast_lossless)]
45#![allow(clippy::needless_pass_by_value)]
46#![allow(clippy::redundant_closure_for_method_calls)]
47#![allow(clippy::uninlined_format_args)]
48#![allow(clippy::ptr_arg)]
49#![allow(clippy::return_self_not_must_use)]
50#![allow(clippy::not_unsafe_ptr_arg_deref)]
51#![allow(clippy::items_after_statements)]
52#![allow(clippy::unreadable_literal)]
53#![allow(clippy::if_same_then_else)]
54#![allow(clippy::needless_range_loop)]
55#![allow(clippy::trivially_copy_pass_by_ref)]
56#![allow(clippy::unnecessary_wraps)]
57#![allow(clippy::match_same_arms)]
58#![allow(clippy::unused_self)]
59#![allow(clippy::too_many_lines)]
60#![allow(clippy::single_match_else)]
61#![allow(clippy::fn_params_excessive_bools)]
62#![allow(clippy::struct_excessive_bools)]
63#![allow(clippy::format_push_string)]
64#![allow(clippy::erasing_op)]
65#![allow(clippy::type_repetition_in_bounds)]
66#![allow(clippy::iter_without_into_iter)]
67#![allow(clippy::should_implement_trait)]
68#![allow(clippy::use_debug)]
69#![allow(clippy::case_sensitive_file_extension_comparisons)]
70#![allow(clippy::large_enum_variant)]
71#![allow(clippy::panic)]
72#![allow(clippy::struct_field_names)]
73#![allow(clippy::missing_fields_in_debug)]
74#![allow(clippy::upper_case_acronyms)]
75#![allow(clippy::assigning_clones)]
76#![allow(clippy::option_if_let_else)]
77#![allow(clippy::manual_let_else)]
78#![allow(clippy::explicit_iter_loop)]
79#![allow(clippy::default_trait_access)]
80#![allow(clippy::only_used_in_recursion)]
81#![allow(clippy::manual_clamp)]
82#![allow(clippy::ref_option)]
83#![allow(clippy::multiple_bound_locations)]
84#![allow(clippy::comparison_chain)]
85#![allow(clippy::manual_assert)]
86#![allow(clippy::unnecessary_debug_formatting)]
87
88// =============================================================================
89// Modules
90// =============================================================================
91
92mod checkpoint;
93mod convert;
94mod format;
95mod state_dict;
96
97// =============================================================================
98// Re-exports
99// =============================================================================
100
101pub use checkpoint::{Checkpoint, CheckpointBuilder, TrainingState};
102pub use convert::{
103    convert_from_pytorch, from_onnx_shape, from_pytorch_key, pytorch_layer_mapping, to_onnx_shape,
104    to_pytorch_key, transpose_linear_weights, OnnxOpType,
105};
106pub use format::{detect_format, detect_format_from_bytes, Format};
107pub use state_dict::{StateDict, StateDictEntry, TensorData};
108
109// =============================================================================
110// Imports
111// =============================================================================
112
113use axonml_core::{Error, Result};
114use axonml_nn::Module;
115use std::fs::File;
116use std::io::{BufReader, BufWriter, Read, Write};
117use std::path::Path;
118
119// =============================================================================
120// High-Level API
121// =============================================================================
122
123/// Save a model's state dictionary to a file.
124///
125/// The format is automatically determined from the file extension.
126pub fn save_model<M: Module, P: AsRef<Path>>(model: &M, path: P) -> Result<()> {
127    let path = path.as_ref();
128    let format = detect_format(path);
129    let state_dict = StateDict::from_module(model);
130
131    save_state_dict(&state_dict, path, format)
132}
133
134/// Save a state dictionary to a file with specified format.
135pub fn save_state_dict<P: AsRef<Path>>(
136    state_dict: &StateDict,
137    path: P,
138    format: Format,
139) -> Result<()> {
140    let path = path.as_ref();
141    let file = File::create(path).map_err(|e| Error::InvalidOperation {
142        message: e.to_string(),
143    })?;
144    let mut writer = BufWriter::new(file);
145
146    match format {
147        Format::Axonml => {
148            let encoded = bincode::serialize(state_dict).map_err(|e| Error::InvalidOperation {
149                message: e.to_string(),
150            })?;
151            writer
152                .write_all(&encoded)
153                .map_err(|e| Error::InvalidOperation {
154                    message: e.to_string(),
155                })?;
156        }
157        Format::Json => {
158            serde_json::to_writer_pretty(&mut writer, state_dict).map_err(|e| {
159                Error::InvalidOperation {
160                    message: e.to_string(),
161                }
162            })?;
163        }
164        #[cfg(feature = "safetensors")]
165        Format::SafeTensors => {
166            save_safetensors(state_dict, path)?;
167        }
168        #[cfg(not(feature = "safetensors"))]
169        Format::SafeTensors => {
170            return Err(Error::InvalidOperation {
171                message: "SafeTensors format requires 'safetensors' feature".to_string(),
172            });
173        }
174    }
175
176    Ok(())
177}
178
179/// Load a state dictionary from a file.
180pub fn load_state_dict<P: AsRef<Path>>(path: P) -> Result<StateDict> {
181    let path = path.as_ref();
182    let format = detect_format(path);
183
184    let file = File::open(path).map_err(|e| Error::InvalidOperation {
185        message: e.to_string(),
186    })?;
187    let mut reader = BufReader::new(file);
188
189    match format {
190        Format::Axonml => {
191            let mut bytes = Vec::new();
192            reader
193                .read_to_end(&mut bytes)
194                .map_err(|e| Error::InvalidOperation {
195                    message: e.to_string(),
196                })?;
197            bincode::deserialize(&bytes).map_err(|e| Error::InvalidOperation {
198                message: e.to_string(),
199            })
200        }
201        Format::Json => serde_json::from_reader(reader).map_err(|e| Error::InvalidOperation {
202            message: e.to_string(),
203        }),
204        #[cfg(feature = "safetensors")]
205        Format::SafeTensors => load_safetensors(path),
206        #[cfg(not(feature = "safetensors"))]
207        Format::SafeTensors => Err(Error::InvalidOperation {
208            message: "SafeTensors format requires 'safetensors' feature".to_string(),
209        }),
210    }
211}
212
213/// Save a complete training checkpoint.
214pub fn save_checkpoint<P: AsRef<Path>>(checkpoint: &Checkpoint, path: P) -> Result<()> {
215    let path = path.as_ref();
216    let file = File::create(path).map_err(|e| Error::InvalidOperation {
217        message: e.to_string(),
218    })?;
219    let writer = BufWriter::new(file);
220
221    bincode::serialize_into(writer, checkpoint).map_err(|e| Error::InvalidOperation {
222        message: e.to_string(),
223    })
224}
225
226/// Load a training checkpoint.
227pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> Result<Checkpoint> {
228    let path = path.as_ref();
229    let file = File::open(path).map_err(|e| Error::InvalidOperation {
230        message: e.to_string(),
231    })?;
232    let reader = BufReader::new(file);
233
234    bincode::deserialize_from(reader).map_err(|e| Error::InvalidOperation {
235        message: e.to_string(),
236    })
237}
238
239// =============================================================================
240// SafeTensors Support
241// =============================================================================
242
243#[cfg(feature = "safetensors")]
244fn save_safetensors<P: AsRef<Path>>(state_dict: &StateDict, path: P) -> Result<()> {
245    use safetensors::tensor::SafeTensors;
246    use std::collections::HashMap;
247
248    let mut tensors: HashMap<String, Vec<u8>> = HashMap::new();
249    let mut metadata: HashMap<String, String> = HashMap::new();
250
251    for (name, entry) in state_dict.entries() {
252        let data_bytes: Vec<u8> = entry
253            .data
254            .values
255            .iter()
256            .flat_map(|f| f.to_le_bytes())
257            .collect();
258        tensors.insert(name.clone(), data_bytes);
259        metadata.insert(format!("{}.shape", name), format!("{:?}", entry.data.shape));
260    }
261
262    // Write using safetensors
263    let bytes =
264        safetensors::serialize(&tensors, &Some(metadata)).map_err(|e| Error::InvalidOperation {
265            message: e.to_string(),
266        })?;
267
268    std::fs::write(path, bytes).map_err(|e| Error::InvalidOperation {
269        message: e.to_string(),
270    })
271}
272
273#[cfg(feature = "safetensors")]
274fn load_safetensors<P: AsRef<Path>>(path: P) -> Result<StateDict> {
275    let bytes = std::fs::read(path).map_err(|e| Error::InvalidOperation {
276        message: e.to_string(),
277    })?;
278
279    let tensors =
280        safetensors::SafeTensors::deserialize(&bytes).map_err(|e| Error::InvalidOperation {
281            message: e.to_string(),
282        })?;
283
284    let mut state_dict = StateDict::new();
285
286    for (name, tensor) in tensors.tensors() {
287        let data = tensor.data();
288        let shape: Vec<usize> = tensor.shape().to_vec();
289
290        // Convert bytes to f32
291        let values: Vec<f32> = data
292            .chunks(4)
293            .map(|chunk| {
294                let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
295                f32::from_le_bytes(bytes)
296            })
297            .collect();
298
299        state_dict.insert(name.to_string(), TensorData { shape, values });
300    }
301
302    Ok(state_dict)
303}
304
305// =============================================================================
306// Tests
307// =============================================================================
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_format_detection() {
315        assert_eq!(detect_format("model.axonml"), Format::Axonml);
316        assert_eq!(detect_format("model.json"), Format::Json);
317        assert_eq!(detect_format("model.safetensors"), Format::SafeTensors);
318        assert_eq!(detect_format("model.bin"), Format::Axonml); // default
319    }
320
321    #[test]
322    fn test_state_dict_creation() {
323        let state_dict = StateDict::new();
324        assert!(state_dict.is_empty());
325        assert_eq!(state_dict.len(), 0);
326    }
327
328    #[test]
329    fn test_state_dict_insert_get() {
330        let mut state_dict = StateDict::new();
331        let data = TensorData {
332            shape: vec![2, 3],
333            values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
334        };
335
336        state_dict.insert("layer.weight".to_string(), data);
337
338        assert_eq!(state_dict.len(), 1);
339        assert!(state_dict.contains("layer.weight"));
340
341        let retrieved = state_dict.get("layer.weight").unwrap();
342        assert_eq!(retrieved.data.shape, vec![2, 3]);
343    }
344
345    #[test]
346    fn test_tensor_data_to_tensor() {
347        let data = TensorData {
348            shape: vec![2, 2],
349            values: vec![1.0, 2.0, 3.0, 4.0],
350        };
351
352        let tensor = data.to_tensor().unwrap();
353        assert_eq!(tensor.shape(), &[2, 2]);
354        assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
355    }
356}