use crate::averaging::{AveragingConfig, DecentralizedAverager};
use crate::error::DistributedResult;
use crate::moe::DistributedMoE;
use crate::DistributedConfig;
use tracing::{debug, info};
pub struct DistributedCoordinator {
config: DistributedConfig,
moe: Option<DistributedMoE>,
averager: Option<DecentralizedAverager>,
is_running: bool,
}
impl DistributedCoordinator {
pub fn new(config: DistributedConfig) -> Self {
Self {
config,
moe: None,
averager: None,
is_running: false,
}
}
pub fn initialize(&mut self) -> DistributedResult<()> {
info!(
moe = self.config.enable_moe,
averaging = self.config.enable_averaging,
"Initializing DistributedCoordinator"
);
if self.config.enable_averaging {
let averaging_config = AveragingConfig {
group_size: self.config.averaging_group_size,
..Default::default()
};
self.averager = Some(DecentralizedAverager::new(averaging_config));
debug!(
group_size = self.config.averaging_group_size,
"Parameter averager initialized"
);
}
self.is_running = true;
info!("DistributedCoordinator initialized");
Ok(())
}
pub fn is_enabled(&self) -> bool {
self.config.enable_moe || self.config.enable_averaging
}
pub fn moe(&self) -> Option<&DistributedMoE> {
self.moe.as_ref()
}
pub fn moe_mut(&mut self) -> Option<&mut DistributedMoE> {
self.moe.as_mut()
}
pub fn averager(&self) -> Option<&DecentralizedAverager> {
self.averager.as_ref()
}
pub fn averager_mut(&mut self) -> Option<&mut DecentralizedAverager> {
self.averager.as_mut()
}
pub fn is_running(&self) -> bool {
self.is_running
}
pub fn stop(&mut self) {
info!("DistributedCoordinator stopping");
self.is_running = false;
}
}
impl Default for DistributedCoordinator {
fn default() -> Self {
Self::new(DistributedConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn disabled_config() -> DistributedConfig {
DistributedConfig {
enable_moe: false,
enable_averaging: false,
..DistributedConfig::default()
}
}
#[test]
fn test_disabled_coordinator_not_enabled() {
let coord = DistributedCoordinator::new(disabled_config());
assert!(!coord.is_enabled());
}
#[test]
fn test_not_running_before_init() {
let coord = DistributedCoordinator::new(DistributedConfig::default());
assert!(!coord.is_running());
}
#[test]
fn test_initialize_sets_running() {
let mut coord = DistributedCoordinator::new(DistributedConfig::default());
coord.initialize().unwrap();
assert!(coord.is_running());
}
#[test]
fn test_stop_clears_running() {
let mut coord = DistributedCoordinator::new(DistributedConfig::default());
coord.initialize().unwrap();
coord.stop();
assert!(!coord.is_running());
}
#[test]
fn test_averaging_enabled_creates_averager() {
let cfg = DistributedConfig {
enable_moe: false,
enable_averaging: true,
..DistributedConfig::default()
};
let mut coord = DistributedCoordinator::new(cfg);
coord.initialize().unwrap();
assert!(coord.averager().is_some());
assert!(coord.moe().is_none());
}
#[test]
fn test_averaging_disabled_no_averager() {
let mut coord = DistributedCoordinator::new(disabled_config());
coord.initialize().unwrap();
assert!(coord.averager().is_none());
}
#[test]
fn test_is_enabled_with_averaging_only() {
let cfg = DistributedConfig {
enable_moe: false,
enable_averaging: true,
..DistributedConfig::default()
};
let coord = DistributedCoordinator::new(cfg);
assert!(coord.is_enabled());
}
#[test]
fn test_default_coordinator() {
let coord = DistributedCoordinator::default();
assert!(coord.is_enabled());
}
}