1#![warn(missing_docs)]
29#![warn(clippy::all)]
30#![warn(clippy::pedantic)]
31#![allow(clippy::cast_possible_truncation)]
33#![allow(clippy::cast_sign_loss)]
34#![allow(clippy::cast_precision_loss)]
35#![allow(clippy::cast_possible_wrap)]
36#![allow(clippy::missing_errors_doc)]
37#![allow(clippy::missing_panics_doc)]
38#![allow(clippy::must_use_candidate)]
39#![allow(clippy::module_name_repetitions)]
40#![allow(clippy::similar_names)]
41#![allow(clippy::many_single_char_names)]
42#![allow(clippy::too_many_arguments)]
43#![allow(clippy::doc_markdown)]
44#![allow(clippy::cast_lossless)]
45#![allow(clippy::needless_pass_by_value)]
46#![allow(clippy::redundant_closure_for_method_calls)]
47#![allow(clippy::uninlined_format_args)]
48#![allow(clippy::ptr_arg)]
49#![allow(clippy::return_self_not_must_use)]
50#![allow(clippy::not_unsafe_ptr_arg_deref)]
51#![allow(clippy::items_after_statements)]
52#![allow(clippy::unreadable_literal)]
53#![allow(clippy::if_same_then_else)]
54#![allow(clippy::needless_range_loop)]
55#![allow(clippy::trivially_copy_pass_by_ref)]
56#![allow(clippy::unnecessary_wraps)]
57#![allow(clippy::match_same_arms)]
58#![allow(clippy::unused_self)]
59#![allow(clippy::too_many_lines)]
60#![allow(clippy::single_match_else)]
61#![allow(clippy::fn_params_excessive_bools)]
62#![allow(clippy::struct_excessive_bools)]
63#![allow(clippy::format_push_string)]
64#![allow(clippy::erasing_op)]
65#![allow(clippy::type_repetition_in_bounds)]
66#![allow(clippy::iter_without_into_iter)]
67#![allow(clippy::should_implement_trait)]
68#![allow(clippy::use_debug)]
69#![allow(clippy::case_sensitive_file_extension_comparisons)]
70#![allow(clippy::large_enum_variant)]
71#![allow(clippy::panic)]
72#![allow(clippy::struct_field_names)]
73#![allow(clippy::missing_fields_in_debug)]
74#![allow(clippy::upper_case_acronyms)]
75#![allow(clippy::assigning_clones)]
76#![allow(clippy::option_if_let_else)]
77#![allow(clippy::manual_let_else)]
78#![allow(clippy::explicit_iter_loop)]
79#![allow(clippy::default_trait_access)]
80#![allow(clippy::only_used_in_recursion)]
81#![allow(clippy::manual_clamp)]
82#![allow(clippy::ref_option)]
83#![allow(clippy::multiple_bound_locations)]
84#![allow(clippy::comparison_chain)]
85#![allow(clippy::manual_assert)]
86#![allow(clippy::unnecessary_debug_formatting)]
87
88mod checkpoint;
93mod convert;
94mod format;
95mod state_dict;
96
97pub use checkpoint::{Checkpoint, CheckpointBuilder, TrainingState};
102pub use convert::{
103 convert_from_pytorch, from_onnx_shape, from_pytorch_key, pytorch_layer_mapping, to_onnx_shape,
104 to_pytorch_key, transpose_linear_weights, OnnxOpType,
105};
106pub use format::{detect_format, detect_format_from_bytes, Format};
107pub use state_dict::{StateDict, StateDictEntry, TensorData};
108
109use axonml_core::{Error, Result};
114use axonml_nn::Module;
115use std::fs::File;
116use std::io::{BufReader, BufWriter, Read, Write};
117use std::path::Path;
118
119pub fn save_model<M: Module, P: AsRef<Path>>(model: &M, path: P) -> Result<()> {
127 let path = path.as_ref();
128 let format = detect_format(path);
129 let state_dict = StateDict::from_module(model);
130
131 save_state_dict(&state_dict, path, format)
132}
133
134pub fn save_state_dict<P: AsRef<Path>>(
136 state_dict: &StateDict,
137 path: P,
138 format: Format,
139) -> Result<()> {
140 let path = path.as_ref();
141 let file = File::create(path).map_err(|e| Error::InvalidOperation {
142 message: e.to_string(),
143 })?;
144 let mut writer = BufWriter::new(file);
145
146 match format {
147 Format::Axonml => {
148 let encoded = bincode::serialize(state_dict).map_err(|e| Error::InvalidOperation {
149 message: e.to_string(),
150 })?;
151 writer
152 .write_all(&encoded)
153 .map_err(|e| Error::InvalidOperation {
154 message: e.to_string(),
155 })?;
156 }
157 Format::Json => {
158 serde_json::to_writer_pretty(&mut writer, state_dict).map_err(|e| {
159 Error::InvalidOperation {
160 message: e.to_string(),
161 }
162 })?;
163 }
164 #[cfg(feature = "safetensors")]
165 Format::SafeTensors => {
166 save_safetensors(state_dict, path)?;
167 }
168 #[cfg(not(feature = "safetensors"))]
169 Format::SafeTensors => {
170 return Err(Error::InvalidOperation {
171 message: "SafeTensors format requires 'safetensors' feature".to_string(),
172 });
173 }
174 }
175
176 Ok(())
177}
178
179pub fn load_state_dict<P: AsRef<Path>>(path: P) -> Result<StateDict> {
181 let path = path.as_ref();
182 let format = detect_format(path);
183
184 let file = File::open(path).map_err(|e| Error::InvalidOperation {
185 message: e.to_string(),
186 })?;
187 let mut reader = BufReader::new(file);
188
189 match format {
190 Format::Axonml => {
191 let mut bytes = Vec::new();
192 reader
193 .read_to_end(&mut bytes)
194 .map_err(|e| Error::InvalidOperation {
195 message: e.to_string(),
196 })?;
197 bincode::deserialize(&bytes).map_err(|e| Error::InvalidOperation {
198 message: e.to_string(),
199 })
200 }
201 Format::Json => serde_json::from_reader(reader).map_err(|e| Error::InvalidOperation {
202 message: e.to_string(),
203 }),
204 #[cfg(feature = "safetensors")]
205 Format::SafeTensors => load_safetensors(path),
206 #[cfg(not(feature = "safetensors"))]
207 Format::SafeTensors => Err(Error::InvalidOperation {
208 message: "SafeTensors format requires 'safetensors' feature".to_string(),
209 }),
210 }
211}
212
213pub fn save_checkpoint<P: AsRef<Path>>(checkpoint: &Checkpoint, path: P) -> Result<()> {
215 let path = path.as_ref();
216 let file = File::create(path).map_err(|e| Error::InvalidOperation {
217 message: e.to_string(),
218 })?;
219 let writer = BufWriter::new(file);
220
221 bincode::serialize_into(writer, checkpoint).map_err(|e| Error::InvalidOperation {
222 message: e.to_string(),
223 })
224}
225
226pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> Result<Checkpoint> {
228 let path = path.as_ref();
229 let file = File::open(path).map_err(|e| Error::InvalidOperation {
230 message: e.to_string(),
231 })?;
232 let reader = BufReader::new(file);
233
234 bincode::deserialize_from(reader).map_err(|e| Error::InvalidOperation {
235 message: e.to_string(),
236 })
237}
238
239#[cfg(feature = "safetensors")]
244fn save_safetensors<P: AsRef<Path>>(state_dict: &StateDict, path: P) -> Result<()> {
245 use safetensors::tensor::SafeTensors;
246 use std::collections::HashMap;
247
248 let mut tensors: HashMap<String, Vec<u8>> = HashMap::new();
249 let mut metadata: HashMap<String, String> = HashMap::new();
250
251 for (name, entry) in state_dict.entries() {
252 let data_bytes: Vec<u8> = entry
253 .data
254 .values
255 .iter()
256 .flat_map(|f| f.to_le_bytes())
257 .collect();
258 tensors.insert(name.clone(), data_bytes);
259 metadata.insert(format!("{}.shape", name), format!("{:?}", entry.data.shape));
260 }
261
262 let bytes =
264 safetensors::serialize(&tensors, &Some(metadata)).map_err(|e| Error::InvalidOperation {
265 message: e.to_string(),
266 })?;
267
268 std::fs::write(path, bytes).map_err(|e| Error::InvalidOperation {
269 message: e.to_string(),
270 })
271}
272
273#[cfg(feature = "safetensors")]
274fn load_safetensors<P: AsRef<Path>>(path: P) -> Result<StateDict> {
275 let bytes = std::fs::read(path).map_err(|e| Error::InvalidOperation {
276 message: e.to_string(),
277 })?;
278
279 let tensors =
280 safetensors::SafeTensors::deserialize(&bytes).map_err(|e| Error::InvalidOperation {
281 message: e.to_string(),
282 })?;
283
284 let mut state_dict = StateDict::new();
285
286 for (name, tensor) in tensors.tensors() {
287 let data = tensor.data();
288 let shape: Vec<usize> = tensor.shape().to_vec();
289
290 let values: Vec<f32> = data
292 .chunks(4)
293 .map(|chunk| {
294 let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
295 f32::from_le_bytes(bytes)
296 })
297 .collect();
298
299 state_dict.insert(name.to_string(), TensorData { shape, values });
300 }
301
302 Ok(state_dict)
303}
304
305#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_format_detection() {
315 assert_eq!(detect_format("model.axonml"), Format::Axonml);
316 assert_eq!(detect_format("model.json"), Format::Json);
317 assert_eq!(detect_format("model.safetensors"), Format::SafeTensors);
318 assert_eq!(detect_format("model.bin"), Format::Axonml); }
320
321 #[test]
322 fn test_state_dict_creation() {
323 let state_dict = StateDict::new();
324 assert!(state_dict.is_empty());
325 assert_eq!(state_dict.len(), 0);
326 }
327
328 #[test]
329 fn test_state_dict_insert_get() {
330 let mut state_dict = StateDict::new();
331 let data = TensorData {
332 shape: vec![2, 3],
333 values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
334 };
335
336 state_dict.insert("layer.weight".to_string(), data);
337
338 assert_eq!(state_dict.len(), 1);
339 assert!(state_dict.contains("layer.weight"));
340
341 let retrieved = state_dict.get("layer.weight").unwrap();
342 assert_eq!(retrieved.data.shape, vec![2, 3]);
343 }
344
345 #[test]
346 fn test_tensor_data_to_tensor() {
347 let data = TensorData {
348 shape: vec![2, 2],
349 values: vec![1.0, 2.0, 3.0, 4.0],
350 };
351
352 let tensor = data.to_tensor().unwrap();
353 assert_eq!(tensor.shape(), &[2, 2]);
354 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
355 }
356}