1use super::data_parallel::{average_gradients, has_non_finite};
19use super::distributed::{DistributedConfig, WireMessage};
20use std::io::{Read, Write};
21use std::net::{TcpListener, TcpStream};
22use std::time::Instant;
23
24#[derive(Debug)]
26struct WorkerConnection {
27 worker_id: u32,
28 #[allow(dead_code)]
29 node_id: String,
30 #[allow(dead_code)]
31 gpu_count: u32,
32 #[allow(dead_code)]
33 backend: String,
34 stream: TcpStream,
35}
36
37pub struct GradientServer {
39 config: DistributedConfig,
40 listener: TcpListener,
41 workers: Vec<WorkerConnection>,
42 total_samples: usize,
43}
44
45#[derive(Debug, Clone)]
47pub struct AllReduceResult {
48 pub avg_gradients: Vec<f32>,
50 pub global_loss: f32,
52 pub total_correct: usize,
54 pub total_samples: usize,
56 pub allreduce_ms: f64,
58}
59
60#[derive(Debug, Clone)]
62pub struct BlockAllReduceResult {
63 pub block_idx: u32,
65 pub avg_gradients: Vec<f32>,
67 pub component_sizes: Vec<u32>,
69 pub allreduce_ms: f64,
71}
72
73#[derive(Debug, Clone)]
75pub struct NonBlockAllReduceResult {
76 pub component: u8,
78 pub avg_gradients: Vec<f32>,
80 pub allreduce_ms: f64,
82}
83
84impl GradientServer {
85 pub fn bind(config: DistributedConfig) -> Result<Self, String> {
90 let listener = TcpListener::bind(config.bind_addr)
91 .map_err(|e| format!("failed to bind {}: {e}", config.bind_addr))?;
92 eprintln!(
93 "[coordinator] Listening on {} (expecting {} workers)",
94 config.bind_addr, config.expect_workers
95 );
96 Ok(Self { config, listener, workers: Vec::new(), total_samples: 0 })
97 }
98
99 pub fn wait_for_workers(&mut self) -> Result<(), String> {
106 let expected = self.config.expect_workers;
107 eprintln!("[coordinator] Waiting for {expected} workers to connect...");
108
109 while self.workers.len() < expected {
110 let (stream, addr) =
111 self.listener.accept().map_err(|e| format!("accept failed: {e}"))?;
112 eprintln!("[coordinator] Connection from {addr}");
113
114 let msg = read_wire_message(&stream)?;
116 match msg {
117 WireMessage::JoinRequest { node_id, gpu_count, backend } => {
118 let worker_id = self.workers.len() as u32;
119 eprintln!(
120 "[coordinator] Worker {worker_id} joined: {node_id} ({gpu_count} GPUs, {backend})"
121 );
122
123 let response =
125 WireMessage::JoinAccepted { worker_id, total_workers: expected as u32 };
126 send_wire_message(&stream, &response)?;
127
128 self.workers.push(WorkerConnection {
129 worker_id,
130 node_id,
131 gpu_count,
132 backend,
133 stream,
134 });
135 }
136 other => {
137 return Err(format!("expected JoinRequest, got {other:?}"));
138 }
139 }
140 }
141
142 eprintln!("[coordinator] All {expected} workers connected");
143 Ok(())
144 }
145
146 pub fn set_total_samples(&mut self, n: usize) {
148 self.total_samples = n;
149 }
150
151 pub fn send_shard_assignments(&mut self, step: u64) -> Result<(), String> {
156 let n = self.workers.len();
157 let shard_size = self.total_samples / n;
158
159 for (i, worker) in self.workers.iter().enumerate() {
160 let start = i * shard_size;
161 let end = if i == n - 1 { self.total_samples } else { start + shard_size };
162 let msg = WireMessage::ShardAssignment { step, shard_start: start, shard_end: end };
163 send_wire_message(&worker.stream, &msg)?;
164 }
165 Ok(())
166 }
167
168 pub fn collect_and_reduce(&mut self, step: u64) -> Result<AllReduceResult, String> {
177 let start = Instant::now();
178 let n = self.workers.len();
179 let mut all_grads: Vec<Vec<f32>> = Vec::with_capacity(n);
180 let mut total_loss = 0.0f32;
181 let mut total_correct = 0usize;
182 let mut total_samples = 0usize;
183
184 for worker in &self.workers {
185 let msg = read_wire_message(&worker.stream)?;
186 match msg {
187 WireMessage::GradientPayload {
188 step: recv_step,
189 gradients,
190 loss,
191 correct,
192 total,
193 ..
194 } => {
195 if recv_step != step {
196 return Err(format!("step mismatch: expected {step}, got {recv_step}"));
197 }
198
199 if has_non_finite(&gradients) {
201 return Err(format!(
202 "JIDOKA HALT: worker {} sent non-finite gradient at step {step}",
203 worker.worker_id
204 ));
205 }
206
207 total_loss += loss * total as f32;
208 total_correct += correct;
209 total_samples += total;
210 all_grads.push(gradients);
211 }
212 other => {
213 return Err(format!(
214 "expected GradientPayload from worker {}, got {other:?}",
215 worker.worker_id
216 ));
217 }
218 }
219 }
220
221 let avg_gradients = average_gradients(&all_grads);
223 let global_loss = if total_samples > 0 { total_loss / total_samples as f32 } else { 0.0 };
224
225 let allreduce_ms = start.elapsed().as_secs_f64() * 1000.0;
226
227 Ok(AllReduceResult {
228 avg_gradients,
229 global_loss,
230 total_correct,
231 total_samples,
232 allreduce_ms,
233 })
234 }
235
236 pub fn broadcast_averaged(
241 &mut self,
242 step: u64,
243 result: &AllReduceResult,
244 ) -> Result<(), String> {
245 let msg = WireMessage::AveragedGradient {
246 step,
247 gradients: result.avg_gradients.clone(),
248 global_loss: result.global_loss,
249 };
250 for worker in &self.workers {
251 send_wire_message(&worker.stream, &msg)?;
252 }
253 Ok(())
254 }
255
256 pub fn shutdown_workers(&mut self) {
258 for worker in &self.workers {
259 let _ = send_wire_message(&worker.stream, &WireMessage::Shutdown);
260 }
261 }
262
263 #[must_use]
265 pub fn worker_count(&self) -> usize {
266 self.workers.len()
267 }
268
269 pub fn collect_and_reduce_block(
283 &mut self,
284 step: u64,
285 block_idx: u32,
286 ) -> Result<BlockAllReduceResult, String> {
287 let start = Instant::now();
288 let n = self.workers.len();
289 let mut all_grads: Vec<Vec<f32>> = Vec::with_capacity(n);
290 let mut component_sizes = Vec::new();
291
292 for worker in &self.workers {
293 let msg = read_wire_message(&worker.stream)?;
294 match msg {
295 WireMessage::BlockGradientPayload {
296 step: recv_step,
297 block_idx: recv_block_idx,
298 gradients,
299 component_sizes: cs,
300 ..
301 } => {
302 if recv_step != step {
303 return Err(format!("step mismatch: expected {step}, got {recv_step}"));
304 }
305 if recv_block_idx != block_idx {
306 return Err(format!(
307 "block_idx mismatch: expected {block_idx}, got {recv_block_idx}"
308 ));
309 }
310 if has_non_finite(&gradients) {
311 return Err(format!(
312 "JIDOKA HALT: worker {} sent non-finite block {block_idx} gradient at step {step}",
313 worker.worker_id
314 ));
315 }
316 if component_sizes.is_empty() {
317 component_sizes = cs;
318 }
319 all_grads.push(gradients);
320 }
321 other => {
322 return Err(format!(
323 "expected BlockGradientPayload from worker {}, got {other:?}",
324 worker.worker_id
325 ));
326 }
327 }
328 }
329
330 let avg_gradients = average_gradients(&all_grads);
331 let allreduce_ms = start.elapsed().as_secs_f64() * 1000.0;
332
333 Ok(BlockAllReduceResult { block_idx, avg_gradients, component_sizes, allreduce_ms })
334 }
335
336 pub fn broadcast_averaged_block(
341 &mut self,
342 step: u64,
343 result: &BlockAllReduceResult,
344 ) -> Result<(), String> {
345 let msg = WireMessage::AveragedBlockGradient {
346 step,
347 block_idx: result.block_idx,
348 gradients: result.avg_gradients.clone(),
349 component_sizes: result.component_sizes.clone(),
350 };
351 for worker in &self.workers {
352 send_wire_message(&worker.stream, &msg)?;
353 }
354 Ok(())
355 }
356
357 pub fn collect_and_reduce_non_block(
364 &mut self,
365 step: u64,
366 expected_component: u8,
367 ) -> Result<NonBlockAllReduceResult, String> {
368 let start = Instant::now();
369 let n = self.workers.len();
370 let mut all_grads: Vec<Vec<f32>> = Vec::with_capacity(n);
371
372 for worker in &self.workers {
373 let msg = read_wire_message(&worker.stream)?;
374 match msg {
375 WireMessage::NonBlockGradientPayload {
376 step: recv_step,
377 component,
378 gradients,
379 ..
380 } => {
381 if recv_step != step {
382 return Err(format!("step mismatch: expected {step}, got {recv_step}"));
383 }
384 if component != expected_component {
385 return Err(format!(
386 "component mismatch: expected {expected_component}, got {component}"
387 ));
388 }
389 if has_non_finite(&gradients) {
390 return Err(format!(
391 "JIDOKA HALT: worker {} sent non-finite component {component} gradient at step {step}",
392 worker.worker_id
393 ));
394 }
395 all_grads.push(gradients);
396 }
397 other => {
398 return Err(format!(
399 "expected NonBlockGradientPayload from worker {}, got {other:?}",
400 worker.worker_id
401 ));
402 }
403 }
404 }
405
406 let avg_gradients = average_gradients(&all_grads);
407 let allreduce_ms = start.elapsed().as_secs_f64() * 1000.0;
408
409 Ok(NonBlockAllReduceResult { component: expected_component, avg_gradients, allreduce_ms })
410 }
411
412 pub fn broadcast_averaged_non_block(
414 &mut self,
415 step: u64,
416 result: &NonBlockAllReduceResult,
417 ) -> Result<(), String> {
418 let msg = WireMessage::AveragedNonBlockGradient {
419 step,
420 component: result.component,
421 gradients: result.avg_gradients.clone(),
422 };
423 for worker in &self.workers {
424 send_wire_message(&worker.stream, &msg)?;
425 }
426 Ok(())
427 }
428}
429
430pub(crate) fn read_wire_message(stream: &TcpStream) -> Result<WireMessage, String> {
434 let mut len_buf = [0u8; 4];
435 (&*stream).read_exact(&mut len_buf).map_err(|e| format!("read length failed: {e}"))?;
436 let len = u32::from_be_bytes(len_buf) as usize;
437
438 if len > 100_000_000 {
439 return Err(format!("message too large: {len} bytes"));
440 }
441
442 let mut payload = vec![0u8; len];
443 (&*stream).read_exact(&mut payload).map_err(|e| format!("read payload failed: {e}"))?;
444
445 WireMessage::from_payload(&payload)
446}
447
448pub(crate) fn send_wire_message(stream: &TcpStream, msg: &WireMessage) -> Result<(), String> {
450 let bytes = msg.to_bytes();
451 (&*stream).write_all(&bytes).map_err(|e| format!("send failed: {e}"))?;
452 (&*stream).flush().map_err(|e| format!("flush failed: {e}"))?;
453 Ok(())
454}
455
456impl GradientServer {
457 #[must_use]
461 pub fn local_addr(&self) -> std::net::SocketAddr {
462 self.listener.local_addr().expect("listener has local addr")
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 #![allow(clippy::unwrap_used)]
469 use super::*;
470 use std::net::TcpStream;
471 use std::thread;
472
473 #[test]
474 fn test_server_bind() {
475 let config = DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
477 let server = GradientServer::bind(config);
478 assert!(server.is_ok());
479 }
480
481 #[test]
482 fn test_server_worker_count_initially_zero() {
483 let config = DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
484 let server = GradientServer::bind(config).expect("valid");
485 assert_eq!(server.worker_count(), 0);
486 }
487
488 #[test]
489 fn test_server_accept_worker() {
490 let config = DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
491 let mut server = GradientServer::bind(config).expect("valid");
492 let addr = server.local_addr();
493
494 let handle = thread::spawn(move || {
496 let stream = TcpStream::connect(addr).expect("valid");
497 let join = WireMessage::JoinRequest {
498 node_id: "test-worker".to_string(),
499 gpu_count: 1,
500 backend: "cpu".to_string(),
501 };
502 send_wire_message(&stream, &join).expect("valid");
503
504 let response = read_wire_message(&stream).expect("valid");
506 match response {
507 WireMessage::JoinAccepted { worker_id, total_workers } => {
508 assert_eq!(worker_id, 0);
509 assert_eq!(total_workers, 1);
510 }
511 other => panic!("expected JoinAccepted, got {other:?}"),
512 }
513 stream
514 });
515
516 server.wait_for_workers().expect("valid");
517 assert_eq!(server.worker_count(), 1);
518
519 let _stream = handle.join().expect("valid");
520 }
521
522 #[test]
523 fn test_server_shard_and_reduce() {
524 let config = DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 2);
525 let mut server = GradientServer::bind(config).expect("valid");
526 let addr = server.local_addr();
527
528 let handles: Vec<_> = (0..2)
530 .map(|i| {
531 thread::spawn(move || {
532 let stream = TcpStream::connect(addr).expect("valid");
533 let join = WireMessage::JoinRequest {
534 node_id: format!("worker-{i}"),
535 gpu_count: 1,
536 backend: "cpu".to_string(),
537 };
538 send_wire_message(&stream, &join).expect("valid");
539 let _ = read_wire_message(&stream).expect("valid"); let shard_msg = read_wire_message(&stream).expect("valid");
543 let (shard_start, shard_end) = match shard_msg {
544 WireMessage::ShardAssignment { shard_start, shard_end, .. } => {
545 (shard_start, shard_end)
546 }
547 other => panic!("expected ShardAssignment, got {other:?}"),
548 };
549
550 let grad = WireMessage::GradientPayload {
552 step: 0,
553 worker_id: i,
554 gradients: vec![1.0 + i as f32, 2.0 + i as f32],
555 loss: 0.5 + i as f32 * 0.1,
556 correct: shard_end - shard_start,
557 total: shard_end - shard_start,
558 };
559 send_wire_message(&stream, &grad).expect("valid");
560
561 let avg_msg = read_wire_message(&stream).expect("valid");
563 match avg_msg {
564 WireMessage::AveragedGradient { gradients, .. } => {
565 assert!((gradients[0] - 1.5).abs() < 1e-5);
567 assert!((gradients[1] - 2.5).abs() < 1e-5);
568 }
569 other => panic!("expected AveragedGradient, got {other:?}"),
570 }
571
572 stream
573 })
574 })
575 .collect();
576
577 server.wait_for_workers().expect("valid");
579 server.set_total_samples(100);
580 server.send_shard_assignments(0).expect("valid");
581 let result = server.collect_and_reduce(0).expect("valid");
582
583 assert!((result.avg_gradients[0] - 1.5).abs() < 1e-5);
584 assert!((result.avg_gradients[1] - 2.5).abs() < 1e-5);
585 assert_eq!(result.total_samples, 100);
586 assert!(result.allreduce_ms >= 0.0);
587
588 server.broadcast_averaged(0, &result).expect("valid");
589
590 for h in handles {
591 let _stream = h.join().expect("valid");
592 }
593 }
594
595 #[test]
596 fn test_server_jidoka_halt_on_nan() {
597 let config = DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
598 let mut server = GradientServer::bind(config).expect("valid");
599 let addr = server.local_addr();
600
601 let handle = thread::spawn(move || {
602 let stream = TcpStream::connect(addr).expect("valid");
603 let join = WireMessage::JoinRequest {
604 node_id: "bad-worker".to_string(),
605 gpu_count: 1,
606 backend: "cpu".to_string(),
607 };
608 send_wire_message(&stream, &join).expect("valid");
609 let _ = read_wire_message(&stream).expect("valid");
610
611 let _ = read_wire_message(&stream).expect("valid");
613
614 let grad = WireMessage::GradientPayload {
616 step: 0,
617 worker_id: 0,
618 gradients: vec![1.0, f32::NAN, 3.0],
619 loss: 0.5,
620 correct: 5,
621 total: 10,
622 };
623 send_wire_message(&stream, &grad).expect("valid");
624 stream
625 });
626
627 server.wait_for_workers().expect("valid");
628 server.set_total_samples(10);
629 server.send_shard_assignments(0).expect("valid");
630 let result = server.collect_and_reduce(0);
631 assert!(result.is_err());
632 assert!(result.unwrap_err().contains("JIDOKA HALT"));
633
634 let _stream = handle.join().expect("valid");
635 }
636}