use std::path::Path;
use std::collections::HashMap;
use candle_core::{Result, Tensor, Device, DType};
use safetensors::tensor::{SafeTensors, TensorView};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CheckpointMetadata {
pub step: usize,
pub lr: f64,
pub loss: Option<f64>,
pub config: Option<String>,
}
impl Default for CheckpointMetadata {
fn default() -> Self {
Self {
step: 0,
lr: 0.0,
loss: None,
config: None,
}
}
}
pub struct Checkpoint {
pub tensors: HashMap<String, Tensor>,
pub metadata: CheckpointMetadata,
}
impl Checkpoint {
pub fn new(tensors: HashMap<String, Tensor>, metadata: CheckpointMetadata) -> Self {
Self { tensors, metadata }
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> crate::Result<()> {
let metadata_json = serde_json::to_string(&self.metadata)
.map_err(|e| crate::TRMError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)))?;
std::fs::write(path.as_ref(), metadata_json.as_bytes())?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P, device: &Device) -> crate::Result<Self> {
let data = std::fs::read(path.as_ref())?;
let metadata: CheckpointMetadata = serde_json::from_slice(&data).unwrap_or_default();
Ok(Self {
tensors: HashMap::new(),
metadata,
})
}
pub fn load_weights<P: AsRef<Path>>(
path: P,
device: &Device,
) -> crate::Result<HashMap<String, Tensor>> {
let checkpoint = Self::load(path, device)?;
Ok(checkpoint.tensors)
}
}
pub fn save_checkpoint<P: AsRef<Path>>(
params: HashMap<String, Tensor>,
path: P,
metadata: CheckpointMetadata,
) -> crate::Result<()> {
let checkpoint = Checkpoint::new(params, metadata);
checkpoint.save(path)
}
pub fn load_checkpoint<P: AsRef<Path>>(
path: P,
device: &Device,
) -> crate::Result<HashMap<String, Tensor>> {
Checkpoint::load_weights(path, device)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
#[test]
fn test_checkpoint_metadata() {
let metadata = CheckpointMetadata {
step: 1000,
lr: 0.001,
loss: Some(0.5),
config: Some("{}".to_string()),
};
assert_eq!(metadata.step, 1000);
assert_eq!(metadata.lr, 0.001);
}
#[test]
fn test_checkpoint_creation() -> Result<()> {
let device = Device::Cpu;
let mut tensors = HashMap::new();
tensors.insert(
"weight".to_string(),
Tensor::ones((10, 10), DType::F32, &device)?,
);
let metadata = CheckpointMetadata::default();
let checkpoint = Checkpoint::new(tensors, metadata);
assert_eq!(checkpoint.tensors.len(), 1);
Ok(())
}
#[test]
fn test_save_load_checkpoint() -> Result<()> {
let device = Device::Cpu;
let mut tensors = HashMap::new();
tensors.insert(
"weight".to_string(),
Tensor::ones((5, 5), DType::F32, &device)?,
);
let metadata = CheckpointMetadata {
step: 500,
lr: 0.0005,
loss: Some(0.25),
config: None,
};
let temp_path = std::path::Path::new("test_checkpoint.safetensors");
let result = save_checkpoint(tensors, temp_path, metadata.clone());
if temp_path.exists() {
fs::remove_file(temp_path).ok();
}
Ok(())
}
}