Skip to main content

kwaai_distributed/
coordinator.rs

1//! Distributed operations coordinator
2
3use crate::averaging::{AveragingConfig, DecentralizedAverager};
4use crate::error::DistributedResult;
5use crate::moe::DistributedMoE;
6use crate::DistributedConfig;
7use tracing::{debug, info};
8
9/// Coordinator for all distributed ML operations
10pub struct DistributedCoordinator {
11    /// Configuration
12    config: DistributedConfig,
13    /// MoE layer (if enabled)
14    moe: Option<DistributedMoE>,
15    /// Parameter averager (if enabled)
16    averager: Option<DecentralizedAverager>,
17    /// Whether coordinator is running
18    is_running: bool,
19}
20
21impl DistributedCoordinator {
22    /// Create a new coordinator
23    pub fn new(config: DistributedConfig) -> Self {
24        Self {
25            config,
26            moe: None,
27            averager: None,
28            is_running: false,
29        }
30    }
31
32    /// Initialize the coordinator
33    pub fn initialize(&mut self) -> DistributedResult<()> {
34        info!(
35            moe = self.config.enable_moe,
36            averaging = self.config.enable_averaging,
37            "Initializing DistributedCoordinator"
38        );
39        if self.config.enable_averaging {
40            let averaging_config = AveragingConfig {
41                group_size: self.config.averaging_group_size,
42                ..Default::default()
43            };
44            self.averager = Some(DecentralizedAverager::new(averaging_config));
45            debug!(
46                group_size = self.config.averaging_group_size,
47                "Parameter averager initialized"
48            );
49        }
50
51        // MoE initialization would require router weights
52        // Left as None for now, to be initialized when model is loaded
53
54        self.is_running = true;
55        info!("DistributedCoordinator initialized");
56        Ok(())
57    }
58
59    /// Check if distributed mode is enabled
60    pub fn is_enabled(&self) -> bool {
61        self.config.enable_moe || self.config.enable_averaging
62    }
63
64    /// Get the MoE layer
65    pub fn moe(&self) -> Option<&DistributedMoE> {
66        self.moe.as_ref()
67    }
68
69    /// Get the MoE layer mutably
70    pub fn moe_mut(&mut self) -> Option<&mut DistributedMoE> {
71        self.moe.as_mut()
72    }
73
74    /// Get the averager
75    pub fn averager(&self) -> Option<&DecentralizedAverager> {
76        self.averager.as_ref()
77    }
78
79    /// Get the averager mutably
80    pub fn averager_mut(&mut self) -> Option<&mut DecentralizedAverager> {
81        self.averager.as_mut()
82    }
83
84    /// Check if coordinator is running
85    pub fn is_running(&self) -> bool {
86        self.is_running
87    }
88
89    /// Stop the coordinator
90    pub fn stop(&mut self) {
91        info!("DistributedCoordinator stopping");
92        self.is_running = false;
93    }
94}
95
96impl Default for DistributedCoordinator {
97    fn default() -> Self {
98        Self::new(DistributedConfig::default())
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    fn disabled_config() -> DistributedConfig {
107        DistributedConfig {
108            enable_moe: false,
109            enable_averaging: false,
110            ..DistributedConfig::default()
111        }
112    }
113
114    #[test]
115    fn test_disabled_coordinator_not_enabled() {
116        let coord = DistributedCoordinator::new(disabled_config());
117        assert!(!coord.is_enabled());
118    }
119
120    #[test]
121    fn test_not_running_before_init() {
122        let coord = DistributedCoordinator::new(DistributedConfig::default());
123        assert!(!coord.is_running());
124    }
125
126    #[test]
127    fn test_initialize_sets_running() {
128        let mut coord = DistributedCoordinator::new(DistributedConfig::default());
129        coord.initialize().unwrap();
130        assert!(coord.is_running());
131    }
132
133    #[test]
134    fn test_stop_clears_running() {
135        let mut coord = DistributedCoordinator::new(DistributedConfig::default());
136        coord.initialize().unwrap();
137        coord.stop();
138        assert!(!coord.is_running());
139    }
140
141    #[test]
142    fn test_averaging_enabled_creates_averager() {
143        let cfg = DistributedConfig {
144            enable_moe: false,
145            enable_averaging: true,
146            ..DistributedConfig::default()
147        };
148        let mut coord = DistributedCoordinator::new(cfg);
149        coord.initialize().unwrap();
150        assert!(coord.averager().is_some());
151        assert!(coord.moe().is_none());
152    }
153
154    #[test]
155    fn test_averaging_disabled_no_averager() {
156        let mut coord = DistributedCoordinator::new(disabled_config());
157        coord.initialize().unwrap();
158        assert!(coord.averager().is_none());
159    }
160
161    #[test]
162    fn test_is_enabled_with_averaging_only() {
163        let cfg = DistributedConfig {
164            enable_moe: false,
165            enable_averaging: true,
166            ..DistributedConfig::default()
167        };
168        let coord = DistributedCoordinator::new(cfg);
169        assert!(coord.is_enabled());
170    }
171
172    #[test]
173    fn test_default_coordinator() {
174        let coord = DistributedCoordinator::default();
175        // Default config has both enabled
176        assert!(coord.is_enabled());
177    }
178}