1use super::distributed::{DistributedConfig, WireMessage};
16use super::gradient_server::{read_wire_message, send_wire_message};
17use std::net::TcpStream;
18
19pub struct WorkerClient {
21 config: DistributedConfig,
22 stream: TcpStream,
23 worker_id: u32,
24 total_workers: u32,
25}
26
27#[derive(Debug, Clone)]
29pub struct ShardAssignment {
30 pub step: u64,
31 pub shard_start: usize,
32 pub shard_end: usize,
33}
34
35#[derive(Debug, Clone)]
37pub struct AveragedResult {
38 pub step: u64,
39 pub gradients: Vec<f32>,
40 pub global_loss: f32,
41}
42
43#[derive(Debug, Clone)]
45pub struct AveragedBlockResult {
46 pub step: u64,
47 pub block_idx: u32,
48 pub gradients: Vec<f32>,
49 pub component_sizes: Vec<u32>,
50}
51
52#[derive(Debug, Clone)]
54pub struct AveragedNonBlockResult {
55 pub step: u64,
56 pub component: u8,
57 pub gradients: Vec<f32>,
58}
59
60impl WorkerClient {
61 pub fn connect(
71 config: DistributedConfig,
72 gpu_count: u32,
73 backend: &str,
74 ) -> Result<Self, String> {
75 let coord_addr = config
76 .coordinator_addr
77 .ok_or_else(|| "worker config must have coordinator_addr".to_string())?;
78
79 eprintln!("[worker {}] Connecting to coordinator at {coord_addr}...", config.node_id);
80
81 let stream = TcpStream::connect(coord_addr)
82 .map_err(|e| format!("failed to connect to {coord_addr}: {e}"))?;
83
84 let join = WireMessage::JoinRequest {
86 node_id: config.node_id.clone(),
87 gpu_count,
88 backend: backend.to_string(),
89 };
90 send_wire_message(&stream, &join)?;
91
92 let response = read_wire_message(&stream)?;
94 match response {
95 WireMessage::JoinAccepted { worker_id, total_workers } => {
96 eprintln!(
97 "[worker {}] Joined as worker {worker_id}/{total_workers}",
98 config.node_id
99 );
100 Ok(Self { config, stream, worker_id, total_workers })
101 }
102 other => Err(format!("expected JoinAccepted, got {other:?}")),
103 }
104 }
105
106 pub fn receive_shard(&self) -> Result<Option<ShardAssignment>, String> {
113 let msg = read_wire_message(&self.stream)?;
114 match msg {
115 WireMessage::ShardAssignment { step, shard_start, shard_end } => {
116 Ok(Some(ShardAssignment { step, shard_start, shard_end }))
117 }
118 WireMessage::Shutdown => {
119 eprintln!("[worker {}] Received shutdown from coordinator", self.config.node_id);
120 Ok(None)
121 }
122 other => Err(format!("expected ShardAssignment or Shutdown, got {other:?}")),
123 }
124 }
125
126 pub fn send_gradients(
138 &self,
139 step: u64,
140 gradients: Vec<f32>,
141 loss: f32,
142 correct: usize,
143 total: usize,
144 ) -> Result<(), String> {
145 let msg = WireMessage::GradientPayload {
146 step,
147 worker_id: self.worker_id,
148 gradients,
149 loss,
150 correct,
151 total,
152 };
153 send_wire_message(&self.stream, &msg)
154 }
155
156 pub fn receive_averaged(&self) -> Result<AveragedResult, String> {
161 let msg = read_wire_message(&self.stream)?;
162 match msg {
163 WireMessage::AveragedGradient { step, gradients, global_loss } => {
164 Ok(AveragedResult { step, gradients, global_loss })
165 }
166 WireMessage::Shutdown => Err("shutdown during AllReduce".to_string()),
167 other => Err(format!("expected AveragedGradient, got {other:?}")),
168 }
169 }
170
171 pub fn send_block_gradient(
182 &self,
183 step: u64,
184 block_idx: u32,
185 num_blocks: u32,
186 gradients: Vec<f32>,
187 component_sizes: Vec<u32>,
188 ) -> Result<(), String> {
189 let msg = WireMessage::BlockGradientPayload {
190 step,
191 worker_id: self.worker_id,
192 block_idx,
193 num_blocks,
194 gradients,
195 component_sizes,
196 };
197 send_wire_message(&self.stream, &msg)
198 }
199
200 pub fn receive_averaged_block(&self) -> Result<AveragedBlockResult, String> {
202 let msg = read_wire_message(&self.stream)?;
203 match msg {
204 WireMessage::AveragedBlockGradient { step, block_idx, gradients, component_sizes } => {
205 Ok(AveragedBlockResult { step, block_idx, gradients, component_sizes })
206 }
207 WireMessage::Shutdown => Err("shutdown during block AllReduce".to_string()),
208 other => Err(format!("expected AveragedBlockGradient, got {other:?}")),
209 }
210 }
211
212 pub fn send_non_block_gradient(
219 &self,
220 step: u64,
221 component: u8,
222 gradients: Vec<f32>,
223 ) -> Result<(), String> {
224 let msg = WireMessage::NonBlockGradientPayload {
225 step,
226 worker_id: self.worker_id,
227 component,
228 gradients,
229 };
230 send_wire_message(&self.stream, &msg)
231 }
232
233 pub fn receive_averaged_non_block(&self) -> Result<AveragedNonBlockResult, String> {
235 let msg = read_wire_message(&self.stream)?;
236 match msg {
237 WireMessage::AveragedNonBlockGradient { step, component, gradients } => {
238 Ok(AveragedNonBlockResult { step, component, gradients })
239 }
240 WireMessage::Shutdown => Err("shutdown during non-block AllReduce".to_string()),
241 other => Err(format!("expected AveragedNonBlockGradient, got {other:?}")),
242 }
243 }
244
245 #[must_use]
247 pub fn worker_id(&self) -> u32 {
248 self.worker_id
249 }
250
251 #[must_use]
253 pub fn total_workers(&self) -> u32 {
254 self.total_workers
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 #![allow(clippy::unwrap_used)]
261 use super::super::distributed::DistributedConfig;
262 use super::super::gradient_server::GradientServer;
263 use super::*;
264 use std::thread;
265
266 #[test]
267 fn test_worker_connect_and_join() {
268 let server_config =
269 DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
270 let mut server = GradientServer::bind(server_config).expect("valid");
271 let addr = server.local_addr();
272
273 let handle = thread::spawn(move || {
274 let worker_config = DistributedConfig::worker(addr);
275 let client = WorkerClient::connect(worker_config, 1, "cpu").expect("valid");
276 assert_eq!(client.worker_id(), 0);
277 assert_eq!(client.total_workers(), 1);
278 client
279 });
280
281 server.wait_for_workers().expect("valid");
282 let _client = handle.join().expect("valid");
283 }
284
285 #[test]
286 fn test_worker_block_gradient_roundtrip() {
287 let server_config =
288 DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
289 let mut server = GradientServer::bind(server_config).expect("valid");
290 let addr = server.local_addr();
291
292 let component_sizes = vec![4, 2, 2, 4, 8, 8, 8, 1, 1];
293 let total: u32 = component_sizes.iter().sum();
294 let grads: Vec<f32> = (0..total).map(|i| i as f32 * 0.1).collect();
295
296 let grads_clone = grads.clone();
297 let sizes_clone = component_sizes.clone();
298 let handle = thread::spawn(move || {
299 let worker_config = DistributedConfig::worker(addr);
300 let client = WorkerClient::connect(worker_config, 1, "cuda").expect("valid");
301
302 client.send_block_gradient(0, 5, 24, grads_clone, sizes_clone).expect("valid");
304
305 let avg = client.receive_averaged_block().expect("valid");
307 assert_eq!(avg.step, 0);
308 assert_eq!(avg.block_idx, 5);
309 assert_eq!(avg.gradients.len(), total as usize);
311 avg
312 });
313
314 server.wait_for_workers().expect("valid");
315 let result = server.collect_and_reduce_block(0, 5).expect("valid");
316 assert_eq!(result.block_idx, 5);
317 assert_eq!(result.avg_gradients.len(), total as usize);
318 server.broadcast_averaged_block(0, &result).expect("valid");
319
320 let avg = handle.join().expect("valid");
321 for (a, b) in avg.gradients.iter().zip(grads.iter()) {
323 assert!((a - b).abs() < 1e-6, "gradient mismatch: {a} != {b}");
324 }
325 }
326
327 #[test]
328 fn test_worker_non_block_gradient_roundtrip() {
329 let server_config =
330 DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
331 let mut server = GradientServer::bind(server_config).expect("valid");
332 let addr = server.local_addr();
333
334 let grads = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
335
336 let grads_clone = grads.clone();
337 let handle = thread::spawn(move || {
338 let worker_config = DistributedConfig::worker(addr);
339 let client = WorkerClient::connect(worker_config, 1, "cuda").expect("valid");
340
341 client.send_non_block_gradient(0, 0, grads_clone).expect("valid");
343
344 let avg = client.receive_averaged_non_block().expect("valid");
346 assert_eq!(avg.step, 0);
347 assert_eq!(avg.component, 0);
348 avg
349 });
350
351 server.wait_for_workers().expect("valid");
352 let result = server.collect_and_reduce_non_block(0, 0).expect("valid");
353 assert_eq!(result.component, 0);
354 server.broadcast_averaged_non_block(0, &result).expect("valid");
355
356 let avg = handle.join().expect("valid");
357 for (a, b) in avg.gradients.iter().zip(grads.iter()) {
358 assert!((a - b).abs() < 1e-6, "gradient mismatch: {a} != {b}");
359 }
360 }
361
362 #[test]
363 fn test_two_worker_block_allreduce() {
364 let server_config =
365 DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 2);
366 let mut server = GradientServer::bind(server_config).expect("valid");
367 let addr = server.local_addr();
368
369 let component_sizes = vec![2, 1, 1, 2, 2, 2, 2, 1, 1];
370 let total: u32 = component_sizes.iter().sum();
371
372 let sizes0 = component_sizes.clone();
374 let h0 = thread::spawn(move || {
375 let cfg = DistributedConfig::worker(addr);
376 let client = WorkerClient::connect(cfg, 1, "cuda").expect("valid");
377 let grads = vec![1.0f32; total as usize];
378 client.send_block_gradient(0, 0, 1, grads, sizes0).expect("valid");
379 client.receive_averaged_block().expect("valid")
380 });
381
382 let sizes1 = component_sizes.clone();
384 let h1 = thread::spawn(move || {
385 let cfg = DistributedConfig::worker(addr);
386 let client = WorkerClient::connect(cfg, 1, "cuda").expect("valid");
387 let grads = vec![3.0f32; total as usize];
388 client.send_block_gradient(0, 0, 1, grads, sizes1).expect("valid");
389 client.receive_averaged_block().expect("valid")
390 });
391
392 server.wait_for_workers().expect("valid");
393 let result = server.collect_and_reduce_block(0, 0).expect("valid");
394 server.broadcast_averaged_block(0, &result).expect("valid");
395
396 let avg0 = h0.join().expect("valid");
397 let avg1 = h1.join().expect("valid");
398
399 for g in &avg0.gradients {
401 assert!((g - 2.0).abs() < 1e-6, "expected 2.0, got {g}");
402 }
403 for g in &avg1.gradients {
404 assert!((g - 2.0).abs() < 1e-6, "expected 2.0, got {g}");
405 }
406 }
407
408 #[test]
409 fn test_worker_full_training_step() {
410 let server_config =
411 DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
412 let mut server = GradientServer::bind(server_config).expect("valid");
413 let addr = server.local_addr();
414
415 let handle = thread::spawn(move || {
416 let worker_config = DistributedConfig::worker(addr);
417 let client = WorkerClient::connect(worker_config, 1, "cpu").expect("valid");
418
419 let shard = client.receive_shard().expect("valid").expect("should get shard");
421 assert_eq!(shard.step, 0);
422 assert_eq!(shard.shard_start, 0);
423 assert_eq!(shard.shard_end, 50);
424
425 client.send_gradients(0, vec![1.0, 2.0, 3.0], 0.5, 48, 50).expect("valid");
427
428 let avg = client.receive_averaged().expect("valid");
430 assert_eq!(avg.step, 0);
431 assert_eq!(avg.gradients, vec![1.0, 2.0, 3.0]); assert!((avg.global_loss - 0.5).abs() < 1e-5);
433
434 client
435 });
436
437 server.wait_for_workers().expect("valid");
438 server.set_total_samples(50);
439 server.send_shard_assignments(0).expect("valid");
440 let result = server.collect_and_reduce(0).expect("valid");
441 server.broadcast_averaged(0, &result).expect("valid");
442
443 let _client = handle.join().expect("valid");
444 }
445}