1#![warn(missing_docs)]
18#![warn(clippy::all)]
19#![warn(clippy::pedantic)]
20#![allow(clippy::cast_possible_truncation)]
22#![allow(clippy::cast_sign_loss)]
23#![allow(clippy::cast_precision_loss)]
24#![allow(clippy::cast_possible_wrap)]
25#![allow(clippy::missing_errors_doc)]
26#![allow(clippy::missing_panics_doc)]
27#![allow(clippy::must_use_candidate)]
28#![allow(clippy::module_name_repetitions)]
29#![allow(clippy::similar_names)]
30#![allow(clippy::many_single_char_names)]
31#![allow(clippy::too_many_arguments)]
32#![allow(clippy::doc_markdown)]
33#![allow(clippy::cast_lossless)]
34#![allow(clippy::needless_pass_by_value)]
35#![allow(clippy::redundant_closure_for_method_calls)]
36#![allow(clippy::uninlined_format_args)]
37#![allow(clippy::ptr_arg)]
38#![allow(clippy::return_self_not_must_use)]
39#![allow(clippy::not_unsafe_ptr_arg_deref)]
40#![allow(clippy::items_after_statements)]
41#![allow(clippy::unreadable_literal)]
42#![allow(clippy::if_same_then_else)]
43#![allow(clippy::needless_range_loop)]
44#![allow(clippy::trivially_copy_pass_by_ref)]
45#![allow(clippy::unnecessary_wraps)]
46#![allow(clippy::match_same_arms)]
47#![allow(clippy::unused_self)]
48#![allow(clippy::too_many_lines)]
49#![allow(clippy::single_match_else)]
50#![allow(clippy::fn_params_excessive_bools)]
51#![allow(clippy::struct_excessive_bools)]
52#![allow(clippy::format_push_string)]
53#![allow(clippy::erasing_op)]
54#![allow(clippy::type_repetition_in_bounds)]
55#![allow(clippy::iter_without_into_iter)]
56#![allow(clippy::should_implement_trait)]
57#![allow(clippy::use_debug)]
58#![allow(clippy::case_sensitive_file_extension_comparisons)]
59#![allow(clippy::large_enum_variant)]
60#![allow(clippy::panic)]
61#![allow(clippy::struct_field_names)]
62#![allow(clippy::missing_fields_in_debug)]
63#![allow(clippy::upper_case_acronyms)]
64#![allow(clippy::assigning_clones)]
65#![allow(clippy::option_if_let_else)]
66#![allow(clippy::manual_let_else)]
67#![allow(clippy::explicit_iter_loop)]
68#![allow(clippy::default_trait_access)]
69#![allow(clippy::only_used_in_recursion)]
70#![allow(clippy::manual_clamp)]
71#![allow(clippy::ref_option)]
72#![allow(clippy::multiple_bound_locations)]
73#![allow(clippy::comparison_chain)]
74#![allow(clippy::manual_assert)]
75#![allow(clippy::unnecessary_debug_formatting)]
76
77mod checkpoint;
82mod convert;
83mod format;
84mod state_dict;
85
86pub use checkpoint::{Checkpoint, CheckpointBuilder, TrainingState};
91pub use convert::{
92 OnnxOpType, convert_from_pytorch, from_onnx_shape, from_pytorch_key, pytorch_layer_mapping,
93 to_onnx_shape, to_pytorch_key, transpose_linear_weights,
94};
95pub use format::{Format, detect_format, detect_format_from_bytes};
96pub use state_dict::{StateDict, StateDictEntry, TensorData};
97
98use axonml_core::{Error, Result};
103use axonml_nn::Module;
104use std::fs::File;
105use std::io::{BufReader, BufWriter, Read, Write};
106use std::path::Path;
107
108pub fn save_model<M: Module, P: AsRef<Path>>(model: &M, path: P) -> Result<()> {
116 let path = path.as_ref();
117 let format = detect_format(path);
118 let state_dict = StateDict::from_module(model);
119
120 save_state_dict(&state_dict, path, format)
121}
122
123pub fn save_state_dict<P: AsRef<Path>>(
125 state_dict: &StateDict,
126 path: P,
127 format: Format,
128) -> Result<()> {
129 let path = path.as_ref();
130 let file = File::create(path).map_err(|e| Error::InvalidOperation {
131 message: e.to_string(),
132 })?;
133 let mut writer = BufWriter::new(file);
134
135 match format {
136 Format::Axonml => {
137 let encoded = bincode::serialize(state_dict).map_err(|e| Error::InvalidOperation {
138 message: e.to_string(),
139 })?;
140 writer
141 .write_all(&encoded)
142 .map_err(|e| Error::InvalidOperation {
143 message: e.to_string(),
144 })?;
145 }
146 Format::Json => {
147 serde_json::to_writer_pretty(&mut writer, state_dict).map_err(|e| {
148 Error::InvalidOperation {
149 message: e.to_string(),
150 }
151 })?;
152 }
153 #[cfg(feature = "safetensors")]
154 Format::SafeTensors => {
155 save_safetensors(state_dict, path)?;
156 }
157 #[cfg(not(feature = "safetensors"))]
158 Format::SafeTensors => {
159 return Err(Error::InvalidOperation {
160 message: "SafeTensors format requires 'safetensors' feature".to_string(),
161 });
162 }
163 }
164
165 Ok(())
166}
167
168pub fn load_state_dict<P: AsRef<Path>>(path: P) -> Result<StateDict> {
170 let path = path.as_ref();
171 let format = detect_format(path);
172
173 let file = File::open(path).map_err(|e| Error::InvalidOperation {
174 message: e.to_string(),
175 })?;
176 let mut reader = BufReader::new(file);
177
178 match format {
179 Format::Axonml => {
180 let mut bytes = Vec::new();
181 reader
182 .read_to_end(&mut bytes)
183 .map_err(|e| Error::InvalidOperation {
184 message: e.to_string(),
185 })?;
186 bincode::deserialize(&bytes).map_err(|e| Error::InvalidOperation {
187 message: e.to_string(),
188 })
189 }
190 Format::Json => serde_json::from_reader(reader).map_err(|e| Error::InvalidOperation {
191 message: e.to_string(),
192 }),
193 #[cfg(feature = "safetensors")]
194 Format::SafeTensors => load_safetensors(path),
195 #[cfg(not(feature = "safetensors"))]
196 Format::SafeTensors => Err(Error::InvalidOperation {
197 message: "SafeTensors format requires 'safetensors' feature".to_string(),
198 }),
199 }
200}
201
202pub fn save_checkpoint<P: AsRef<Path>>(checkpoint: &Checkpoint, path: P) -> Result<()> {
204 let path = path.as_ref();
205 let file = File::create(path).map_err(|e| Error::InvalidOperation {
206 message: e.to_string(),
207 })?;
208 let writer = BufWriter::new(file);
209
210 bincode::serialize_into(writer, checkpoint).map_err(|e| Error::InvalidOperation {
211 message: e.to_string(),
212 })
213}
214
215pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> Result<Checkpoint> {
217 let path = path.as_ref();
218 let file = File::open(path).map_err(|e| Error::InvalidOperation {
219 message: e.to_string(),
220 })?;
221 let reader = BufReader::new(file);
222
223 bincode::deserialize_from(reader).map_err(|e| Error::InvalidOperation {
224 message: e.to_string(),
225 })
226}
227
228#[cfg(feature = "safetensors")]
233fn save_safetensors<P: AsRef<Path>>(state_dict: &StateDict, path: P) -> Result<()> {
234 use safetensors::tensor::SafeTensors;
235 use std::collections::HashMap;
236
237 let mut tensors: HashMap<String, Vec<u8>> = HashMap::new();
238 let mut metadata: HashMap<String, String> = HashMap::new();
239
240 for (name, entry) in state_dict.entries() {
241 let data_bytes: Vec<u8> = entry
242 .data
243 .values
244 .iter()
245 .flat_map(|f| f.to_le_bytes())
246 .collect();
247 tensors.insert(name.clone(), data_bytes);
248 metadata.insert(format!("{}.shape", name), format!("{:?}", entry.data.shape));
249 }
250
251 let bytes =
253 safetensors::serialize(&tensors, &Some(metadata)).map_err(|e| Error::InvalidOperation {
254 message: e.to_string(),
255 })?;
256
257 std::fs::write(path, bytes).map_err(|e| Error::InvalidOperation {
258 message: e.to_string(),
259 })
260}
261
262#[cfg(feature = "safetensors")]
263fn load_safetensors<P: AsRef<Path>>(path: P) -> Result<StateDict> {
264 let bytes = std::fs::read(path).map_err(|e| Error::InvalidOperation {
265 message: e.to_string(),
266 })?;
267
268 let tensors =
269 safetensors::SafeTensors::deserialize(&bytes).map_err(|e| Error::InvalidOperation {
270 message: e.to_string(),
271 })?;
272
273 let mut state_dict = StateDict::new();
274
275 for (name, tensor) in tensors.tensors() {
276 let data = tensor.data();
277 let shape: Vec<usize> = tensor.shape().to_vec();
278
279 let values: Vec<f32> = data
281 .chunks(4)
282 .map(|chunk| {
283 let bytes: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
284 f32::from_le_bytes(bytes)
285 })
286 .collect();
287
288 state_dict.insert(name.to_string(), TensorData { shape, values });
289 }
290
291 Ok(state_dict)
292}
293
294#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_format_detection() {
304 assert_eq!(detect_format("model.axonml"), Format::Axonml);
305 assert_eq!(detect_format("model.json"), Format::Json);
306 assert_eq!(detect_format("model.safetensors"), Format::SafeTensors);
307 assert_eq!(detect_format("model.bin"), Format::Axonml); }
309
310 #[test]
311 fn test_state_dict_creation() {
312 let state_dict = StateDict::new();
313 assert!(state_dict.is_empty());
314 assert_eq!(state_dict.len(), 0);
315 }
316
317 #[test]
318 fn test_state_dict_insert_get() {
319 let mut state_dict = StateDict::new();
320 let data = TensorData {
321 shape: vec![2, 3],
322 values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
323 };
324
325 state_dict.insert("layer.weight".to_string(), data);
326
327 assert_eq!(state_dict.len(), 1);
328 assert!(state_dict.contains("layer.weight"));
329
330 let retrieved = state_dict.get("layer.weight").unwrap();
331 assert_eq!(retrieved.data.shape, vec![2, 3]);
332 }
333
334 #[test]
335 fn test_tensor_data_to_tensor() {
336 let data = TensorData {
337 shape: vec![2, 2],
338 values: vec![1.0, 2.0, 3.0, 4.0],
339 };
340
341 let tensor = data.to_tensor().unwrap();
342 assert_eq!(tensor.shape(), &[2, 2]);
343 assert_eq!(tensor.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
344 }
345}