use crate::error::{Error, Result};
use crate::trainer::checkpoint::{TrainingState, load_checkpoint, save_checkpoint};
use crate::trainer::distributed_checkpoint::types::{ShardingMeta, ShardingStrategy};
use numr::ops::ShapeOps;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
use numr::tensor::Tensor;
use std::collections::HashMap;
use std::path::Path;
pub fn consolidate_checkpoint<P: AsRef<Path>>(sharded_dir: P, output_dir: P) -> Result<()> {
let sharded_dir = sharded_dir.as_ref();
let output_dir = output_dir.as_ref();
let meta_path = sharded_dir.join("sharding_meta.json");
let meta_json = std::fs::read_to_string(&meta_path).map_err(|e| Error::TrainingError {
reason: format!("failed to read sharding meta: {e}"),
})?;
let meta: ShardingMeta =
serde_json::from_str(&meta_json).map_err(|e| Error::TrainingError {
reason: format!("failed to parse sharding meta: {e}"),
})?;
let device = CpuDevice::new();
let mut merged_model: HashMap<String, Tensor<CpuRuntime>> = HashMap::new();
let mut merged_opt: HashMap<String, Tensor<CpuRuntime>> = HashMap::new();
let mut training_state: Option<TrainingState> = None;
let strategy = &meta.shards[0].strategy;
let is_tensor_parallel = matches!(strategy, ShardingStrategy::TensorParallel);
if is_tensor_parallel {
consolidate_tensor_parallel(
&meta,
sharded_dir,
&device,
&mut merged_model,
&mut merged_opt,
&mut training_state,
)?;
} else {
for rank in 0..meta.world_size {
let rank_dir = sharded_dir.join(format!("rank_{rank}"));
let (model, opt, state) = load_checkpoint::<CpuRuntime, _>(&rank_dir, &device)?;
if training_state.is_none() {
training_state = Some(state);
}
let strategy = &meta.shards[rank].strategy;
match strategy {
ShardingStrategy::Replicated => {
if rank == 0 {
merged_model = model;
if let Some(opt) = opt {
merged_opt = opt;
}
}
}
ShardingStrategy::ZeroPartitioned { .. } => {
merged_model.extend(model);
if let Some(opt) = opt {
merged_opt.extend(opt);
}
}
ShardingStrategy::TensorParallel => {
unreachable!()
}
}
}
}
let state = training_state.ok_or_else(|| Error::TrainingError {
reason: "no ranks found in sharded checkpoint".to_string(),
})?;
save_checkpoint(
output_dir,
&merged_model,
if merged_opt.is_empty() {
None
} else {
Some(&merged_opt)
},
&state,
)
}
fn consolidate_tensor_parallel(
meta: &ShardingMeta,
sharded_dir: &Path,
device: &CpuDevice,
merged_model: &mut HashMap<String, Tensor<CpuRuntime>>,
merged_opt: &mut HashMap<String, Tensor<CpuRuntime>>,
training_state: &mut Option<TrainingState>,
) -> Result<()> {
let split_dims = &meta.shards[0].split_dims;
if split_dims.is_empty() {
return Err(Error::TrainingError {
reason: "TensorParallel consolidation requires split_dims metadata \
in ShardingConfig (missing for rank 0)"
.to_string(),
});
}
let mut all_models = Vec::new();
let mut all_opts = Vec::new();
for rank in 0..meta.world_size {
let rank_dir = sharded_dir.join(format!("rank_{rank}"));
let (model, opt, state) = load_checkpoint::<CpuRuntime, _>(&rank_dir, device)?;
if training_state.is_none() {
*training_state = Some(state);
}
all_models.push(model);
all_opts.push(opt);
}
let client = CpuClient::new(device.clone());
for (param_name, &split_dim) in split_dims {
let shards: Vec<&Tensor<CpuRuntime>> = (0..meta.world_size)
.filter_map(|r| all_models[r].get(param_name))
.collect();
if shards.len() == meta.world_size {
let merged =
client
.cat(&shards, split_dim as isize)
.map_err(|e| Error::TrainingError {
reason: format!(
"failed to concatenate param '{}' along dim {}: {}",
param_name, split_dim, e
),
})?;
merged_model.insert(param_name.clone(), merged);
}
}
if let Some(rank0_model) = all_models.first() {
for (name, tensor) in rank0_model {
if !split_dims.contains_key(name) {
merged_model.insert(name.clone(), tensor.clone());
}
}
}
if let Some(Some(opt0)) = all_opts.first() {
for (name, tensor) in opt0 {
if split_dims.contains_key(name) {
let shards: Vec<&Tensor<CpuRuntime>> = (0..meta.world_size)
.filter_map(|r| all_opts[r].as_ref().and_then(|o| o.get(name)))
.collect();
if shards.len() == meta.world_size {
let split_dim = split_dims[name];
if let Ok(merged) = client.cat(&shards, split_dim as isize) {
merged_opt.insert(name.clone(), merged);
}
}
} else {
merged_opt.insert(name.clone(), tensor.clone());
}
}
}
Ok(())
}