1use std::fmt;
28use std::net::SocketAddr;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum NodeRole {
33 Coordinator,
35 Worker,
37}
38
39impl fmt::Display for NodeRole {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 Self::Coordinator => write!(f, "coordinator"),
43 Self::Worker => write!(f, "worker"),
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct DistributedConfig {
51 pub role: NodeRole,
53 pub bind_addr: SocketAddr,
55 pub coordinator_addr: Option<SocketAddr>,
57 pub expect_workers: usize,
59 pub heartbeat_interval_ms: u64,
61 pub heartbeat_timeout_ms: u64,
63 pub node_id: String,
65}
66
67impl DistributedConfig {
68 #[must_use]
74 pub fn coordinator(bind_addr: SocketAddr, expect_workers: usize) -> Self {
75 Self {
76 role: NodeRole::Coordinator,
77 bind_addr,
78 coordinator_addr: None,
79 expect_workers,
80 heartbeat_interval_ms: 5000,
81 heartbeat_timeout_ms: 30000,
82 node_id: Self::default_node_id(),
83 }
84 }
85
86 #[must_use]
91 pub fn worker(coordinator_addr: SocketAddr) -> Self {
92 Self {
93 role: NodeRole::Worker,
94 bind_addr: "0.0.0.0:0".parse().expect("valid addr"),
95 coordinator_addr: Some(coordinator_addr),
96 expect_workers: 0,
97 heartbeat_interval_ms: 5000,
98 heartbeat_timeout_ms: 30000,
99 node_id: Self::default_node_id(),
100 }
101 }
102
103 #[must_use]
105 pub fn is_coordinator(&self) -> bool {
106 self.role == NodeRole::Coordinator
107 }
108
109 fn default_node_id() -> String {
110 let hostname = hostname::get()
111 .map_or_else(|_| "unknown".to_string(), |h| h.to_string_lossy().to_string());
112 let pid = std::process::id();
113 format!("{hostname}-{pid}")
114 }
115}
116
117impl Default for DistributedConfig {
118 fn default() -> Self {
119 Self::coordinator("0.0.0.0:9000".parse().expect("valid addr"), 1)
120 }
121}
122
123#[derive(Debug, Clone)]
130pub enum WireMessage {
131 JoinRequest { node_id: String, gpu_count: u32, backend: String },
133 JoinAccepted { worker_id: u32, total_workers: u32 },
135 ShardAssignment { step: u64, shard_start: usize, shard_end: usize },
137 GradientPayload {
139 step: u64,
140 worker_id: u32,
141 gradients: Vec<f32>,
143 loss: f32,
144 correct: usize,
145 total: usize,
146 },
147 AveragedGradient { step: u64, gradients: Vec<f32>, global_loss: f32 },
149 Heartbeat { node_id: String, timestamp_ms: u64 },
151 Shutdown,
153
154 BlockGradientPayload {
161 step: u64,
162 worker_id: u32,
163 block_idx: u32,
164 num_blocks: u32,
165 gradients: Vec<f32>,
167 component_sizes: Vec<u32>,
169 },
170 AveragedBlockGradient {
172 step: u64,
173 block_idx: u32,
174 gradients: Vec<f32>,
176 component_sizes: Vec<u32>,
178 },
179 NonBlockGradientPayload {
181 step: u64,
182 worker_id: u32,
183 component: u8,
185 gradients: Vec<f32>,
187 },
188 AveragedNonBlockGradient {
190 step: u64,
191 component: u8,
193 gradients: Vec<f32>,
195 },
196}
197
198impl WireMessage {
199 pub fn to_bytes(&self) -> Vec<u8> {
203 let payload = self.serialize_payload();
204 let len = payload.len() as u32;
205 let mut buf = Vec::with_capacity(4 + payload.len());
206 buf.extend_from_slice(&len.to_be_bytes());
207 buf.extend_from_slice(&payload);
208 buf
209 }
210
211 pub fn from_payload(payload: &[u8]) -> Result<Self, String> {
216 Self::deserialize_payload(payload)
217 }
218
219 fn serialize_payload(&self) -> Vec<u8> {
221 let mut buf = Vec::new();
222 match self {
223 Self::JoinRequest { node_id, gpu_count, backend } => {
224 buf.push(0x01);
225 write_string(&mut buf, node_id);
226 buf.extend_from_slice(&gpu_count.to_le_bytes());
227 write_string(&mut buf, backend);
228 }
229 Self::JoinAccepted { worker_id, total_workers } => {
230 buf.push(0x02);
231 buf.extend_from_slice(&worker_id.to_le_bytes());
232 buf.extend_from_slice(&total_workers.to_le_bytes());
233 }
234 Self::ShardAssignment { step, shard_start, shard_end } => {
235 buf.push(0x03);
236 buf.extend_from_slice(&step.to_le_bytes());
237 buf.extend_from_slice(&(*shard_start as u64).to_le_bytes());
238 buf.extend_from_slice(&(*shard_end as u64).to_le_bytes());
239 }
240 Self::GradientPayload { step, worker_id, gradients, loss, correct, total } => {
241 buf.push(0x04);
242 buf.extend_from_slice(&step.to_le_bytes());
243 buf.extend_from_slice(&worker_id.to_le_bytes());
244 write_f32_vec(&mut buf, gradients);
245 buf.extend_from_slice(&loss.to_le_bytes());
246 buf.extend_from_slice(&(*correct as u64).to_le_bytes());
247 buf.extend_from_slice(&(*total as u64).to_le_bytes());
248 }
249 Self::AveragedGradient { step, gradients, global_loss } => {
250 buf.push(0x05);
251 buf.extend_from_slice(&step.to_le_bytes());
252 write_f32_vec(&mut buf, gradients);
253 buf.extend_from_slice(&global_loss.to_le_bytes());
254 }
255 Self::Heartbeat { node_id, timestamp_ms } => {
256 buf.push(0x06);
257 write_string(&mut buf, node_id);
258 buf.extend_from_slice(×tamp_ms.to_le_bytes());
259 }
260 Self::Shutdown => buf.push(0x07),
261 Self::BlockGradientPayload {
262 step,
263 worker_id,
264 block_idx,
265 num_blocks,
266 gradients,
267 component_sizes,
268 } => serialize_block_grad(
269 &mut buf,
270 0x08,
271 *step,
272 *worker_id,
273 *block_idx,
274 *num_blocks,
275 gradients,
276 component_sizes,
277 ),
278 Self::AveragedBlockGradient { step, block_idx, gradients, component_sizes } => {
279 serialize_averaged_block(&mut buf, *step, *block_idx, gradients, component_sizes);
280 }
281 Self::NonBlockGradientPayload { step, worker_id, component, gradients } => {
282 serialize_non_block_grad(&mut buf, *step, *worker_id, *component, gradients);
283 }
284 Self::AveragedNonBlockGradient { step, component, gradients } => {
285 serialize_averaged_non_block(&mut buf, *step, *component, gradients);
286 }
287 }
288 buf
289 }
290
291 fn deserialize_payload(data: &[u8]) -> Result<Self, String> {
292 if data.is_empty() {
293 return Err("empty payload".to_string());
294 }
295 let tag = data[0];
296 let rest = &data[1..];
297 match tag {
298 0x01 => decode_join_request(rest),
299 0x02 => decode_join_accepted(rest),
300 0x03 => decode_shard_assignment(rest),
301 0x04 => decode_gradient_payload(rest),
302 0x05 => decode_averaged_gradient(rest),
303 0x06 => decode_heartbeat(rest),
304 0x07 => Ok(Self::Shutdown),
305 0x08 => decode_block_gradient_payload(rest),
306 0x09 => decode_averaged_block_gradient(rest),
307 0x0A => decode_non_block_gradient_payload(rest),
308 0x0B => decode_averaged_non_block_gradient(rest),
309 other => Err(format!("unknown message tag: 0x{other:02x}")),
310 }
311 }
312}
313
314fn decode_join_request(rest: &[u8]) -> Result<WireMessage, String> {
315 let (node_id, rest) = read_string(rest)?;
316 if rest.len() < 4 {
317 return Err("truncated JoinRequest".to_string());
318 }
319 let gpu_count = u32::from_le_bytes(rest[..4].try_into().expect("4 bytes"));
320 let (backend, _) = read_string(&rest[4..])?;
321 Ok(WireMessage::JoinRequest { node_id, gpu_count, backend })
322}
323
324fn decode_join_accepted(rest: &[u8]) -> Result<WireMessage, String> {
325 if rest.len() < 8 {
326 return Err("truncated JoinAccepted".to_string());
327 }
328 let worker_id = u32::from_le_bytes(rest[..4].try_into().expect("4 bytes"));
329 let total_workers = u32::from_le_bytes(rest[4..8].try_into().expect("4 bytes"));
330 Ok(WireMessage::JoinAccepted { worker_id, total_workers })
331}
332
333fn decode_shard_assignment(rest: &[u8]) -> Result<WireMessage, String> {
334 if rest.len() < 24 {
335 return Err("truncated ShardAssignment".to_string());
336 }
337 let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
338 let shard_start = u64::from_le_bytes(rest[8..16].try_into().expect("8 bytes")) as usize;
339 let shard_end = u64::from_le_bytes(rest[16..24].try_into().expect("8 bytes")) as usize;
340 Ok(WireMessage::ShardAssignment { step, shard_start, shard_end })
341}
342
343fn decode_gradient_payload(rest: &[u8]) -> Result<WireMessage, String> {
344 if rest.len() < 20 {
345 return Err("truncated GradientPayload header".to_string());
346 }
347 let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
348 let worker_id = u32::from_le_bytes(rest[8..12].try_into().expect("4 bytes"));
349 let grad_len = u64::from_le_bytes(rest[12..20].try_into().expect("8 bytes")) as usize;
350 let grad_bytes = grad_len * 4;
351 if rest.len() < 20 + grad_bytes + 4 + 8 + 8 {
352 return Err("truncated GradientPayload data".to_string());
353 }
354 let gradients = read_f32_vec(rest, 20, grad_len);
355 let tail = &rest[20 + grad_bytes..];
356 let loss = f32::from_le_bytes(tail[..4].try_into().expect("4 bytes"));
357 let correct = u64::from_le_bytes(tail[4..12].try_into().expect("8 bytes")) as usize;
358 let total = u64::from_le_bytes(tail[12..20].try_into().expect("8 bytes")) as usize;
359 Ok(WireMessage::GradientPayload { step, worker_id, gradients, loss, correct, total })
360}
361
362fn decode_averaged_gradient(rest: &[u8]) -> Result<WireMessage, String> {
363 if rest.len() < 16 {
364 return Err("truncated AveragedGradient header".to_string());
365 }
366 let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
367 let grad_len = u64::from_le_bytes(rest[8..16].try_into().expect("8 bytes")) as usize;
368 let grad_bytes = grad_len * 4;
369 if rest.len() < 16 + grad_bytes + 4 {
370 return Err("truncated AveragedGradient data".to_string());
371 }
372 let gradients = read_f32_vec(rest, 16, grad_len);
373 let global_loss =
374 f32::from_le_bytes(rest[16 + grad_bytes..16 + grad_bytes + 4].try_into().expect("4 bytes"));
375 Ok(WireMessage::AveragedGradient { step, gradients, global_loss })
376}
377
378fn decode_heartbeat(rest: &[u8]) -> Result<WireMessage, String> {
379 let (node_id, rest) = read_string(rest)?;
380 if rest.len() < 8 {
381 return Err("truncated Heartbeat".to_string());
382 }
383 let timestamp_ms = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
384 Ok(WireMessage::Heartbeat { node_id, timestamp_ms })
385}
386
387fn decode_block_gradient_payload(rest: &[u8]) -> Result<WireMessage, String> {
388 if rest.len() < 24 {
390 return Err("truncated BlockGradientPayload header".to_string());
391 }
392 let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
393 let worker_id = u32::from_le_bytes(rest[8..12].try_into().expect("4 bytes"));
394 let block_idx = u32::from_le_bytes(rest[12..16].try_into().expect("4 bytes"));
395 let num_blocks = u32::from_le_bytes(rest[16..20].try_into().expect("4 bytes"));
396 let num_components = u32::from_le_bytes(rest[20..24].try_into().expect("4 bytes")) as usize;
397
398 let comp_end = 24 + num_components * 4;
399 if rest.len() < comp_end + 8 {
400 return Err("truncated BlockGradientPayload component_sizes".to_string());
401 }
402 let mut component_sizes = Vec::with_capacity(num_components);
403 for i in 0..num_components {
404 let start = 24 + i * 4;
405 component_sizes
406 .push(u32::from_le_bytes(rest[start..start + 4].try_into().expect("4 bytes")));
407 }
408
409 let grad_len =
410 u64::from_le_bytes(rest[comp_end..comp_end + 8].try_into().expect("8 bytes")) as usize;
411 let grad_start = comp_end + 8;
412 if rest.len() < grad_start + grad_len * 4 {
413 return Err("truncated BlockGradientPayload gradients".to_string());
414 }
415 let gradients = read_f32_vec(rest, grad_start, grad_len);
416
417 Ok(WireMessage::BlockGradientPayload {
418 step,
419 worker_id,
420 block_idx,
421 num_blocks,
422 gradients,
423 component_sizes,
424 })
425}
426
427fn decode_averaged_block_gradient(rest: &[u8]) -> Result<WireMessage, String> {
428 if rest.len() < 16 {
430 return Err("truncated AveragedBlockGradient header".to_string());
431 }
432 let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
433 let block_idx = u32::from_le_bytes(rest[8..12].try_into().expect("4 bytes"));
434 let num_components = u32::from_le_bytes(rest[12..16].try_into().expect("4 bytes")) as usize;
435
436 let comp_end = 16 + num_components * 4;
437 if rest.len() < comp_end + 8 {
438 return Err("truncated AveragedBlockGradient component_sizes".to_string());
439 }
440 let mut component_sizes = Vec::with_capacity(num_components);
441 for i in 0..num_components {
442 let start = 16 + i * 4;
443 component_sizes
444 .push(u32::from_le_bytes(rest[start..start + 4].try_into().expect("4 bytes")));
445 }
446
447 let grad_len =
448 u64::from_le_bytes(rest[comp_end..comp_end + 8].try_into().expect("8 bytes")) as usize;
449 let grad_start = comp_end + 8;
450 if rest.len() < grad_start + grad_len * 4 {
451 return Err("truncated AveragedBlockGradient gradients".to_string());
452 }
453 let gradients = read_f32_vec(rest, grad_start, grad_len);
454
455 Ok(WireMessage::AveragedBlockGradient { step, block_idx, gradients, component_sizes })
456}
457
458fn decode_non_block_gradient_payload(rest: &[u8]) -> Result<WireMessage, String> {
459 if rest.len() < 21 {
461 return Err("truncated NonBlockGradientPayload header".to_string());
462 }
463 let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
464 let worker_id = u32::from_le_bytes(rest[8..12].try_into().expect("4 bytes"));
465 let component = rest[12];
466 let grad_len = u64::from_le_bytes(rest[13..21].try_into().expect("8 bytes")) as usize;
467 if rest.len() < 21 + grad_len * 4 {
468 return Err("truncated NonBlockGradientPayload gradients".to_string());
469 }
470 let gradients = read_f32_vec(rest, 21, grad_len);
471
472 Ok(WireMessage::NonBlockGradientPayload { step, worker_id, component, gradients })
473}
474
475fn decode_averaged_non_block_gradient(rest: &[u8]) -> Result<WireMessage, String> {
476 if rest.len() < 17 {
478 return Err("truncated AveragedNonBlockGradient header".to_string());
479 }
480 let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
481 let component = rest[8];
482 let grad_len = u64::from_le_bytes(rest[9..17].try_into().expect("8 bytes")) as usize;
483 if rest.len() < 17 + grad_len * 4 {
484 return Err("truncated AveragedNonBlockGradient gradients".to_string());
485 }
486 let gradients = read_f32_vec(rest, 17, grad_len);
487
488 Ok(WireMessage::AveragedNonBlockGradient { step, component, gradients })
489}
490
491fn read_f32_vec(data: &[u8], offset: usize, count: usize) -> Vec<f32> {
492 let mut result = Vec::with_capacity(count);
493 for i in 0..count {
494 let start = offset + i * 4;
495 let val = f32::from_le_bytes(data[start..start + 4].try_into().expect("4 bytes"));
496 result.push(val);
497 }
498 result
499}
500
501fn write_f32_vec(buf: &mut Vec<u8>, v: &[f32]) {
503 buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
504 for &x in v {
505 buf.extend_from_slice(&x.to_le_bytes());
506 }
507}
508
509fn write_component_sizes(buf: &mut Vec<u8>, sizes: &[u32]) {
511 buf.extend_from_slice(&(sizes.len() as u32).to_le_bytes());
512 for &sz in sizes {
513 buf.extend_from_slice(&sz.to_le_bytes());
514 }
515}
516
517fn serialize_block_grad(
519 buf: &mut Vec<u8>,
520 tag: u8,
521 step: u64,
522 worker_id: u32,
523 block_idx: u32,
524 num_blocks: u32,
525 gradients: &[f32],
526 component_sizes: &[u32],
527) {
528 buf.push(tag);
529 buf.extend_from_slice(&step.to_le_bytes());
530 buf.extend_from_slice(&worker_id.to_le_bytes());
531 buf.extend_from_slice(&block_idx.to_le_bytes());
532 buf.extend_from_slice(&num_blocks.to_le_bytes());
533 write_component_sizes(buf, component_sizes);
534 write_f32_vec(buf, gradients);
535}
536
537fn serialize_averaged_block(
539 buf: &mut Vec<u8>,
540 step: u64,
541 block_idx: u32,
542 gradients: &[f32],
543 component_sizes: &[u32],
544) {
545 buf.push(0x09);
546 buf.extend_from_slice(&step.to_le_bytes());
547 buf.extend_from_slice(&block_idx.to_le_bytes());
548 write_component_sizes(buf, component_sizes);
549 write_f32_vec(buf, gradients);
550}
551
552fn serialize_non_block_grad(
554 buf: &mut Vec<u8>,
555 step: u64,
556 worker_id: u32,
557 component: u8,
558 gradients: &[f32],
559) {
560 buf.push(0x0A);
561 buf.extend_from_slice(&step.to_le_bytes());
562 buf.extend_from_slice(&worker_id.to_le_bytes());
563 buf.push(component);
564 write_f32_vec(buf, gradients);
565}
566
567fn serialize_averaged_non_block(buf: &mut Vec<u8>, step: u64, component: u8, gradients: &[f32]) {
569 buf.push(0x0B);
570 buf.extend_from_slice(&step.to_le_bytes());
571 buf.push(component);
572 write_f32_vec(buf, gradients);
573}
574
575fn write_string(buf: &mut Vec<u8>, s: &str) {
576 let bytes = s.as_bytes();
577 buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
578 buf.extend_from_slice(bytes);
579}
580
581fn read_string(data: &[u8]) -> Result<(String, &[u8]), String> {
582 if data.len() < 4 {
583 return Err("truncated string length".to_string());
584 }
585 let len = u32::from_le_bytes(data[..4].try_into().expect("4 bytes")) as usize;
586 if data.len() < 4 + len {
587 return Err("truncated string data".to_string());
588 }
589 let s =
590 String::from_utf8(data[4..4 + len].to_vec()).map_err(|e| format!("invalid utf8: {e}"))?;
591 Ok((s, &data[4 + len..]))
592}
593
594#[cfg(test)]
595mod tests {
596 #![allow(clippy::unwrap_used)]
597 use super::*;
598
599 #[test]
600 fn test_coordinator_config() {
601 let config = DistributedConfig::coordinator("0.0.0.0:9000".parse().expect("valid"), 3);
602 assert!(config.is_coordinator());
603 assert_eq!(config.role, NodeRole::Coordinator);
604 assert_eq!(config.expect_workers, 3);
605 assert!(config.coordinator_addr.is_none());
606 }
607
608 #[test]
609 fn test_worker_config() {
610 let config = DistributedConfig::worker("192.168.50.100:9000".parse().expect("valid"));
611 assert!(!config.is_coordinator());
612 assert_eq!(config.role, NodeRole::Worker);
613 assert_eq!(config.coordinator_addr, Some("192.168.50.100:9000".parse().expect("valid")));
614 }
615
616 #[test]
617 fn test_default_config() {
618 let config = DistributedConfig::default();
619 assert!(config.is_coordinator());
620 assert_eq!(config.expect_workers, 1);
621 }
622
623 #[test]
624 fn test_node_role_display() {
625 assert_eq!(NodeRole::Coordinator.to_string(), "coordinator");
626 assert_eq!(NodeRole::Worker.to_string(), "worker");
627 }
628
629 #[test]
630 fn test_node_id_not_empty() {
631 let config = DistributedConfig::default();
632 assert!(!config.node_id.is_empty());
633 }
634
635 #[test]
638 fn test_wire_join_request_roundtrip() {
639 let msg = WireMessage::JoinRequest {
640 node_id: "intel-1234".to_string(),
641 gpu_count: 2,
642 backend: "wgpu".to_string(),
643 };
644 let bytes = msg.to_bytes();
645 let payload = &bytes[4..];
647 let decoded = WireMessage::from_payload(payload).expect("valid");
648 match decoded {
649 WireMessage::JoinRequest { node_id, gpu_count, backend } => {
650 assert_eq!(node_id, "intel-1234");
651 assert_eq!(gpu_count, 2);
652 assert_eq!(backend, "wgpu");
653 }
654 other => panic!("expected JoinRequest, got {other:?}"),
655 }
656 }
657
658 #[test]
659 fn test_wire_join_accepted_roundtrip() {
660 let msg = WireMessage::JoinAccepted { worker_id: 1, total_workers: 3 };
661 let bytes = msg.to_bytes();
662 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
663 match decoded {
664 WireMessage::JoinAccepted { worker_id, total_workers } => {
665 assert_eq!(worker_id, 1);
666 assert_eq!(total_workers, 3);
667 }
668 other => panic!("expected JoinAccepted, got {other:?}"),
669 }
670 }
671
672 #[test]
673 fn test_wire_shard_assignment_roundtrip() {
674 let msg = WireMessage::ShardAssignment { step: 42, shard_start: 100, shard_end: 200 };
675 let bytes = msg.to_bytes();
676 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
677 match decoded {
678 WireMessage::ShardAssignment { step, shard_start, shard_end } => {
679 assert_eq!(step, 42);
680 assert_eq!(shard_start, 100);
681 assert_eq!(shard_end, 200);
682 }
683 other => panic!("expected ShardAssignment, got {other:?}"),
684 }
685 }
686
687 #[test]
688 fn test_wire_gradient_payload_roundtrip() {
689 let grads = vec![1.0f32, 2.0, 3.0, -0.5, 0.0];
690 let msg = WireMessage::GradientPayload {
691 step: 10,
692 worker_id: 2,
693 gradients: grads.clone(),
694 loss: 0.456,
695 correct: 8,
696 total: 10,
697 };
698 let bytes = msg.to_bytes();
699 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
700 match decoded {
701 WireMessage::GradientPayload { step, worker_id, gradients, loss, correct, total } => {
702 assert_eq!(step, 10);
703 assert_eq!(worker_id, 2);
704 assert_eq!(gradients, grads);
705 assert!((loss - 0.456).abs() < 1e-6);
706 assert_eq!(correct, 8);
707 assert_eq!(total, 10);
708 }
709 other => panic!("expected GradientPayload, got {other:?}"),
710 }
711 }
712
713 #[test]
714 fn test_wire_averaged_gradient_roundtrip() {
715 let grads = vec![0.5f32, 1.0, 1.5];
716 let msg =
717 WireMessage::AveragedGradient { step: 5, gradients: grads.clone(), global_loss: 0.789 };
718 let bytes = msg.to_bytes();
719 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
720 match decoded {
721 WireMessage::AveragedGradient { step, gradients, global_loss } => {
722 assert_eq!(step, 5);
723 assert_eq!(gradients, grads);
724 assert!((global_loss - 0.789).abs() < 1e-6);
725 }
726 other => panic!("expected AveragedGradient, got {other:?}"),
727 }
728 }
729
730 #[test]
731 fn test_wire_heartbeat_roundtrip() {
732 let msg = WireMessage::Heartbeat {
733 node_id: "lambda-5678".to_string(),
734 timestamp_ms: 1_709_000_000_000,
735 };
736 let bytes = msg.to_bytes();
737 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
738 match decoded {
739 WireMessage::Heartbeat { node_id, timestamp_ms } => {
740 assert_eq!(node_id, "lambda-5678");
741 assert_eq!(timestamp_ms, 1_709_000_000_000);
742 }
743 other => panic!("expected Heartbeat, got {other:?}"),
744 }
745 }
746
747 #[test]
748 fn test_wire_shutdown_roundtrip() {
749 let msg = WireMessage::Shutdown;
750 let bytes = msg.to_bytes();
751 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
752 assert!(matches!(decoded, WireMessage::Shutdown));
753 }
754
755 #[test]
756 fn test_wire_empty_payload_error() {
757 let result = WireMessage::from_payload(&[]);
758 assert!(result.is_err());
759 }
760
761 #[test]
762 fn test_wire_unknown_tag_error() {
763 let result = WireMessage::from_payload(&[0xFF]);
764 assert!(result.is_err());
765 }
766
767 #[test]
770 fn test_wire_block_gradient_payload_roundtrip() {
771 let component_sizes = vec![100, 50, 50, 100, 200, 200, 200, 10, 10];
772 let total: u32 = component_sizes.iter().sum();
773 let grads: Vec<f32> = (0..total).map(|i| i as f32 * 0.01).collect();
774 let msg = WireMessage::BlockGradientPayload {
775 step: 42,
776 worker_id: 1,
777 block_idx: 5,
778 num_blocks: 24,
779 gradients: grads.clone(),
780 component_sizes: component_sizes.clone(),
781 };
782 let bytes = msg.to_bytes();
783 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
784 match decoded {
785 WireMessage::BlockGradientPayload {
786 step,
787 worker_id,
788 block_idx,
789 num_blocks,
790 gradients,
791 component_sizes: cs,
792 } => {
793 assert_eq!(step, 42);
794 assert_eq!(worker_id, 1);
795 assert_eq!(block_idx, 5);
796 assert_eq!(num_blocks, 24);
797 assert_eq!(gradients, grads);
798 assert_eq!(cs, component_sizes);
799 }
800 other => panic!("expected BlockGradientPayload, got {other:?}"),
801 }
802 }
803
804 #[test]
805 fn test_wire_averaged_block_gradient_roundtrip() {
806 let component_sizes = vec![100, 50, 50, 100, 200, 200, 200, 10, 10];
807 let total: u32 = component_sizes.iter().sum();
808 let grads: Vec<f32> = (0..total).map(|i| i as f32 * -0.005).collect();
809 let msg = WireMessage::AveragedBlockGradient {
810 step: 99,
811 block_idx: 23,
812 gradients: grads.clone(),
813 component_sizes: component_sizes.clone(),
814 };
815 let bytes = msg.to_bytes();
816 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
817 match decoded {
818 WireMessage::AveragedBlockGradient {
819 step,
820 block_idx,
821 gradients,
822 component_sizes: cs,
823 } => {
824 assert_eq!(step, 99);
825 assert_eq!(block_idx, 23);
826 assert_eq!(gradients, grads);
827 assert_eq!(cs, component_sizes);
828 }
829 other => panic!("expected AveragedBlockGradient, got {other:?}"),
830 }
831 }
832
833 #[test]
834 fn test_wire_non_block_gradient_payload_roundtrip() {
835 let grads = vec![1.0f32, -2.0, 3.5, 0.0, f32::MIN_POSITIVE];
836 let msg = WireMessage::NonBlockGradientPayload {
837 step: 10,
838 worker_id: 0,
839 component: 2, gradients: grads.clone(),
841 };
842 let bytes = msg.to_bytes();
843 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
844 match decoded {
845 WireMessage::NonBlockGradientPayload { step, worker_id, component, gradients } => {
846 assert_eq!(step, 10);
847 assert_eq!(worker_id, 0);
848 assert_eq!(component, 2);
849 assert_eq!(gradients, grads);
850 }
851 other => panic!("expected NonBlockGradientPayload, got {other:?}"),
852 }
853 }
854
855 #[test]
856 fn test_wire_averaged_non_block_gradient_roundtrip() {
857 let grads = vec![0.5f32; 32768];
858 let msg = WireMessage::AveragedNonBlockGradient {
859 step: 50,
860 component: 0, gradients: grads.clone(),
862 };
863 let bytes = msg.to_bytes();
864 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
865 match decoded {
866 WireMessage::AveragedNonBlockGradient { step, component, gradients } => {
867 assert_eq!(step, 50);
868 assert_eq!(component, 0);
869 assert_eq!(gradients, grads);
870 }
871 other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
872 }
873 }
874
875 #[test]
876 fn test_wire_block_gradient_truncated_error() {
877 let result = WireMessage::from_payload(&[0x08, 0x00, 0x00, 0x00]);
879 assert!(result.is_err());
880 assert!(result.unwrap_err().contains("truncated"));
881 }
882
883 #[test]
884 fn test_wire_non_block_gradient_special_values() {
885 let grads = vec![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0, -0.0];
886 let msg = WireMessage::NonBlockGradientPayload {
887 step: 1,
888 worker_id: 0,
889 component: 1,
890 gradients: grads,
891 };
892 let bytes = msg.to_bytes();
893 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
894 match decoded {
895 WireMessage::NonBlockGradientPayload { gradients, .. } => {
896 assert!(gradients[0].is_nan());
897 assert!(gradients[1].is_infinite() && gradients[1].is_sign_positive());
898 assert!(gradients[2].is_infinite() && gradients[2].is_sign_negative());
899 assert_eq!(gradients[3], 0.0);
900 assert_eq!(gradients[4], -0.0);
901 }
902 other => panic!("expected NonBlockGradientPayload, got {other:?}"),
903 }
904 }
905
906 #[test]
907 fn test_wire_large_gradient_roundtrip() {
908 let grad_len = 1_378_050;
910 let grads: Vec<f32> = (0..grad_len).map(|i| (i as f32) * 0.0001).collect();
911 let msg = WireMessage::GradientPayload {
912 step: 100,
913 worker_id: 0,
914 gradients: grads.clone(),
915 loss: 0.123,
916 correct: 95,
917 total: 100,
918 };
919 let bytes = msg.to_bytes();
920 let expected_size = 4 + 1 + 8 + 4 + 8 + grad_len * 4 + 4 + 8 + 8;
923 assert_eq!(bytes.len(), expected_size);
924
925 let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
926 match decoded {
927 WireMessage::GradientPayload { gradients, loss, .. } => {
928 assert_eq!(gradients.len(), grad_len);
929 assert!((loss - 0.123).abs() < 1e-6);
930 }
931 other => panic!("expected GradientPayload, got {other:?}"),
932 }
933 }
934}