1use std::collections::HashMap;
38use std::fs::File;
39use std::io::{BufReader, Read};
40use std::path::Path;
41
42use anyhow::{bail, Context, Result};
43use serde::{Deserialize, Serialize};
44
45use crate::safetensors_support::SafetensorsWriter;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct PyTorchCheckpoint {
53 pub state_dict: StateDict,
55
56 pub optimizer_state: Option<OptimizerState>,
58
59 pub epoch: Option<usize>,
61
62 pub loss_history: Option<Vec<f32>>,
64
65 pub metadata: HashMap<String, String>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct StateDict {
72 pub tensors: HashMap<String, TensorData>,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct TensorData {
79 pub shape: Vec<usize>,
81
82 pub dtype: String,
84
85 pub data: Vec<u8>,
87
88 pub requires_grad: bool,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct OptimizerState {
95 pub optimizer_type: String,
97
98 pub param_state: HashMap<String, ParamState>,
100
101 pub hyperparameters: HashMap<String, f64>,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct ParamState {
108 pub momentum: Option<Vec<u8>>,
110
111 pub velocity: Option<Vec<u8>>,
113
114 pub step: Option<usize>,
116
117 pub custom: HashMap<String, Vec<u8>>,
119}
120
121#[derive(Debug, Clone)]
123pub struct CheckpointMetadata {
124 pub total_parameters: usize,
126
127 pub layer_names: Vec<String>,
129
130 pub total_size_bytes: usize,
132
133 pub dtypes: HashMap<String, usize>, pub has_optimizer_state: bool,
138
139 pub epoch: Option<usize>,
141}
142
143impl PyTorchCheckpoint {
144 #[allow(dead_code)]
151 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
152 let file = File::open(path.as_ref()).context("Failed to open checkpoint file")?;
153 let mut reader = BufReader::new(file);
154
155 let mut bytes = Vec::new();
157 reader
158 .read_to_end(&mut bytes)
159 .context("Failed to read checkpoint file")?;
160
161 Self::from_pickle_bytes(&bytes)
163 }
164
165 fn from_pickle_bytes(bytes: &[u8]) -> Result<Self> {
169 let value: serde_pickle::Value = serde_pickle::from_slice(bytes, Default::default())
173 .context("Failed to deserialize pickle data")?;
174
175 Self::parse_pickle_value(value)
177 }
178
179 fn parse_pickle_value(value: serde_pickle::Value) -> Result<Self> {
181 use serde_pickle::{HashableValue, Value};
182
183 let dict = match value {
185 Value::Dict(d) => d,
186 _ => bail!("Expected dictionary at root of checkpoint"),
187 };
188
189 let mut state_dict_tensors = HashMap::new();
190 let mut optimizer_state = None;
191 let mut epoch = None;
192 let mut loss_history = None;
193 let mut metadata = HashMap::new();
194
195 let has_state_dict_key = dict.iter().any(|(k, _)| {
197 matches!(k, HashableValue::String(ref s) if s == "state_dict" || s == "model_state_dict")
198 });
199
200 for (key, val) in &dict {
202 let key_str = match key {
203 HashableValue::String(s) => s.clone(),
204 HashableValue::Bytes(b) => String::from_utf8_lossy(b).to_string(),
205 _ => continue,
206 };
207
208 match key_str.as_str() {
209 "state_dict" | "model_state_dict" => {
210 if let Value::Dict(sd) = val {
211 state_dict_tensors = Self::parse_state_dict(sd.clone())?;
212 }
213 }
214 "optimizer_state_dict" | "optimizer" => {
215 optimizer_state = Self::parse_optimizer_state(val.clone()).ok();
216 }
217 "epoch" => {
218 if let Value::I64(e) = val {
219 epoch = Some(*e as usize);
220 }
221 }
222 "loss_history" => {
223 loss_history = Self::parse_loss_history(val.clone()).ok();
224 }
225 _ => {
226 if let Value::String(s) = val {
228 metadata.insert(key_str, s.clone());
229 }
230 }
231 }
232 }
233
234 if state_dict_tensors.is_empty() && !has_state_dict_key {
236 state_dict_tensors = Self::parse_state_dict(dict)?;
237 }
238
239 Ok(PyTorchCheckpoint {
240 state_dict: StateDict {
241 tensors: state_dict_tensors,
242 },
243 optimizer_state,
244 epoch,
245 loss_history,
246 metadata,
247 })
248 }
249
250 fn parse_state_dict(
252 dict: std::collections::BTreeMap<serde_pickle::HashableValue, serde_pickle::Value>,
253 ) -> Result<HashMap<String, TensorData>> {
254 use serde_pickle::HashableValue;
255
256 let mut tensors = HashMap::new();
257
258 for (key, val) in dict {
259 let key_str = match key {
260 HashableValue::String(s) => s,
261 HashableValue::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
262 _ => continue,
263 };
264
265 if let Ok(tensor_data) = Self::parse_tensor_value(val) {
267 tensors.insert(key_str, tensor_data);
268 }
269 }
270
271 Ok(tensors)
272 }
273
274 fn parse_tensor_value(value: serde_pickle::Value) -> Result<TensorData> {
276 use serde_pickle::{HashableValue, Value};
277
278 match value {
283 Value::Dict(d) => {
284 let mut shape = Vec::new();
286 let mut data = Vec::new();
287 let mut dtype = "float32".to_string();
288 let mut requires_grad = false;
289
290 for (k, v) in d {
291 let key = match k {
292 HashableValue::String(s) => s,
293 HashableValue::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
294 _ => continue,
295 };
296
297 match key.as_str() {
298 "shape" | "size" => {
299 if let Value::List(list) = v {
300 shape = list
301 .into_iter()
302 .filter_map(|v| match v {
303 Value::I64(i) => Some(i as usize),
304 _ => None,
305 })
306 .collect();
307 }
308 }
309 "data" | "storage" => {
310 if let Value::Bytes(b) = v {
311 data = b;
312 }
313 }
314 "dtype" => {
315 if let Value::String(s) = v {
316 dtype = s;
317 }
318 }
319 "requires_grad" => {
320 if let Value::Bool(b) = v {
321 requires_grad = b;
322 }
323 }
324 _ => {}
325 }
326 }
327
328 if !shape.is_empty() && !data.is_empty() {
329 Ok(TensorData {
330 shape,
331 dtype,
332 data,
333 requires_grad,
334 })
335 } else {
336 bail!("Incomplete tensor data")
337 }
338 }
339 Value::Bytes(data) => {
340 Ok(TensorData {
342 shape: vec![data.len() / 4],
343 dtype: "float32".to_string(),
344 data,
345 requires_grad: false,
346 })
347 }
348 _ => bail!("Unsupported tensor value type"),
349 }
350 }
351
352 #[allow(dead_code)]
354 fn parse_optimizer_state(_value: serde_pickle::Value) -> Result<OptimizerState> {
355 Ok(OptimizerState {
357 optimizer_type: "Unknown".to_string(),
358 param_state: HashMap::new(),
359 hyperparameters: HashMap::new(),
360 })
361 }
362
363 #[allow(dead_code)]
365 fn parse_loss_history(value: serde_pickle::Value) -> Result<Vec<f32>> {
366 use serde_pickle::Value;
367
368 match value {
369 Value::List(list) => {
370 let losses = list
371 .into_iter()
372 .filter_map(|v| match v {
373 Value::F64(f) => Some(f as f32),
374 _ => None,
375 })
376 .collect();
377 Ok(losses)
378 }
379 _ => bail!("Expected list for loss history"),
380 }
381 }
382
383 pub fn metadata(&self) -> CheckpointMetadata {
385 let mut total_parameters = 0;
386 let mut layer_names = Vec::new();
387 let mut total_size_bytes = 0;
388 let mut dtypes = HashMap::new();
389
390 for (name, tensor) in &self.state_dict.tensors {
391 layer_names.push(name.clone());
392
393 let num_elements: usize = tensor.shape.iter().product();
394 total_parameters += num_elements;
395
396 total_size_bytes += tensor.data.len();
397
398 *dtypes.entry(tensor.dtype.clone()).or_insert(0) += 1;
399 }
400
401 CheckpointMetadata {
402 total_parameters,
403 layer_names,
404 total_size_bytes,
405 dtypes,
406 has_optimizer_state: self.optimizer_state.is_some(),
407 epoch: self.epoch,
408 }
409 }
410
411 pub fn state_dict(&self) -> &StateDict {
413 &self.state_dict
414 }
415
416 pub fn to_safetensors(&self) -> Result<Vec<u8>> {
420 let mut writer = SafetensorsWriter::new();
421
422 for (name, tensor) in &self.state_dict.tensors {
423 let shape = tensor.shape.clone();
425
426 match tensor.dtype.as_str() {
428 "float32" | "Float" => {
429 if tensor.data.len() % 4 != 0 {
431 bail!("Invalid float32 data length for tensor {}", name);
432 }
433
434 let float_data: Vec<f32> = tensor
435 .data
436 .chunks_exact(4)
437 .map(|chunk| {
438 let bytes: [u8; 4] = chunk.try_into().unwrap();
439 f32::from_le_bytes(bytes)
440 })
441 .collect();
442
443 writer.add_f32(name, shape, &float_data);
444 }
445 "float64" | "Double" => {
446 if tensor.data.len() % 8 != 0 {
447 bail!("Invalid float64 data length for tensor {}", name);
448 }
449
450 let float_data: Vec<f64> = tensor
451 .data
452 .chunks_exact(8)
453 .map(|chunk| {
454 let bytes: [u8; 8] = chunk.try_into().unwrap();
455 f64::from_le_bytes(bytes)
456 })
457 .collect();
458
459 writer.add_f64(name, shape, &float_data);
460 }
461 _ => {
462 bail!("Unsupported dtype: {}", tensor.dtype);
463 }
464 }
465 }
466
467 writer
468 .serialize()
469 .context("Failed to serialize to safetensors")
470 }
471
472 #[allow(dead_code)]
476 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
477 let bytes = self.to_pickle_bytes()?;
478 std::fs::write(path, bytes).context("Failed to write checkpoint file")?;
479 Ok(())
480 }
481
482 fn to_pickle_bytes(&self) -> Result<Vec<u8>> {
484 use serde_pickle::ser;
485
486 #[derive(Serialize)]
491 struct CheckpointSer {
492 state_dict: HashMap<String, TensorSer>,
493 #[serde(skip_serializing_if = "Option::is_none")]
494 epoch: Option<usize>,
495 #[serde(skip_serializing_if = "Option::is_none")]
496 loss_history: Option<Vec<f32>>,
497 metadata: HashMap<String, String>,
498 }
499
500 #[derive(Serialize)]
501 struct TensorSer {
502 shape: Vec<usize>,
503 dtype: String,
504 data_len: usize,
505 }
506
507 let state_dict_ser: HashMap<String, TensorSer> = self
508 .state_dict
509 .tensors
510 .iter()
511 .map(|(name, tensor)| {
512 (
513 name.clone(),
514 TensorSer {
515 shape: tensor.shape.clone(),
516 dtype: tensor.dtype.clone(),
517 data_len: tensor.data.len(),
518 },
519 )
520 })
521 .collect();
522
523 let checkpoint_ser = CheckpointSer {
524 state_dict: state_dict_ser,
525 epoch: self.epoch,
526 loss_history: self.loss_history.clone(),
527 metadata: self.metadata.clone(),
528 };
529
530 ser::to_vec(&checkpoint_ser, Default::default()).context("Failed to serialize to pickle")
532 }
533
534 #[allow(dead_code)]
538 fn tensor_to_pickle_value(_tensor: &TensorData) -> HashMap<String, String> {
539 HashMap::new()
542 }
543
544 pub fn new() -> Self {
546 PyTorchCheckpoint {
547 state_dict: StateDict {
548 tensors: HashMap::new(),
549 },
550 optimizer_state: None,
551 epoch: None,
552 loss_history: None,
553 metadata: HashMap::new(),
554 }
555 }
556
557 pub fn add_tensor(&mut self, name: String, tensor: TensorData) {
559 self.state_dict.tensors.insert(name, tensor);
560 }
561
562 pub fn set_epoch(&mut self, epoch: usize) {
564 self.epoch = Some(epoch);
565 }
566
567 pub fn add_metadata(&mut self, key: String, value: String) {
569 self.metadata.insert(key, value);
570 }
571}
572
573impl Default for PyTorchCheckpoint {
574 fn default() -> Self {
575 Self::new()
576 }
577}
578
579impl StateDict {
580 pub fn get(&self, name: &str) -> Option<&TensorData> {
582 self.tensors.get(name)
583 }
584
585 pub fn iter(&self) -> impl Iterator<Item = (&String, &TensorData)> {
587 self.tensors.iter()
588 }
589
590 pub fn len(&self) -> usize {
592 self.tensors.len()
593 }
594
595 pub fn is_empty(&self) -> bool {
597 self.tensors.is_empty()
598 }
599}
600
601impl TensorData {
602 pub fn from_f32(shape: Vec<usize>, data: &[f32]) -> Self {
604 let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
605
606 TensorData {
607 shape,
608 dtype: "float32".to_string(),
609 data: bytes,
610 requires_grad: false,
611 }
612 }
613
614 pub fn from_f64(shape: Vec<usize>, data: &[f64]) -> Self {
616 let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
617
618 TensorData {
619 shape,
620 dtype: "float64".to_string(),
621 data: bytes,
622 requires_grad: false,
623 }
624 }
625
626 pub fn as_f32(&self) -> Result<Vec<f32>> {
628 if self.dtype != "float32" && self.dtype != "Float" {
629 bail!("Expected float32 dtype, got {}", self.dtype);
630 }
631
632 if !self.data.len().is_multiple_of(4) {
633 bail!("Invalid data length for float32");
634 }
635
636 Ok(self
637 .data
638 .chunks_exact(4)
639 .map(|chunk| {
640 let bytes: [u8; 4] = chunk.try_into().unwrap();
641 f32::from_le_bytes(bytes)
642 })
643 .collect())
644 }
645
646 pub fn as_f64(&self) -> Result<Vec<f64>> {
648 if self.dtype != "float64" && self.dtype != "Double" {
649 bail!("Expected float64 dtype, got {}", self.dtype);
650 }
651
652 if !self.data.len().is_multiple_of(8) {
653 bail!("Invalid data length for float64");
654 }
655
656 Ok(self
657 .data
658 .chunks_exact(8)
659 .map(|chunk| {
660 let bytes: [u8; 8] = chunk.try_into().unwrap();
661 f64::from_le_bytes(bytes)
662 })
663 .collect())
664 }
665
666 pub fn num_elements(&self) -> usize {
668 self.shape.iter().product()
669 }
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675
676 #[test]
677 fn test_checkpoint_creation() {
678 let mut checkpoint = PyTorchCheckpoint::new();
679
680 let tensor = TensorData::from_f32(vec![2, 3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
682 checkpoint.add_tensor("layer1.weight".to_string(), tensor);
683
684 checkpoint.set_epoch(10);
685 checkpoint.add_metadata("model_type".to_string(), "CNN".to_string());
686
687 assert_eq!(checkpoint.state_dict().len(), 1);
688 assert_eq!(checkpoint.epoch, Some(10));
689 assert_eq!(checkpoint.metadata.get("model_type").unwrap(), "CNN");
690 }
691
692 #[test]
693 fn test_tensor_data_f32() {
694 let data = vec![1.0f32, 2.0, 3.0, 4.0];
695 let tensor = TensorData::from_f32(vec![2, 2], &data);
696
697 assert_eq!(tensor.shape, vec![2, 2]);
698 assert_eq!(tensor.dtype, "float32");
699 assert_eq!(tensor.num_elements(), 4);
700
701 let recovered = tensor.as_f32().unwrap();
702 assert_eq!(recovered, data);
703 }
704
705 #[test]
706 fn test_tensor_data_f64() {
707 let data = vec![1.0f64, 2.0, 3.0, 4.0];
708 let tensor = TensorData::from_f64(vec![2, 2], &data);
709
710 assert_eq!(tensor.shape, vec![2, 2]);
711 assert_eq!(tensor.dtype, "float64");
712
713 let recovered = tensor.as_f64().unwrap();
714 assert_eq!(recovered, data);
715 }
716
717 #[test]
718 fn test_metadata_extraction() {
719 let mut checkpoint = PyTorchCheckpoint::new();
720
721 checkpoint.add_tensor(
722 "layer1.weight".to_string(),
723 TensorData::from_f32(vec![10, 10], &vec![0.0; 100]),
724 );
725 checkpoint.add_tensor(
726 "layer1.bias".to_string(),
727 TensorData::from_f32(vec![10], &[0.0; 10]),
728 );
729 checkpoint.add_tensor(
730 "layer2.weight".to_string(),
731 TensorData::from_f64(vec![5, 10], &vec![0.0; 50]),
732 );
733
734 let metadata = checkpoint.metadata();
735
736 assert_eq!(metadata.total_parameters, 160);
737 assert_eq!(metadata.layer_names.len(), 3);
738 assert_eq!(metadata.dtypes.get("float32"), Some(&2));
739 assert_eq!(metadata.dtypes.get("float64"), Some(&1));
740 }
741
742 #[test]
743 fn test_state_dict_access() {
744 let mut checkpoint = PyTorchCheckpoint::new();
745
746 let tensor = TensorData::from_f32(vec![3], &[1.0, 2.0, 3.0]);
747 checkpoint.add_tensor("test".to_string(), tensor);
748
749 let state_dict = checkpoint.state_dict();
750 assert_eq!(state_dict.len(), 1);
751 assert!(!state_dict.is_empty());
752
753 let retrieved = state_dict.get("test").unwrap();
754 assert_eq!(retrieved.shape, vec![3]);
755 }
756
757 #[test]
758 fn test_checkpoint_serialization() -> Result<()> {
759 let mut checkpoint = PyTorchCheckpoint::new();
760
761 checkpoint.add_tensor(
762 "weight".to_string(),
763 TensorData::from_f32(vec![2, 2], &[1.0, 2.0, 3.0, 4.0]),
764 );
765 checkpoint.set_epoch(5);
766 checkpoint.add_metadata("arch".to_string(), "ResNet".to_string());
767
768 let bytes = checkpoint.to_pickle_bytes()?;
770 assert!(!bytes.is_empty());
771
772 Ok(())
777 }
778
779 #[test]
780 fn test_to_safetensors() -> Result<()> {
781 let mut checkpoint = PyTorchCheckpoint::new();
782
783 checkpoint.add_tensor(
784 "layer1.weight".to_string(),
785 TensorData::from_f32(vec![3, 3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]),
786 );
787 checkpoint.add_tensor(
788 "layer1.bias".to_string(),
789 TensorData::from_f32(vec![3], &[0.1, 0.2, 0.3]),
790 );
791
792 let safetensors_bytes = checkpoint.to_safetensors()?;
793 assert!(!safetensors_bytes.is_empty());
794
795 Ok(())
796 }
797}