use crate::error::{Error, Result};
use crate::trainer::checkpoint::{CHECKPOINT_VERSION, TrainingState, save_checkpoint};
use crate::trainer::distributed_checkpoint::types::{ShardingConfig, ShardingMeta};
use numr::runtime::cpu::CpuRuntime;
use numr::tensor::Tensor;
use std::collections::HashMap;
use std::path::Path;
#[allow(clippy::too_many_arguments)]
pub fn save_distributed_checkpoint<P: AsRef<Path>>(
dir: P,
rank: usize,
world_size: usize,
model_state: &HashMap<String, Tensor<CpuRuntime>>,
optimizer_state: Option<&HashMap<String, Tensor<CpuRuntime>>>,
training_state: &TrainingState,
sharding: ShardingConfig,
) -> Result<()> {
let dir = dir.as_ref();
let rank_dir = dir.join(format!("rank_{rank}"));
save_checkpoint(&rank_dir, model_state, optimizer_state, training_state)?;
if rank == 0 {
let meta = ShardingMeta {
version: CHECKPOINT_VERSION,
world_size,
shards: (0..world_size)
.map(|r| {
if r == rank {
sharding.clone()
} else {
ShardingConfig {
world_size,
rank: r,
owned_params: Vec::new(),
strategy: sharding.strategy.clone(),
split_dims: HashMap::new(),
}
}
})
.collect(),
};
let json = serde_json::to_string_pretty(&meta).map_err(|e| Error::TrainingError {
reason: format!("failed to serialize sharding meta: {e}"),
})?;
std::fs::create_dir_all(dir).map_err(|e| Error::TrainingError {
reason: format!("failed to create checkpoint dir: {e}"),
})?;
std::fs::write(dir.join("sharding_meta.json"), json).map_err(|e| Error::TrainingError {
reason: format!("failed to write sharding meta: {e}"),
})?;
}
Ok(())
}