use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::distributed::comm_utils::broadcast_tensor;
use crate::distributed::grad_sync::all_reduce_grads;
use crate::distributed::zero_base::{ZeroOptimizer, ZeroOptimizerBase};
use crate::error::{Error, Result};
use crate::optimizer::AdamWConfig;
use numr::autograd::GradStore;
use numr::dtype::DType;
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use crate::ops::FusedOptimizerOps;
use numr::runtime::{Communicator, Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};
#[derive(Clone, Debug)]
struct ParamMeta {
shape: Vec<usize>,
dtype: DType,
}
pub struct ZeroStage3<R: Runtime> {
base: ZeroOptimizerBase<R>,
param_meta: HashMap<TensorId, ParamMeta>,
}
impl<R: Runtime<DType = DType>> ZeroStage3<R> {
pub fn new(
config: AdamWConfig,
comm: Arc<dyn Communicator>,
params: &HashMap<TensorId, Tensor<R>>,
) -> Self {
let param_ids: Vec<TensorId> = params.keys().copied().collect();
let base = ZeroOptimizerBase::new(config, comm, ¶m_ids);
let mut param_meta = HashMap::with_capacity(params.len());
for (&id, tensor) in params {
param_meta.insert(
id,
ParamMeta {
shape: tensor.shape().to_vec(),
dtype: tensor.dtype(),
},
);
}
Self { base, param_meta }
}
pub fn gather_full_params(
&self,
params: &mut HashMap<TensorId, Tensor<R>>,
device: &R::Device,
) -> Result<()> {
if self.base.world_size <= 1 {
return Ok(());
}
for &(id, owner) in &self.base.param_owners {
use std::collections::hash_map::Entry;
if let Entry::Vacant(e) = params.entry(id) {
let meta = self
.param_meta
.get(&id)
.ok_or_else(|| Error::DistributedError {
reason: format!(
"missing metadata for param {id:?} — was it in the \
original params passed to ZeroStage3::new?"
),
})?;
let buf = Tensor::<R>::zeros(&meta.shape, meta.dtype, device);
e.insert(buf);
}
let tensor = params.get(&id).expect("just ensured it exists");
broadcast_tensor(self.base.comm.as_ref(), tensor, owner)?;
}
self.base.comm.sync().map_err(|e| Error::DistributedError {
reason: format!("sync after ZeRO Stage 3 gather failed: {e}"),
})?;
Ok(())
}
pub fn release_params(&self, params: &mut HashMap<TensorId, Tensor<R>>) {
if self.base.world_size <= 1 {
return;
}
params.retain(|id, _| self.base.owned_params.contains(id));
}
pub fn step<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: &mut GradStore<R>,
) -> Result<()>
where
C: RuntimeClient<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ FusedOptimizerOps<R>,
{
if self.base.world_size <= 1 {
return self.base.optimizer.step(client, params, grads);
}
all_reduce_grads(self.base.comm.as_ref(), client, grads)?;
let owned_grads = self.base.filter_to_owned(grads);
grads.clear();
self.base.step_owned(client, params, &owned_grads)?;
Ok(())
}
pub fn param_owners(&self) -> &[(TensorId, usize)] {
&self.base.param_owners
}
}
impl<R: Runtime<DType = DType>> ZeroOptimizer<R> for ZeroStage3<R> {
fn owned_param_ids(&self) -> &HashSet<TensorId> {
&self.base.owned_params
}
fn timestep(&self) -> u64 {
self.base.optimizer.timestep()
}
fn set_lr(&mut self, lr: f64) {
self.base.optimizer.set_lr(lr);
}
fn config(&self) -> &AdamWConfig {
self.base.optimizer.config()
}
fn gather_full_params(
&self,
params: &mut HashMap<TensorId, Tensor<R>>,
device: &R::Device,
) -> Result<()> {
ZeroStage3::gather_full_params(self, params, device)
}
fn release_params(&self, params: &mut HashMap<TensorId, Tensor<R>>) {
ZeroStage3::release_params(self, params);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::NoOpCommunicator;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_zero3_single_rank_matches_adamw() {
use crate::optimizer::AdamW;
let (client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let config = AdamWConfig {
lr: 0.1,
weight_decay: 0.0,
..Default::default()
};
let id1 = TensorId::new();
let id2 = TensorId::new();
let t1 = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let t2 = Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0], &[2], &device);
let mut zero3_params = HashMap::new();
zero3_params.insert(id1, t1.clone());
zero3_params.insert(id2, t2.clone());
let mut adam_params = HashMap::new();
adam_params.insert(id1, t1);
adam_params.insert(id2, t2);
let g1 = Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2], &[2], &device);
let g2 = Tensor::<CpuRuntime>::from_slice(&[0.3f32, 0.4], &[2], &device);
let mut grads = GradStore::new();
grads.insert(id1, g1.clone());
grads.insert(id2, g2.clone());
let mut grads2 = GradStore::new();
grads2.insert(id1, g1);
grads2.insert(id2, g2);
let mut zero3 = ZeroStage3::<CpuRuntime>::new(config.clone(), comm, &zero3_params);
zero3.step(&client, &mut zero3_params, &mut grads).unwrap();
let mut adam = AdamW::<CpuRuntime>::new(config);
adam.step(&client, &mut adam_params, &grads2).unwrap();
let z1: Vec<f32> = zero3_params[&id1].to_vec();
let a1: Vec<f32> = adam_params[&id1].to_vec();
for (z, a) in z1.iter().zip(a1.iter()) {
assert!((z - a).abs() < 1e-6, "mismatch: zero3={z}, adam={a}");
}
}
#[test]
fn test_zero3_gather_release_lifecycle() {
let (_, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let id1 = TensorId::new();
let id2 = TensorId::new();
let t1 = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let t2 = Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0], &[2], &device);
let mut params = HashMap::new();
params.insert(id1, t1);
params.insert(id2, t2);
let zero3 = ZeroStage3::<CpuRuntime>::new(AdamWConfig::default(), comm, ¶ms);
assert_eq!(params.len(), 2);
zero3.gather_full_params(&mut params, &device).unwrap();
assert_eq!(params.len(), 2);
zero3.release_params(&mut params);
assert_eq!(params.len(), 2);
}
#[test]
fn test_zero3_step_updates_owned_params() {
let (client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let id = TensorId::new();
let t = Tensor::<CpuRuntime>::from_slice(&[5.0f32, 5.0], &[2], &device);
let original: Vec<f32> = t.to_vec();
let mut params = HashMap::new();
params.insert(id, t);
let g = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device);
let mut grads = GradStore::new();
grads.insert(id, g);
let config = AdamWConfig {
lr: 0.1,
weight_decay: 0.0,
..Default::default()
};
let mut zero3 = ZeroStage3::<CpuRuntime>::new(config, comm, ¶ms);
zero3.step(&client, &mut params, &mut grads).unwrap();
let updated: Vec<f32> = params[&id].to_vec();
assert_ne!(updated, original, "params should change after step");
assert_eq!(zero3.timestep(), 1);
}
#[test]
fn test_zero3_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ZeroStage3<CpuRuntime>>();
}
}