use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::error::{Error, Result};
use crate::optimizer::{AdamW, AdamWConfig};
use numr::autograd::GradStore;
use numr::dtype::DType;
use numr::ops::{BinaryOps, ScalarOps, UnaryOps};
use crate::ops::FusedOptimizerOps;
use numr::runtime::{Communicator, Runtime, RuntimeClient};
use numr::tensor::{Tensor, TensorId};
use crate::distributed::comm_utils::broadcast_tensor;
pub trait ZeroOptimizer<R: Runtime> {
fn owned_param_ids(&self) -> &HashSet<TensorId>;
fn timestep(&self) -> u64;
fn set_lr(&mut self, lr: f64);
fn config(&self) -> &AdamWConfig;
fn gather_full_params(
&self,
_params: &mut HashMap<TensorId, Tensor<R>>,
_device: &R::Device,
) -> Result<()> {
Ok(())
}
fn release_params(&self, _params: &mut HashMap<TensorId, Tensor<R>>) {}
}
macro_rules! impl_zero_optimizer {
($ty:ident) => {
impl<R: numr::runtime::Runtime<DType = numr::dtype::DType>>
$crate::distributed::zero_base::ZeroOptimizer<R> for $ty<R>
{
fn owned_param_ids(&self) -> &std::collections::HashSet<numr::tensor::TensorId> {
&self.base.owned_params
}
fn timestep(&self) -> u64 {
self.base.optimizer.timestep()
}
fn set_lr(&mut self, lr: f64) {
self.base.optimizer.set_lr(lr);
}
fn config(&self) -> &$crate::optimizer::AdamWConfig {
self.base.optimizer.config()
}
}
};
}
pub(crate) use impl_zero_optimizer;
pub(crate) struct ZeroOptimizerBase<R: Runtime> {
pub(crate) optimizer: AdamW<R>,
pub(crate) comm: Arc<dyn Communicator>,
pub(crate) world_size: usize,
pub(crate) owned_params: HashSet<TensorId>,
pub(crate) param_owners: Vec<(TensorId, usize)>,
}
impl<R: Runtime<DType = DType>> ZeroOptimizerBase<R> {
pub(crate) fn new(
config: AdamWConfig,
comm: Arc<dyn Communicator>,
param_ids: &[TensorId],
) -> Self {
let rank = comm.rank();
let world_size = comm.world_size();
let mut sorted_ids: Vec<TensorId> = param_ids.to_vec();
sorted_ids.sort_by_key(|id| id.raw());
let mut owned_params = HashSet::new();
let mut param_owners = Vec::with_capacity(sorted_ids.len());
for (i, &id) in sorted_ids.iter().enumerate() {
let owner = i % world_size;
param_owners.push((id, owner));
if owner == rank {
owned_params.insert(id);
}
}
Self {
optimizer: AdamW::new(config),
comm,
world_size,
owned_params,
param_owners,
}
}
pub(crate) fn filter_to_owned(&self, grads: &GradStore<R>) -> GradStore<R> {
let mut owned = GradStore::new();
for &id in &self.owned_params {
if let Some(g) = grads.get(id) {
owned.insert(id, g.clone());
}
}
owned
}
pub(crate) fn extract_owned_params(
&self,
params: &mut HashMap<TensorId, Tensor<R>>,
) -> HashMap<TensorId, Tensor<R>> {
let mut owned_map = HashMap::new();
for &id in &self.owned_params {
if let Some(t) = params.remove(&id) {
owned_map.insert(id, t);
}
}
owned_map
}
pub(crate) fn restore_owned_params(
&self,
owned_map: HashMap<TensorId, Tensor<R>>,
params: &mut HashMap<TensorId, Tensor<R>>,
) {
for (id, tensor) in owned_map {
params.insert(id, tensor);
}
}
pub(crate) fn broadcast_owned_params(
&self,
params: &HashMap<TensorId, Tensor<R>>,
stage_name: &str,
) -> Result<()> {
for &(id, owner) in &self.param_owners {
if let Some(tensor) = params.get(&id) {
broadcast_tensor(self.comm.as_ref(), tensor, owner)?;
}
}
self.comm.sync().map_err(|e| Error::DistributedError {
reason: format!("sync after {stage_name} broadcast failed: {e}"),
})?;
Ok(())
}
pub(crate) fn step_owned<C>(
&mut self,
client: &C,
params: &mut HashMap<TensorId, Tensor<R>>,
grads: &GradStore<R>,
) -> Result<()>
where
C: RuntimeClient<R> + BinaryOps<R> + UnaryOps<R> + ScalarOps<R> + FusedOptimizerOps<R>,
{
let mut owned_map = self.extract_owned_params(params);
self.optimizer.step(client, &mut owned_map, grads)?;
self.restore_owned_params(owned_map, params);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::NoOpCommunicator;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_base_ownership_round_robin() {
let comm = Arc::new(NoOpCommunicator); let ids: Vec<TensorId> = (0..6).map(|_| TensorId::new()).collect();
let base = ZeroOptimizerBase::<CpuRuntime>::new(AdamWConfig::default(), comm, &ids);
assert_eq!(base.owned_params.len(), 6);
}
#[test]
fn test_base_set_lr() {
let comm = Arc::new(NoOpCommunicator);
let mut base = ZeroOptimizerBase::<CpuRuntime>::new(AdamWConfig::default(), comm, &[]);
base.optimizer.set_lr(0.01);
assert_eq!(base.optimizer.config().lr, 0.01);
}
}