ipfrs_tensorlogic/
pytorch_checkpoint.rs

1//! PyTorch model checkpoint support for ipfrs-tensorlogic.
2//!
3//! This module provides functionality to load and work with PyTorch model checkpoints
4//! (.pt/.pth files). PyTorch checkpoints are Python pickle files containing state_dict
5//! structures with model weights and optionally optimizer state.
6//!
7//! # Safety and Security
8//!
9//! Python pickle format can execute arbitrary code during deserialization. This module
10//! provides a safe subset of pickle deserialization focused on tensor data structures.
11//! For maximum security, consider converting PyTorch checkpoints to Safetensors format.
12//!
13//! # Example
14//!
15//! ```rust,no_run
16//! use ipfrs_tensorlogic::pytorch_checkpoint::{PyTorchCheckpoint, CheckpointMetadata};
17//! use std::path::Path;
18//!
19//! # fn main() -> anyhow::Result<()> {
20//! // Load a PyTorch checkpoint
21//! let checkpoint = PyTorchCheckpoint::load(Path::new("model.pt"))?;
22//!
23//! // Extract metadata
24//! let metadata = checkpoint.metadata();
25//! println!("Model has {} parameters", metadata.total_parameters);
26//! println!("Layers: {:?}", metadata.layer_names);
27//!
28//! // Get state dict
29//! let state_dict = checkpoint.state_dict();
30//! for (key, tensor_info) in &state_dict.tensors {
31//!     println!("{}: {:?}", key, tensor_info.shape);
32//! }
33//! # Ok(())
34//! # }
35//! ```
36
37use std::collections::HashMap;
38use std::fs::File;
39use std::io::{BufReader, Read};
40use std::path::Path;
41
42use anyhow::{bail, Context, Result};
43use serde::{Deserialize, Serialize};
44
45use crate::safetensors_support::SafetensorsWriter;
46
47/// PyTorch checkpoint structure.
48///
49/// Contains the model state_dict and optional optimizer state, epoch information,
50/// and other training metadata commonly saved in PyTorch checkpoints.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct PyTorchCheckpoint {
53    /// Model state dictionary
54    pub state_dict: StateDict,
55
56    /// Optimizer state (if saved)
57    pub optimizer_state: Option<OptimizerState>,
58
59    /// Training epoch (if saved)
60    pub epoch: Option<usize>,
61
62    /// Training loss history (if saved)
63    pub loss_history: Option<Vec<f32>>,
64
65    /// Custom metadata
66    pub metadata: HashMap<String, String>,
67}
68
69/// Model state dictionary containing named tensors.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct StateDict {
72    /// Map from layer/parameter name to tensor information
73    pub tensors: HashMap<String, TensorData>,
74}
75
76/// Tensor data with shape and values.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct TensorData {
79    /// Tensor shape (dimensions)
80    pub shape: Vec<usize>,
81
82    /// Data type identifier
83    pub dtype: String,
84
85    /// Flattened tensor values (stored as bytes)
86    pub data: Vec<u8>,
87
88    /// Whether this tensor requires gradient
89    pub requires_grad: bool,
90}
91
92/// Optimizer state containing parameter state and hyperparameters.
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct OptimizerState {
95    /// Optimizer name (e.g., "Adam", "SGD")
96    pub optimizer_type: String,
97
98    /// Per-parameter state (momentum buffers, etc.)
99    pub param_state: HashMap<String, ParamState>,
100
101    /// Global optimizer hyperparameters
102    pub hyperparameters: HashMap<String, f64>,
103}
104
105/// Per-parameter optimizer state (momentum, velocity, etc.).
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct ParamState {
108    /// Momentum buffer (for SGD with momentum, Adam, etc.)
109    pub momentum: Option<Vec<u8>>,
110
111    /// Velocity buffer (for Adam, RMSprop, etc.)
112    pub velocity: Option<Vec<u8>>,
113
114    /// Step count (for Adam)
115    pub step: Option<usize>,
116
117    /// Custom state fields
118    pub custom: HashMap<String, Vec<u8>>,
119}
120
121/// Checkpoint metadata for quick inspection.
122#[derive(Debug, Clone)]
123pub struct CheckpointMetadata {
124    /// Total number of parameters
125    pub total_parameters: usize,
126
127    /// Layer/parameter names
128    pub layer_names: Vec<String>,
129
130    /// Total size in bytes
131    pub total_size_bytes: usize,
132
133    /// Data types used
134    pub dtypes: HashMap<String, usize>, // dtype -> count
135
136    /// Whether optimizer state is present
137    pub has_optimizer_state: bool,
138
139    /// Current epoch (if available)
140    pub epoch: Option<usize>,
141}
142
143impl PyTorchCheckpoint {
144    /// Load a PyTorch checkpoint from a file.
145    ///
146    /// # Security Note
147    ///
148    /// This uses pickle deserialization which can be unsafe with untrusted files.
149    /// Only load checkpoints from trusted sources.
150    #[allow(dead_code)]
151    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
152        let file = File::open(path.as_ref()).context("Failed to open checkpoint file")?;
153        let mut reader = BufReader::new(file);
154
155        // Read all bytes
156        let mut bytes = Vec::new();
157        reader
158            .read_to_end(&mut bytes)
159            .context("Failed to read checkpoint file")?;
160
161        // Try to deserialize as pickle
162        Self::from_pickle_bytes(&bytes)
163    }
164
165    /// Deserialize checkpoint from pickle bytes.
166    ///
167    /// This provides a safe subset of pickle deserialization focused on tensor data.
168    fn from_pickle_bytes(bytes: &[u8]) -> Result<Self> {
169        // Attempt to deserialize the pickle data
170        // Note: This is a simplified version. Real PyTorch checkpoints may need
171        // more sophisticated handling of numpy arrays and torch tensors.
172        let value: serde_pickle::Value = serde_pickle::from_slice(bytes, Default::default())
173            .context("Failed to deserialize pickle data")?;
174
175        // Parse the pickle value into our checkpoint structure
176        Self::parse_pickle_value(value)
177    }
178
179    /// Parse a pickle value into a checkpoint structure.
180    fn parse_pickle_value(value: serde_pickle::Value) -> Result<Self> {
181        use serde_pickle::{HashableValue, Value};
182
183        // PyTorch checkpoints are typically dictionaries
184        let dict = match value {
185            Value::Dict(d) => d,
186            _ => bail!("Expected dictionary at root of checkpoint"),
187        };
188
189        let mut state_dict_tensors = HashMap::new();
190        let mut optimizer_state = None;
191        let mut epoch = None;
192        let mut loss_history = None;
193        let mut metadata = HashMap::new();
194
195        // Check if dict contains state_dict key
196        let has_state_dict_key = dict.iter().any(|(k, _)| {
197            matches!(k, HashableValue::String(ref s) if s == "state_dict" || s == "model_state_dict")
198        });
199
200        // Parse dictionary entries
201        for (key, val) in &dict {
202            let key_str = match key {
203                HashableValue::String(s) => s.clone(),
204                HashableValue::Bytes(b) => String::from_utf8_lossy(b).to_string(),
205                _ => continue,
206            };
207
208            match key_str.as_str() {
209                "state_dict" | "model_state_dict" => {
210                    if let Value::Dict(sd) = val {
211                        state_dict_tensors = Self::parse_state_dict(sd.clone())?;
212                    }
213                }
214                "optimizer_state_dict" | "optimizer" => {
215                    optimizer_state = Self::parse_optimizer_state(val.clone()).ok();
216                }
217                "epoch" => {
218                    if let Value::I64(e) = val {
219                        epoch = Some(*e as usize);
220                    }
221                }
222                "loss_history" => {
223                    loss_history = Self::parse_loss_history(val.clone()).ok();
224                }
225                _ => {
226                    // Store as metadata
227                    if let Value::String(s) = val {
228                        metadata.insert(key_str, s.clone());
229                    }
230                }
231            }
232        }
233
234        // If no explicit state_dict key, assume the whole dict is the state_dict
235        if state_dict_tensors.is_empty() && !has_state_dict_key {
236            state_dict_tensors = Self::parse_state_dict(dict)?;
237        }
238
239        Ok(PyTorchCheckpoint {
240            state_dict: StateDict {
241                tensors: state_dict_tensors,
242            },
243            optimizer_state,
244            epoch,
245            loss_history,
246            metadata,
247        })
248    }
249
250    /// Parse state_dict from pickle dictionary.
251    fn parse_state_dict(
252        dict: std::collections::BTreeMap<serde_pickle::HashableValue, serde_pickle::Value>,
253    ) -> Result<HashMap<String, TensorData>> {
254        use serde_pickle::HashableValue;
255
256        let mut tensors = HashMap::new();
257
258        for (key, val) in dict {
259            let key_str = match key {
260                HashableValue::String(s) => s,
261                HashableValue::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
262                _ => continue,
263            };
264
265            // Try to parse tensor data
266            if let Ok(tensor_data) = Self::parse_tensor_value(val) {
267                tensors.insert(key_str, tensor_data);
268            }
269        }
270
271        Ok(tensors)
272    }
273
274    /// Parse a tensor value from pickle.
275    fn parse_tensor_value(value: serde_pickle::Value) -> Result<TensorData> {
276        use serde_pickle::{HashableValue, Value};
277
278        // This is simplified - real PyTorch tensors are more complex
279        // In practice, you'd need to handle torch.Tensor objects which contain
280        // references to storage objects
281
282        match value {
283            Value::Dict(d) => {
284                // Look for tensor-like dictionary structure
285                let mut shape = Vec::new();
286                let mut data = Vec::new();
287                let mut dtype = "float32".to_string();
288                let mut requires_grad = false;
289
290                for (k, v) in d {
291                    let key = match k {
292                        HashableValue::String(s) => s,
293                        HashableValue::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
294                        _ => continue,
295                    };
296
297                    match key.as_str() {
298                        "shape" | "size" => {
299                            if let Value::List(list) = v {
300                                shape = list
301                                    .into_iter()
302                                    .filter_map(|v| match v {
303                                        Value::I64(i) => Some(i as usize),
304                                        _ => None,
305                                    })
306                                    .collect();
307                            }
308                        }
309                        "data" | "storage" => {
310                            if let Value::Bytes(b) = v {
311                                data = b;
312                            }
313                        }
314                        "dtype" => {
315                            if let Value::String(s) = v {
316                                dtype = s;
317                            }
318                        }
319                        "requires_grad" => {
320                            if let Value::Bool(b) = v {
321                                requires_grad = b;
322                            }
323                        }
324                        _ => {}
325                    }
326                }
327
328                if !shape.is_empty() && !data.is_empty() {
329                    Ok(TensorData {
330                        shape,
331                        dtype,
332                        data,
333                        requires_grad,
334                    })
335                } else {
336                    bail!("Incomplete tensor data")
337                }
338            }
339            Value::Bytes(data) => {
340                // Raw bytes - assume 1D float32 array
341                Ok(TensorData {
342                    shape: vec![data.len() / 4],
343                    dtype: "float32".to_string(),
344                    data,
345                    requires_grad: false,
346                })
347            }
348            _ => bail!("Unsupported tensor value type"),
349        }
350    }
351
352    /// Parse optimizer state from pickle value.
353    #[allow(dead_code)]
354    fn parse_optimizer_state(_value: serde_pickle::Value) -> Result<OptimizerState> {
355        // Simplified - would need full implementation for real use
356        Ok(OptimizerState {
357            optimizer_type: "Unknown".to_string(),
358            param_state: HashMap::new(),
359            hyperparameters: HashMap::new(),
360        })
361    }
362
363    /// Parse loss history from pickle value.
364    #[allow(dead_code)]
365    fn parse_loss_history(value: serde_pickle::Value) -> Result<Vec<f32>> {
366        use serde_pickle::Value;
367
368        match value {
369            Value::List(list) => {
370                let losses = list
371                    .into_iter()
372                    .filter_map(|v| match v {
373                        Value::F64(f) => Some(f as f32),
374                        _ => None,
375                    })
376                    .collect();
377                Ok(losses)
378            }
379            _ => bail!("Expected list for loss history"),
380        }
381    }
382
383    /// Get checkpoint metadata.
384    pub fn metadata(&self) -> CheckpointMetadata {
385        let mut total_parameters = 0;
386        let mut layer_names = Vec::new();
387        let mut total_size_bytes = 0;
388        let mut dtypes = HashMap::new();
389
390        for (name, tensor) in &self.state_dict.tensors {
391            layer_names.push(name.clone());
392
393            let num_elements: usize = tensor.shape.iter().product();
394            total_parameters += num_elements;
395
396            total_size_bytes += tensor.data.len();
397
398            *dtypes.entry(tensor.dtype.clone()).or_insert(0) += 1;
399        }
400
401        CheckpointMetadata {
402            total_parameters,
403            layer_names,
404            total_size_bytes,
405            dtypes,
406            has_optimizer_state: self.optimizer_state.is_some(),
407            epoch: self.epoch,
408        }
409    }
410
411    /// Get reference to state dict.
412    pub fn state_dict(&self) -> &StateDict {
413        &self.state_dict
414    }
415
416    /// Convert checkpoint to Safetensors format.
417    ///
418    /// This provides a safe, efficient format for storing model weights.
419    pub fn to_safetensors(&self) -> Result<Vec<u8>> {
420        let mut writer = SafetensorsWriter::new();
421
422        for (name, tensor) in &self.state_dict.tensors {
423            // Determine shape for safetensors
424            let shape = tensor.shape.clone();
425
426            // Convert data based on dtype
427            match tensor.dtype.as_str() {
428                "float32" | "Float" => {
429                    // Convert bytes to f32 slice
430                    if tensor.data.len() % 4 != 0 {
431                        bail!("Invalid float32 data length for tensor {}", name);
432                    }
433
434                    let float_data: Vec<f32> = tensor
435                        .data
436                        .chunks_exact(4)
437                        .map(|chunk| {
438                            let bytes: [u8; 4] = chunk.try_into().unwrap();
439                            f32::from_le_bytes(bytes)
440                        })
441                        .collect();
442
443                    writer.add_f32(name, shape, &float_data);
444                }
445                "float64" | "Double" => {
446                    if tensor.data.len() % 8 != 0 {
447                        bail!("Invalid float64 data length for tensor {}", name);
448                    }
449
450                    let float_data: Vec<f64> = tensor
451                        .data
452                        .chunks_exact(8)
453                        .map(|chunk| {
454                            let bytes: [u8; 8] = chunk.try_into().unwrap();
455                            f64::from_le_bytes(bytes)
456                        })
457                        .collect();
458
459                    writer.add_f64(name, shape, &float_data);
460                }
461                _ => {
462                    bail!("Unsupported dtype: {}", tensor.dtype);
463                }
464            }
465        }
466
467        writer
468            .serialize()
469            .context("Failed to serialize to safetensors")
470    }
471
472    /// Save checkpoint in PyTorch format.
473    ///
474    /// Note: This creates a simplified pickle format compatible with PyTorch.
475    #[allow(dead_code)]
476    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
477        let bytes = self.to_pickle_bytes()?;
478        std::fs::write(path, bytes).context("Failed to write checkpoint file")?;
479        Ok(())
480    }
481
482    /// Serialize checkpoint to pickle bytes.
483    fn to_pickle_bytes(&self) -> Result<Vec<u8>> {
484        use serde_pickle::ser;
485
486        // Note: serde_pickle::Value doesn't implement Serialize, so we need to
487        // serialize our checkpoint structure directly via serde
488        // We'll use a simplified serializable format
489
490        #[derive(Serialize)]
491        struct CheckpointSer {
492            state_dict: HashMap<String, TensorSer>,
493            #[serde(skip_serializing_if = "Option::is_none")]
494            epoch: Option<usize>,
495            #[serde(skip_serializing_if = "Option::is_none")]
496            loss_history: Option<Vec<f32>>,
497            metadata: HashMap<String, String>,
498        }
499
500        #[derive(Serialize)]
501        struct TensorSer {
502            shape: Vec<usize>,
503            dtype: String,
504            data_len: usize,
505        }
506
507        let state_dict_ser: HashMap<String, TensorSer> = self
508            .state_dict
509            .tensors
510            .iter()
511            .map(|(name, tensor)| {
512                (
513                    name.clone(),
514                    TensorSer {
515                        shape: tensor.shape.clone(),
516                        dtype: tensor.dtype.clone(),
517                        data_len: tensor.data.len(),
518                    },
519                )
520            })
521            .collect();
522
523        let checkpoint_ser = CheckpointSer {
524            state_dict: state_dict_ser,
525            epoch: self.epoch,
526            loss_history: self.loss_history.clone(),
527            metadata: self.metadata.clone(),
528        };
529
530        // Serialize using serde_pickle's serializer
531        ser::to_vec(&checkpoint_ser, Default::default()).context("Failed to serialize to pickle")
532    }
533
534    /// Convert TensorData to pickle value.
535    ///
536    /// Note: This is a simplified helper for internal use.
537    #[allow(dead_code)]
538    fn tensor_to_pickle_value(_tensor: &TensorData) -> HashMap<String, String> {
539        // Simplified version for internal use
540        // In practice, you'd serialize the full tensor data
541        HashMap::new()
542    }
543
544    /// Create a new empty checkpoint.
545    pub fn new() -> Self {
546        PyTorchCheckpoint {
547            state_dict: StateDict {
548                tensors: HashMap::new(),
549            },
550            optimizer_state: None,
551            epoch: None,
552            loss_history: None,
553            metadata: HashMap::new(),
554        }
555    }
556
557    /// Add a tensor to the state dict.
558    pub fn add_tensor(&mut self, name: String, tensor: TensorData) {
559        self.state_dict.tensors.insert(name, tensor);
560    }
561
562    /// Set the epoch.
563    pub fn set_epoch(&mut self, epoch: usize) {
564        self.epoch = Some(epoch);
565    }
566
567    /// Add metadata entry.
568    pub fn add_metadata(&mut self, key: String, value: String) {
569        self.metadata.insert(key, value);
570    }
571}
572
573impl Default for PyTorchCheckpoint {
574    fn default() -> Self {
575        Self::new()
576    }
577}
578
579impl StateDict {
580    /// Get a tensor by name.
581    pub fn get(&self, name: &str) -> Option<&TensorData> {
582        self.tensors.get(name)
583    }
584
585    /// Iterate over tensors.
586    pub fn iter(&self) -> impl Iterator<Item = (&String, &TensorData)> {
587        self.tensors.iter()
588    }
589
590    /// Get number of tensors.
591    pub fn len(&self) -> usize {
592        self.tensors.len()
593    }
594
595    /// Check if state dict is empty.
596    pub fn is_empty(&self) -> bool {
597        self.tensors.is_empty()
598    }
599}
600
601impl TensorData {
602    /// Create new tensor data from f32 values.
603    pub fn from_f32(shape: Vec<usize>, data: &[f32]) -> Self {
604        let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
605
606        TensorData {
607            shape,
608            dtype: "float32".to_string(),
609            data: bytes,
610            requires_grad: false,
611        }
612    }
613
614    /// Create new tensor data from f64 values.
615    pub fn from_f64(shape: Vec<usize>, data: &[f64]) -> Self {
616        let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
617
618        TensorData {
619            shape,
620            dtype: "float64".to_string(),
621            data: bytes,
622            requires_grad: false,
623        }
624    }
625
626    /// Get tensor as f32 slice.
627    pub fn as_f32(&self) -> Result<Vec<f32>> {
628        if self.dtype != "float32" && self.dtype != "Float" {
629            bail!("Expected float32 dtype, got {}", self.dtype);
630        }
631
632        if !self.data.len().is_multiple_of(4) {
633            bail!("Invalid data length for float32");
634        }
635
636        Ok(self
637            .data
638            .chunks_exact(4)
639            .map(|chunk| {
640                let bytes: [u8; 4] = chunk.try_into().unwrap();
641                f32::from_le_bytes(bytes)
642            })
643            .collect())
644    }
645
646    /// Get tensor as f64 slice.
647    pub fn as_f64(&self) -> Result<Vec<f64>> {
648        if self.dtype != "float64" && self.dtype != "Double" {
649            bail!("Expected float64 dtype, got {}", self.dtype);
650        }
651
652        if !self.data.len().is_multiple_of(8) {
653            bail!("Invalid data length for float64");
654        }
655
656        Ok(self
657            .data
658            .chunks_exact(8)
659            .map(|chunk| {
660                let bytes: [u8; 8] = chunk.try_into().unwrap();
661                f64::from_le_bytes(bytes)
662            })
663            .collect())
664    }
665
666    /// Get number of elements.
667    pub fn num_elements(&self) -> usize {
668        self.shape.iter().product()
669    }
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675
676    #[test]
677    fn test_checkpoint_creation() {
678        let mut checkpoint = PyTorchCheckpoint::new();
679
680        // Add a simple tensor
681        let tensor = TensorData::from_f32(vec![2, 3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
682        checkpoint.add_tensor("layer1.weight".to_string(), tensor);
683
684        checkpoint.set_epoch(10);
685        checkpoint.add_metadata("model_type".to_string(), "CNN".to_string());
686
687        assert_eq!(checkpoint.state_dict().len(), 1);
688        assert_eq!(checkpoint.epoch, Some(10));
689        assert_eq!(checkpoint.metadata.get("model_type").unwrap(), "CNN");
690    }
691
692    #[test]
693    fn test_tensor_data_f32() {
694        let data = vec![1.0f32, 2.0, 3.0, 4.0];
695        let tensor = TensorData::from_f32(vec![2, 2], &data);
696
697        assert_eq!(tensor.shape, vec![2, 2]);
698        assert_eq!(tensor.dtype, "float32");
699        assert_eq!(tensor.num_elements(), 4);
700
701        let recovered = tensor.as_f32().unwrap();
702        assert_eq!(recovered, data);
703    }
704
705    #[test]
706    fn test_tensor_data_f64() {
707        let data = vec![1.0f64, 2.0, 3.0, 4.0];
708        let tensor = TensorData::from_f64(vec![2, 2], &data);
709
710        assert_eq!(tensor.shape, vec![2, 2]);
711        assert_eq!(tensor.dtype, "float64");
712
713        let recovered = tensor.as_f64().unwrap();
714        assert_eq!(recovered, data);
715    }
716
717    #[test]
718    fn test_metadata_extraction() {
719        let mut checkpoint = PyTorchCheckpoint::new();
720
721        checkpoint.add_tensor(
722            "layer1.weight".to_string(),
723            TensorData::from_f32(vec![10, 10], &vec![0.0; 100]),
724        );
725        checkpoint.add_tensor(
726            "layer1.bias".to_string(),
727            TensorData::from_f32(vec![10], &[0.0; 10]),
728        );
729        checkpoint.add_tensor(
730            "layer2.weight".to_string(),
731            TensorData::from_f64(vec![5, 10], &vec![0.0; 50]),
732        );
733
734        let metadata = checkpoint.metadata();
735
736        assert_eq!(metadata.total_parameters, 160);
737        assert_eq!(metadata.layer_names.len(), 3);
738        assert_eq!(metadata.dtypes.get("float32"), Some(&2));
739        assert_eq!(metadata.dtypes.get("float64"), Some(&1));
740    }
741
742    #[test]
743    fn test_state_dict_access() {
744        let mut checkpoint = PyTorchCheckpoint::new();
745
746        let tensor = TensorData::from_f32(vec![3], &[1.0, 2.0, 3.0]);
747        checkpoint.add_tensor("test".to_string(), tensor);
748
749        let state_dict = checkpoint.state_dict();
750        assert_eq!(state_dict.len(), 1);
751        assert!(!state_dict.is_empty());
752
753        let retrieved = state_dict.get("test").unwrap();
754        assert_eq!(retrieved.shape, vec![3]);
755    }
756
757    #[test]
758    fn test_checkpoint_serialization() -> Result<()> {
759        let mut checkpoint = PyTorchCheckpoint::new();
760
761        checkpoint.add_tensor(
762            "weight".to_string(),
763            TensorData::from_f32(vec![2, 2], &[1.0, 2.0, 3.0, 4.0]),
764        );
765        checkpoint.set_epoch(5);
766        checkpoint.add_metadata("arch".to_string(), "ResNet".to_string());
767
768        // Test that serialization works without errors
769        let bytes = checkpoint.to_pickle_bytes()?;
770        assert!(!bytes.is_empty());
771
772        // Note: Full PyTorch pickle roundtrip requires handling complex tensor
773        // structures. For practical use, convert to Safetensors format using
774        // to_safetensors() which provides full fidelity and is more secure.
775
776        Ok(())
777    }
778
779    #[test]
780    fn test_to_safetensors() -> Result<()> {
781        let mut checkpoint = PyTorchCheckpoint::new();
782
783        checkpoint.add_tensor(
784            "layer1.weight".to_string(),
785            TensorData::from_f32(vec![3, 3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]),
786        );
787        checkpoint.add_tensor(
788            "layer1.bias".to_string(),
789            TensorData::from_f32(vec![3], &[0.1, 0.2, 0.3]),
790        );
791
792        let safetensors_bytes = checkpoint.to_safetensors()?;
793        assert!(!safetensors_bytes.is_empty());
794
795        Ok(())
796    }
797}