kwaai_distributed/
coordinator.rs1use crate::averaging::{AveragingConfig, DecentralizedAverager};
4use crate::error::DistributedResult;
5use crate::moe::DistributedMoE;
6use crate::DistributedConfig;
7use tracing::{debug, info};
8
9pub struct DistributedCoordinator {
11 config: DistributedConfig,
13 moe: Option<DistributedMoE>,
15 averager: Option<DecentralizedAverager>,
17 is_running: bool,
19}
20
21impl DistributedCoordinator {
22 pub fn new(config: DistributedConfig) -> Self {
24 Self {
25 config,
26 moe: None,
27 averager: None,
28 is_running: false,
29 }
30 }
31
32 pub fn initialize(&mut self) -> DistributedResult<()> {
34 info!(
35 moe = self.config.enable_moe,
36 averaging = self.config.enable_averaging,
37 "Initializing DistributedCoordinator"
38 );
39 if self.config.enable_averaging {
40 let averaging_config = AveragingConfig {
41 group_size: self.config.averaging_group_size,
42 ..Default::default()
43 };
44 self.averager = Some(DecentralizedAverager::new(averaging_config));
45 debug!(
46 group_size = self.config.averaging_group_size,
47 "Parameter averager initialized"
48 );
49 }
50
51 self.is_running = true;
55 info!("DistributedCoordinator initialized");
56 Ok(())
57 }
58
59 pub fn is_enabled(&self) -> bool {
61 self.config.enable_moe || self.config.enable_averaging
62 }
63
64 pub fn moe(&self) -> Option<&DistributedMoE> {
66 self.moe.as_ref()
67 }
68
69 pub fn moe_mut(&mut self) -> Option<&mut DistributedMoE> {
71 self.moe.as_mut()
72 }
73
74 pub fn averager(&self) -> Option<&DecentralizedAverager> {
76 self.averager.as_ref()
77 }
78
79 pub fn averager_mut(&mut self) -> Option<&mut DecentralizedAverager> {
81 self.averager.as_mut()
82 }
83
84 pub fn is_running(&self) -> bool {
86 self.is_running
87 }
88
89 pub fn stop(&mut self) {
91 info!("DistributedCoordinator stopping");
92 self.is_running = false;
93 }
94}
95
96impl Default for DistributedCoordinator {
97 fn default() -> Self {
98 Self::new(DistributedConfig::default())
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 fn disabled_config() -> DistributedConfig {
107 DistributedConfig {
108 enable_moe: false,
109 enable_averaging: false,
110 ..DistributedConfig::default()
111 }
112 }
113
114 #[test]
115 fn test_disabled_coordinator_not_enabled() {
116 let coord = DistributedCoordinator::new(disabled_config());
117 assert!(!coord.is_enabled());
118 }
119
120 #[test]
121 fn test_not_running_before_init() {
122 let coord = DistributedCoordinator::new(DistributedConfig::default());
123 assert!(!coord.is_running());
124 }
125
126 #[test]
127 fn test_initialize_sets_running() {
128 let mut coord = DistributedCoordinator::new(DistributedConfig::default());
129 coord.initialize().unwrap();
130 assert!(coord.is_running());
131 }
132
133 #[test]
134 fn test_stop_clears_running() {
135 let mut coord = DistributedCoordinator::new(DistributedConfig::default());
136 coord.initialize().unwrap();
137 coord.stop();
138 assert!(!coord.is_running());
139 }
140
141 #[test]
142 fn test_averaging_enabled_creates_averager() {
143 let cfg = DistributedConfig {
144 enable_moe: false,
145 enable_averaging: true,
146 ..DistributedConfig::default()
147 };
148 let mut coord = DistributedCoordinator::new(cfg);
149 coord.initialize().unwrap();
150 assert!(coord.averager().is_some());
151 assert!(coord.moe().is_none());
152 }
153
154 #[test]
155 fn test_averaging_disabled_no_averager() {
156 let mut coord = DistributedCoordinator::new(disabled_config());
157 coord.initialize().unwrap();
158 assert!(coord.averager().is_none());
159 }
160
161 #[test]
162 fn test_is_enabled_with_averaging_only() {
163 let cfg = DistributedConfig {
164 enable_moe: false,
165 enable_averaging: true,
166 ..DistributedConfig::default()
167 };
168 let coord = DistributedCoordinator::new(cfg);
169 assert!(coord.is_enabled());
170 }
171
172 #[test]
173 fn test_default_coordinator() {
174 let coord = DistributedCoordinator::default();
175 assert!(coord.is_enabled());
177 }
178}