use crate::averaging::{AveragingConfig, DecentralizedAverager};
use crate::error::DistributedResult;
use crate::moe::DistributedMoE;
use crate::DistributedConfig;
use tracing::{debug, info};
pub struct DistributedCoordinator {
config: DistributedConfig,
moe: Option<DistributedMoE>,
averager: Option<DecentralizedAverager>,
is_running: bool,
}
impl DistributedCoordinator {
pub fn new(config: DistributedConfig) -> Self {
Self {
config,
moe: None,
averager: None,
is_running: false,
}
}
pub fn initialize(&mut self) -> DistributedResult<()> {
info!(
moe = self.config.enable_moe,
averaging = self.config.enable_averaging,
"Initializing DistributedCoordinator"
);
if self.config.enable_averaging {
let averaging_config = AveragingConfig {
group_size: self.config.averaging_group_size,
..Default::default()
};
self.averager = Some(DecentralizedAverager::new(averaging_config));
debug!(
group_size = self.config.averaging_group_size,
"Parameter averager initialized"
);
}
self.is_running = true;
info!("DistributedCoordinator initialized");
Ok(())
}
pub fn is_enabled(&self) -> bool {
self.config.enable_moe || self.config.enable_averaging
}
pub fn moe(&self) -> Option<&DistributedMoE> {
self.moe.as_ref()
}
pub fn moe_mut(&mut self) -> Option<&mut DistributedMoE> {
self.moe.as_mut()
}
pub fn averager(&self) -> Option<&DecentralizedAverager> {
self.averager.as_ref()
}
pub fn averager_mut(&mut self) -> Option<&mut DecentralizedAverager> {
self.averager.as_mut()
}
pub fn is_running(&self) -> bool {
self.is_running
}
pub fn stop(&mut self) {
info!("DistributedCoordinator stopping");
self.is_running = false;
}
}
impl Default for DistributedCoordinator {
fn default() -> Self {
Self::new(DistributedConfig::default())
}
}