axonml-serialize
Overview
axonml-serialize provides model serialization functionality for the AxonML machine learning framework. It supports saving and loading trained models, including state dictionaries, training checkpoints, and format conversion utilities for interoperability with PyTorch and ONNX.
Features
- Multiple Formats - Support for AxonML native binary (.axonml), JSON (.json), and SafeTensors (.safetensors) formats
- State Dictionaries - PyTorch-style state_dict for storing and loading model parameters
- Training Checkpoints - Save complete training state including model, optimizer, epoch, and metrics
- Format Detection - Automatic format detection from file extensions and magic bytes
- PyTorch Conversion - Utilities for converting between PyTorch and AxonML naming conventions
- ONNX Shape Utilities - Helper functions for ONNX shape conversion with dynamic dimension support
- Metadata Support - Attach custom metadata to state dictionaries and checkpoints
Modules
| Module | Description |
|---|---|
state_dict |
StateDict and TensorData for storing model parameters by name |
checkpoint |
Checkpoint and TrainingState for saving/resuming training sessions |
format |
Format enum and detection utilities for different serialization formats |
convert |
Conversion utilities for PyTorch and ONNX interoperability |
Usage
Add the dependency to your Cargo.toml:
[]
= "0.1.0"
Saving and Loading Models
use ;
use Linear;
// Save a model (format detected from extension)
let model = new;
save_model?; // Binary format
save_model?; // JSON format
// Load state dictionary
let state_dict = load_state_dict?;
println!;
println!;
Working with State Dictionaries
use ;
// Create a state dictionary
let mut state_dict = new;
let weights = TensorData ;
state_dict.insert;
let bias = TensorData ;
state_dict.insert;
// Query the state dictionary
assert!;
println!;
// Filter by prefix
let linear_params = state_dict.filter_prefix;
// Strip prefix from keys
let stripped = state_dict.strip_prefix;
assert!;
Training Checkpoints
use ;
// Track training state
let mut training_state = new;
training_state.record_loss;
training_state.record_loss;
training_state.update_best; // lower is better
training_state.next_epoch;
training_state.next_step;
// Create checkpoint with builder pattern
let checkpoint = builder
.model_state
.optimizer_state
.training_state
.epoch
.global_step
.config
.config
.build;
// Save and load checkpoints
save_checkpoint?;
let loaded = load_checkpoint?;
println!;
println!;
Format Detection
use ;
// Detect from file extension
let format = detect_format;
assert_eq!;
let format = detect_format;
assert_eq!;
// Detect from file contents
let bytes = b"{\"key\": \"value\"}";
let format = detect_format_from_bytes;
assert_eq!;
// Format properties
assert!;
assert!;
PyTorch Conversion
use ;
// Convert PyTorch key naming to AxonML
let key = from_pytorch_key;
assert_eq!;
// Convert entire state dictionary
let axonml_dict = convert_from_pytorch;
// Transpose linear weights if needed (PyTorch uses [out, in])
let transposed = transpose_linear_weights;
ONNX Shape Utilities
use ;
// Convert to ONNX shape (with dynamic batch)
let onnx_shape = to_onnx_shape;
assert_eq!;
// Convert from ONNX shape (replace -1 with default)
let shape = from_onnx_shape;
assert_eq!;
// ONNX operator type mapping
let op = from_str;
assert_eq!;
State Dictionary Metadata
use StateDict;
let mut state_dict = new;
state_dict.set_metadata;
state_dict.set_metadata;
if let Some = state_dict.get_metadata
Tests
Run the test suite:
License
Licensed under either of:
- MIT License (LICENSE-MIT or http://opensource.org/licenses/MIT)
- Apache License, Version 2.0 (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
at your option.