Skip to main content

entrenar/finetune/
distributed.rs

1//! Distributed training configuration and protocol (SPEC-DIST-2026-001 Phase 2)
2//!
3//! Defines the coordination protocol for multi-node heterogeneous training
4//! across CUDA and wgpu backends over TCP.
5//!
6//! # Architecture
7//!
8//! ```text
9//! Coordinator (intel:9000)
10//!   ├── Worker 0: intel:gpu0 (wgpu)
11//!   ├── Worker 1: intel:gpu1 (wgpu)
12//!   └── Worker 2: lambda:gpu0 (CUDA)
13//! ```
14//!
15//! # Protocol
16//!
17//! 1. Workers connect to coordinator via TCP
18//! 2. Coordinator broadcasts model config + initial weights
19//! 3. Per step: coordinator sends shard assignment, workers compute gradients,
20//!    coordinator AllReduces, broadcasts averaged gradients
21//! 4. Workers detect coordinator failure via heartbeat timeout
22//!
23//! # Contract: F-DP-001 (Weight Consistency)
24//!
25//! After AllReduce + optimizer step, all workers hold identical LoRA weights.
26
27use std::fmt;
28use std::net::SocketAddr;
29
30/// Role of a node in distributed training.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum NodeRole {
33    /// Coordinates training: accepts workers, shards data, AllReduces gradients
34    Coordinator,
35    /// Computes forward/backward on assigned shard, sends gradients to coordinator
36    Worker,
37}
38
39impl fmt::Display for NodeRole {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            Self::Coordinator => write!(f, "coordinator"),
43            Self::Worker => write!(f, "worker"),
44        }
45    }
46}
47
48/// Configuration for distributed training.
49#[derive(Debug, Clone)]
50pub struct DistributedConfig {
51    /// Role of this node
52    pub role: NodeRole,
53    /// Address to bind (coordinator) or connect to (worker)
54    pub bind_addr: SocketAddr,
55    /// Coordinator address (workers only; coordinators use bind_addr)
56    pub coordinator_addr: Option<SocketAddr>,
57    /// Expected number of workers (coordinator only, for barrier)
58    pub expect_workers: usize,
59    /// Heartbeat interval in milliseconds
60    pub heartbeat_interval_ms: u64,
61    /// Heartbeat timeout in milliseconds (detect failure after this)
62    pub heartbeat_timeout_ms: u64,
63    /// Node identifier (auto-assigned from hostname + pid)
64    pub node_id: String,
65}
66
67impl DistributedConfig {
68    /// Create a coordinator config.
69    ///
70    /// # Arguments
71    /// * `bind_addr` - Address to listen on (e.g., `0.0.0.0:9000`)
72    /// * `expect_workers` - Total worker count (including coordinator's own GPUs)
73    #[must_use]
74    pub fn coordinator(bind_addr: SocketAddr, expect_workers: usize) -> Self {
75        Self {
76            role: NodeRole::Coordinator,
77            bind_addr,
78            coordinator_addr: None,
79            expect_workers,
80            heartbeat_interval_ms: 5000,
81            heartbeat_timeout_ms: 30000,
82            node_id: Self::default_node_id(),
83        }
84    }
85
86    /// Create a worker config.
87    ///
88    /// # Arguments
89    /// * `coordinator_addr` - Address of the coordinator (e.g., `intel:9000`)
90    #[must_use]
91    pub fn worker(coordinator_addr: SocketAddr) -> Self {
92        Self {
93            role: NodeRole::Worker,
94            bind_addr: "0.0.0.0:0".parse().expect("valid addr"),
95            coordinator_addr: Some(coordinator_addr),
96            expect_workers: 0,
97            heartbeat_interval_ms: 5000,
98            heartbeat_timeout_ms: 30000,
99            node_id: Self::default_node_id(),
100        }
101    }
102
103    /// Check if this node is a coordinator
104    #[must_use]
105    pub fn is_coordinator(&self) -> bool {
106        self.role == NodeRole::Coordinator
107    }
108
109    fn default_node_id() -> String {
110        let hostname = hostname::get()
111            .map_or_else(|_| "unknown".to_string(), |h| h.to_string_lossy().to_string());
112        let pid = std::process::id();
113        format!("{hostname}-{pid}")
114    }
115}
116
117impl Default for DistributedConfig {
118    fn default() -> Self {
119        Self::coordinator("0.0.0.0:9000".parse().expect("valid addr"), 1)
120    }
121}
122
123// ─── Wire Protocol ───────────────────────────────────────────────────────────
124
125/// Messages sent between coordinator and workers over TCP.
126///
127/// Each message is length-prefixed: `[u32 big-endian length][payload bytes]`.
128/// Payload is bincode-serialized `WireMessage`.
129#[derive(Debug, Clone)]
130pub enum WireMessage {
131    /// Worker → Coordinator: request to join training
132    JoinRequest { node_id: String, gpu_count: u32, backend: String },
133    /// Coordinator → Worker: join accepted with assigned worker ID
134    JoinAccepted { worker_id: u32, total_workers: u32 },
135    /// Coordinator → Worker: here is your shard for this step
136    ShardAssignment { step: u64, shard_start: usize, shard_end: usize },
137    /// Worker → Coordinator: gradient data for this step
138    GradientPayload {
139        step: u64,
140        worker_id: u32,
141        /// Serialized f32 gradient vector (LoRA params + classifier head)
142        gradients: Vec<f32>,
143        loss: f32,
144        correct: usize,
145        total: usize,
146    },
147    /// Coordinator → Worker: averaged gradient for optimizer step
148    AveragedGradient { step: u64, gradients: Vec<f32>, global_loss: f32 },
149    /// Bidirectional: heartbeat ping/pong
150    Heartbeat { node_id: String, timestamp_ms: u64 },
151    /// Coordinator → Worker: training complete, shut down
152    Shutdown,
153
154    // === v2: Per-block gradient messages for DDP pretraining (tags 0x08-0x0B) ===
155    /// Worker → Coordinator: gradient data for a single transformer block
156    ///
157    /// Sent after backward pass for block[block_idx]. Contains 9 gradient
158    /// components (w_q, w_k, w_v, w_o, gate, up, down, input_norm, post_attn_norm)
159    /// concatenated into a single f32 vector with per-component sizes.
160    BlockGradientPayload {
161        step: u64,
162        worker_id: u32,
163        block_idx: u32,
164        num_blocks: u32,
165        /// Concatenated f32 gradient vector for all components
166        gradients: Vec<f32>,
167        /// Element count for each component (e.g., [H*H, H*D_kv, H*D_kv, H*H, H*I, H*I, I*H, H, H])
168        component_sizes: Vec<u32>,
169    },
170    /// Coordinator → Worker: averaged gradient for a single transformer block
171    AveragedBlockGradient {
172        step: u64,
173        block_idx: u32,
174        /// Concatenated averaged f32 gradient vector
175        gradients: Vec<f32>,
176        /// Element count for each component (same layout as BlockGradientPayload)
177        component_sizes: Vec<u32>,
178    },
179    /// Worker → Coordinator: gradient for a non-block component (LM head, final norm, embedding)
180    NonBlockGradientPayload {
181        step: u64,
182        worker_id: u32,
183        /// Component ID: 0=lm_head, 1=final_norm, 2=embedding
184        component: u8,
185        /// Flattened f32 gradient
186        gradients: Vec<f32>,
187    },
188    /// Coordinator → Worker: averaged gradient for a non-block component
189    AveragedNonBlockGradient {
190        step: u64,
191        /// Component ID: 0=lm_head, 1=final_norm, 2=embedding
192        component: u8,
193        /// Averaged f32 gradient
194        gradients: Vec<f32>,
195    },
196}
197
198impl WireMessage {
199    /// Serialize this message to bytes (length-prefixed).
200    ///
201    /// Format: `[4 bytes big-endian length][payload]`
202    pub fn to_bytes(&self) -> Vec<u8> {
203        let payload = self.serialize_payload();
204        let len = payload.len() as u32;
205        let mut buf = Vec::with_capacity(4 + payload.len());
206        buf.extend_from_slice(&len.to_be_bytes());
207        buf.extend_from_slice(&payload);
208        buf
209    }
210
211    /// Deserialize from a complete payload (without length prefix).
212    ///
213    /// # Errors
214    /// Returns error if payload is malformed.
215    pub fn from_payload(payload: &[u8]) -> Result<Self, String> {
216        Self::deserialize_payload(payload)
217    }
218
219    // Simple binary serialization (avoiding serde dependency for wire protocol)
220    fn serialize_payload(&self) -> Vec<u8> {
221        let mut buf = Vec::new();
222        match self {
223            Self::JoinRequest { node_id, gpu_count, backend } => {
224                buf.push(0x01);
225                write_string(&mut buf, node_id);
226                buf.extend_from_slice(&gpu_count.to_le_bytes());
227                write_string(&mut buf, backend);
228            }
229            Self::JoinAccepted { worker_id, total_workers } => {
230                buf.push(0x02);
231                buf.extend_from_slice(&worker_id.to_le_bytes());
232                buf.extend_from_slice(&total_workers.to_le_bytes());
233            }
234            Self::ShardAssignment { step, shard_start, shard_end } => {
235                buf.push(0x03);
236                buf.extend_from_slice(&step.to_le_bytes());
237                buf.extend_from_slice(&(*shard_start as u64).to_le_bytes());
238                buf.extend_from_slice(&(*shard_end as u64).to_le_bytes());
239            }
240            Self::GradientPayload { step, worker_id, gradients, loss, correct, total } => {
241                buf.push(0x04);
242                buf.extend_from_slice(&step.to_le_bytes());
243                buf.extend_from_slice(&worker_id.to_le_bytes());
244                write_f32_vec(&mut buf, gradients);
245                buf.extend_from_slice(&loss.to_le_bytes());
246                buf.extend_from_slice(&(*correct as u64).to_le_bytes());
247                buf.extend_from_slice(&(*total as u64).to_le_bytes());
248            }
249            Self::AveragedGradient { step, gradients, global_loss } => {
250                buf.push(0x05);
251                buf.extend_from_slice(&step.to_le_bytes());
252                write_f32_vec(&mut buf, gradients);
253                buf.extend_from_slice(&global_loss.to_le_bytes());
254            }
255            Self::Heartbeat { node_id, timestamp_ms } => {
256                buf.push(0x06);
257                write_string(&mut buf, node_id);
258                buf.extend_from_slice(&timestamp_ms.to_le_bytes());
259            }
260            Self::Shutdown => buf.push(0x07),
261            Self::BlockGradientPayload {
262                step,
263                worker_id,
264                block_idx,
265                num_blocks,
266                gradients,
267                component_sizes,
268            } => serialize_block_grad(
269                &mut buf,
270                0x08,
271                *step,
272                *worker_id,
273                *block_idx,
274                *num_blocks,
275                gradients,
276                component_sizes,
277            ),
278            Self::AveragedBlockGradient { step, block_idx, gradients, component_sizes } => {
279                serialize_averaged_block(&mut buf, *step, *block_idx, gradients, component_sizes);
280            }
281            Self::NonBlockGradientPayload { step, worker_id, component, gradients } => {
282                serialize_non_block_grad(&mut buf, *step, *worker_id, *component, gradients);
283            }
284            Self::AveragedNonBlockGradient { step, component, gradients } => {
285                serialize_averaged_non_block(&mut buf, *step, *component, gradients);
286            }
287        }
288        buf
289    }
290
291    fn deserialize_payload(data: &[u8]) -> Result<Self, String> {
292        if data.is_empty() {
293            return Err("empty payload".to_string());
294        }
295        let tag = data[0];
296        let rest = &data[1..];
297        match tag {
298            0x01 => decode_join_request(rest),
299            0x02 => decode_join_accepted(rest),
300            0x03 => decode_shard_assignment(rest),
301            0x04 => decode_gradient_payload(rest),
302            0x05 => decode_averaged_gradient(rest),
303            0x06 => decode_heartbeat(rest),
304            0x07 => Ok(Self::Shutdown),
305            0x08 => decode_block_gradient_payload(rest),
306            0x09 => decode_averaged_block_gradient(rest),
307            0x0A => decode_non_block_gradient_payload(rest),
308            0x0B => decode_averaged_non_block_gradient(rest),
309            other => Err(format!("unknown message tag: 0x{other:02x}")),
310        }
311    }
312}
313
314fn decode_join_request(rest: &[u8]) -> Result<WireMessage, String> {
315    let (node_id, rest) = read_string(rest)?;
316    if rest.len() < 4 {
317        return Err("truncated JoinRequest".to_string());
318    }
319    let gpu_count = u32::from_le_bytes(rest[..4].try_into().expect("4 bytes"));
320    let (backend, _) = read_string(&rest[4..])?;
321    Ok(WireMessage::JoinRequest { node_id, gpu_count, backend })
322}
323
324fn decode_join_accepted(rest: &[u8]) -> Result<WireMessage, String> {
325    if rest.len() < 8 {
326        return Err("truncated JoinAccepted".to_string());
327    }
328    let worker_id = u32::from_le_bytes(rest[..4].try_into().expect("4 bytes"));
329    let total_workers = u32::from_le_bytes(rest[4..8].try_into().expect("4 bytes"));
330    Ok(WireMessage::JoinAccepted { worker_id, total_workers })
331}
332
333fn decode_shard_assignment(rest: &[u8]) -> Result<WireMessage, String> {
334    if rest.len() < 24 {
335        return Err("truncated ShardAssignment".to_string());
336    }
337    let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
338    let shard_start = u64::from_le_bytes(rest[8..16].try_into().expect("8 bytes")) as usize;
339    let shard_end = u64::from_le_bytes(rest[16..24].try_into().expect("8 bytes")) as usize;
340    Ok(WireMessage::ShardAssignment { step, shard_start, shard_end })
341}
342
343fn decode_gradient_payload(rest: &[u8]) -> Result<WireMessage, String> {
344    if rest.len() < 20 {
345        return Err("truncated GradientPayload header".to_string());
346    }
347    let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
348    let worker_id = u32::from_le_bytes(rest[8..12].try_into().expect("4 bytes"));
349    let grad_len = u64::from_le_bytes(rest[12..20].try_into().expect("8 bytes")) as usize;
350    let grad_bytes = grad_len * 4;
351    if rest.len() < 20 + grad_bytes + 4 + 8 + 8 {
352        return Err("truncated GradientPayload data".to_string());
353    }
354    let gradients = read_f32_vec(rest, 20, grad_len);
355    let tail = &rest[20 + grad_bytes..];
356    let loss = f32::from_le_bytes(tail[..4].try_into().expect("4 bytes"));
357    let correct = u64::from_le_bytes(tail[4..12].try_into().expect("8 bytes")) as usize;
358    let total = u64::from_le_bytes(tail[12..20].try_into().expect("8 bytes")) as usize;
359    Ok(WireMessage::GradientPayload { step, worker_id, gradients, loss, correct, total })
360}
361
362fn decode_averaged_gradient(rest: &[u8]) -> Result<WireMessage, String> {
363    if rest.len() < 16 {
364        return Err("truncated AveragedGradient header".to_string());
365    }
366    let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
367    let grad_len = u64::from_le_bytes(rest[8..16].try_into().expect("8 bytes")) as usize;
368    let grad_bytes = grad_len * 4;
369    if rest.len() < 16 + grad_bytes + 4 {
370        return Err("truncated AveragedGradient data".to_string());
371    }
372    let gradients = read_f32_vec(rest, 16, grad_len);
373    let global_loss =
374        f32::from_le_bytes(rest[16 + grad_bytes..16 + grad_bytes + 4].try_into().expect("4 bytes"));
375    Ok(WireMessage::AveragedGradient { step, gradients, global_loss })
376}
377
378fn decode_heartbeat(rest: &[u8]) -> Result<WireMessage, String> {
379    let (node_id, rest) = read_string(rest)?;
380    if rest.len() < 8 {
381        return Err("truncated Heartbeat".to_string());
382    }
383    let timestamp_ms = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
384    Ok(WireMessage::Heartbeat { node_id, timestamp_ms })
385}
386
387fn decode_block_gradient_payload(rest: &[u8]) -> Result<WireMessage, String> {
388    // step(8) + worker_id(4) + block_idx(4) + num_blocks(4) + num_components(4) = 24
389    if rest.len() < 24 {
390        return Err("truncated BlockGradientPayload header".to_string());
391    }
392    let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
393    let worker_id = u32::from_le_bytes(rest[8..12].try_into().expect("4 bytes"));
394    let block_idx = u32::from_le_bytes(rest[12..16].try_into().expect("4 bytes"));
395    let num_blocks = u32::from_le_bytes(rest[16..20].try_into().expect("4 bytes"));
396    let num_components = u32::from_le_bytes(rest[20..24].try_into().expect("4 bytes")) as usize;
397
398    let comp_end = 24 + num_components * 4;
399    if rest.len() < comp_end + 8 {
400        return Err("truncated BlockGradientPayload component_sizes".to_string());
401    }
402    let mut component_sizes = Vec::with_capacity(num_components);
403    for i in 0..num_components {
404        let start = 24 + i * 4;
405        component_sizes
406            .push(u32::from_le_bytes(rest[start..start + 4].try_into().expect("4 bytes")));
407    }
408
409    let grad_len =
410        u64::from_le_bytes(rest[comp_end..comp_end + 8].try_into().expect("8 bytes")) as usize;
411    let grad_start = comp_end + 8;
412    if rest.len() < grad_start + grad_len * 4 {
413        return Err("truncated BlockGradientPayload gradients".to_string());
414    }
415    let gradients = read_f32_vec(rest, grad_start, grad_len);
416
417    Ok(WireMessage::BlockGradientPayload {
418        step,
419        worker_id,
420        block_idx,
421        num_blocks,
422        gradients,
423        component_sizes,
424    })
425}
426
427fn decode_averaged_block_gradient(rest: &[u8]) -> Result<WireMessage, String> {
428    // step(8) + block_idx(4) + num_components(4) = 16
429    if rest.len() < 16 {
430        return Err("truncated AveragedBlockGradient header".to_string());
431    }
432    let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
433    let block_idx = u32::from_le_bytes(rest[8..12].try_into().expect("4 bytes"));
434    let num_components = u32::from_le_bytes(rest[12..16].try_into().expect("4 bytes")) as usize;
435
436    let comp_end = 16 + num_components * 4;
437    if rest.len() < comp_end + 8 {
438        return Err("truncated AveragedBlockGradient component_sizes".to_string());
439    }
440    let mut component_sizes = Vec::with_capacity(num_components);
441    for i in 0..num_components {
442        let start = 16 + i * 4;
443        component_sizes
444            .push(u32::from_le_bytes(rest[start..start + 4].try_into().expect("4 bytes")));
445    }
446
447    let grad_len =
448        u64::from_le_bytes(rest[comp_end..comp_end + 8].try_into().expect("8 bytes")) as usize;
449    let grad_start = comp_end + 8;
450    if rest.len() < grad_start + grad_len * 4 {
451        return Err("truncated AveragedBlockGradient gradients".to_string());
452    }
453    let gradients = read_f32_vec(rest, grad_start, grad_len);
454
455    Ok(WireMessage::AveragedBlockGradient { step, block_idx, gradients, component_sizes })
456}
457
458fn decode_non_block_gradient_payload(rest: &[u8]) -> Result<WireMessage, String> {
459    // step(8) + worker_id(4) + component(1) + grad_len(8) = 21
460    if rest.len() < 21 {
461        return Err("truncated NonBlockGradientPayload header".to_string());
462    }
463    let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
464    let worker_id = u32::from_le_bytes(rest[8..12].try_into().expect("4 bytes"));
465    let component = rest[12];
466    let grad_len = u64::from_le_bytes(rest[13..21].try_into().expect("8 bytes")) as usize;
467    if rest.len() < 21 + grad_len * 4 {
468        return Err("truncated NonBlockGradientPayload gradients".to_string());
469    }
470    let gradients = read_f32_vec(rest, 21, grad_len);
471
472    Ok(WireMessage::NonBlockGradientPayload { step, worker_id, component, gradients })
473}
474
475fn decode_averaged_non_block_gradient(rest: &[u8]) -> Result<WireMessage, String> {
476    // step(8) + component(1) + grad_len(8) = 17
477    if rest.len() < 17 {
478        return Err("truncated AveragedNonBlockGradient header".to_string());
479    }
480    let step = u64::from_le_bytes(rest[..8].try_into().expect("8 bytes"));
481    let component = rest[8];
482    let grad_len = u64::from_le_bytes(rest[9..17].try_into().expect("8 bytes")) as usize;
483    if rest.len() < 17 + grad_len * 4 {
484        return Err("truncated AveragedNonBlockGradient gradients".to_string());
485    }
486    let gradients = read_f32_vec(rest, 17, grad_len);
487
488    Ok(WireMessage::AveragedNonBlockGradient { step, component, gradients })
489}
490
491fn read_f32_vec(data: &[u8], offset: usize, count: usize) -> Vec<f32> {
492    let mut result = Vec::with_capacity(count);
493    for i in 0..count {
494        let start = offset + i * 4;
495        let val = f32::from_le_bytes(data[start..start + 4].try_into().expect("4 bytes"));
496        result.push(val);
497    }
498    result
499}
500
501/// Write a length-prefixed f32 vector to the wire buffer.
502fn write_f32_vec(buf: &mut Vec<u8>, v: &[f32]) {
503    buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
504    for &x in v {
505        buf.extend_from_slice(&x.to_le_bytes());
506    }
507}
508
509/// Write length-prefixed u32 component sizes to the wire buffer.
510fn write_component_sizes(buf: &mut Vec<u8>, sizes: &[u32]) {
511    buf.extend_from_slice(&(sizes.len() as u32).to_le_bytes());
512    for &sz in sizes {
513        buf.extend_from_slice(&sz.to_le_bytes());
514    }
515}
516
517/// Serialize a BlockGradientPayload (tag 0x08).
518fn serialize_block_grad(
519    buf: &mut Vec<u8>,
520    tag: u8,
521    step: u64,
522    worker_id: u32,
523    block_idx: u32,
524    num_blocks: u32,
525    gradients: &[f32],
526    component_sizes: &[u32],
527) {
528    buf.push(tag);
529    buf.extend_from_slice(&step.to_le_bytes());
530    buf.extend_from_slice(&worker_id.to_le_bytes());
531    buf.extend_from_slice(&block_idx.to_le_bytes());
532    buf.extend_from_slice(&num_blocks.to_le_bytes());
533    write_component_sizes(buf, component_sizes);
534    write_f32_vec(buf, gradients);
535}
536
537/// Serialize an AveragedBlockGradient (tag 0x09).
538fn serialize_averaged_block(
539    buf: &mut Vec<u8>,
540    step: u64,
541    block_idx: u32,
542    gradients: &[f32],
543    component_sizes: &[u32],
544) {
545    buf.push(0x09);
546    buf.extend_from_slice(&step.to_le_bytes());
547    buf.extend_from_slice(&block_idx.to_le_bytes());
548    write_component_sizes(buf, component_sizes);
549    write_f32_vec(buf, gradients);
550}
551
552/// Serialize a NonBlockGradientPayload (tag 0x0A).
553fn serialize_non_block_grad(
554    buf: &mut Vec<u8>,
555    step: u64,
556    worker_id: u32,
557    component: u8,
558    gradients: &[f32],
559) {
560    buf.push(0x0A);
561    buf.extend_from_slice(&step.to_le_bytes());
562    buf.extend_from_slice(&worker_id.to_le_bytes());
563    buf.push(component);
564    write_f32_vec(buf, gradients);
565}
566
567/// Serialize an AveragedNonBlockGradient (tag 0x0B).
568fn serialize_averaged_non_block(buf: &mut Vec<u8>, step: u64, component: u8, gradients: &[f32]) {
569    buf.push(0x0B);
570    buf.extend_from_slice(&step.to_le_bytes());
571    buf.push(component);
572    write_f32_vec(buf, gradients);
573}
574
575fn write_string(buf: &mut Vec<u8>, s: &str) {
576    let bytes = s.as_bytes();
577    buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
578    buf.extend_from_slice(bytes);
579}
580
581fn read_string(data: &[u8]) -> Result<(String, &[u8]), String> {
582    if data.len() < 4 {
583        return Err("truncated string length".to_string());
584    }
585    let len = u32::from_le_bytes(data[..4].try_into().expect("4 bytes")) as usize;
586    if data.len() < 4 + len {
587        return Err("truncated string data".to_string());
588    }
589    let s =
590        String::from_utf8(data[4..4 + len].to_vec()).map_err(|e| format!("invalid utf8: {e}"))?;
591    Ok((s, &data[4 + len..]))
592}
593
594#[cfg(test)]
595mod tests {
596    #![allow(clippy::unwrap_used)]
597    use super::*;
598
599    #[test]
600    fn test_coordinator_config() {
601        let config = DistributedConfig::coordinator("0.0.0.0:9000".parse().expect("valid"), 3);
602        assert!(config.is_coordinator());
603        assert_eq!(config.role, NodeRole::Coordinator);
604        assert_eq!(config.expect_workers, 3);
605        assert!(config.coordinator_addr.is_none());
606    }
607
608    #[test]
609    fn test_worker_config() {
610        let config = DistributedConfig::worker("192.168.50.100:9000".parse().expect("valid"));
611        assert!(!config.is_coordinator());
612        assert_eq!(config.role, NodeRole::Worker);
613        assert_eq!(config.coordinator_addr, Some("192.168.50.100:9000".parse().expect("valid")));
614    }
615
616    #[test]
617    fn test_default_config() {
618        let config = DistributedConfig::default();
619        assert!(config.is_coordinator());
620        assert_eq!(config.expect_workers, 1);
621    }
622
623    #[test]
624    fn test_node_role_display() {
625        assert_eq!(NodeRole::Coordinator.to_string(), "coordinator");
626        assert_eq!(NodeRole::Worker.to_string(), "worker");
627    }
628
629    #[test]
630    fn test_node_id_not_empty() {
631        let config = DistributedConfig::default();
632        assert!(!config.node_id.is_empty());
633    }
634
635    // ── Wire protocol round-trip tests ───────────────────────────────────
636
637    #[test]
638    fn test_wire_join_request_roundtrip() {
639        let msg = WireMessage::JoinRequest {
640            node_id: "intel-1234".to_string(),
641            gpu_count: 2,
642            backend: "wgpu".to_string(),
643        };
644        let bytes = msg.to_bytes();
645        // Skip 4-byte length prefix
646        let payload = &bytes[4..];
647        let decoded = WireMessage::from_payload(payload).expect("valid");
648        match decoded {
649            WireMessage::JoinRequest { node_id, gpu_count, backend } => {
650                assert_eq!(node_id, "intel-1234");
651                assert_eq!(gpu_count, 2);
652                assert_eq!(backend, "wgpu");
653            }
654            other => panic!("expected JoinRequest, got {other:?}"),
655        }
656    }
657
658    #[test]
659    fn test_wire_join_accepted_roundtrip() {
660        let msg = WireMessage::JoinAccepted { worker_id: 1, total_workers: 3 };
661        let bytes = msg.to_bytes();
662        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
663        match decoded {
664            WireMessage::JoinAccepted { worker_id, total_workers } => {
665                assert_eq!(worker_id, 1);
666                assert_eq!(total_workers, 3);
667            }
668            other => panic!("expected JoinAccepted, got {other:?}"),
669        }
670    }
671
672    #[test]
673    fn test_wire_shard_assignment_roundtrip() {
674        let msg = WireMessage::ShardAssignment { step: 42, shard_start: 100, shard_end: 200 };
675        let bytes = msg.to_bytes();
676        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
677        match decoded {
678            WireMessage::ShardAssignment { step, shard_start, shard_end } => {
679                assert_eq!(step, 42);
680                assert_eq!(shard_start, 100);
681                assert_eq!(shard_end, 200);
682            }
683            other => panic!("expected ShardAssignment, got {other:?}"),
684        }
685    }
686
687    #[test]
688    fn test_wire_gradient_payload_roundtrip() {
689        let grads = vec![1.0f32, 2.0, 3.0, -0.5, 0.0];
690        let msg = WireMessage::GradientPayload {
691            step: 10,
692            worker_id: 2,
693            gradients: grads.clone(),
694            loss: 0.456,
695            correct: 8,
696            total: 10,
697        };
698        let bytes = msg.to_bytes();
699        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
700        match decoded {
701            WireMessage::GradientPayload { step, worker_id, gradients, loss, correct, total } => {
702                assert_eq!(step, 10);
703                assert_eq!(worker_id, 2);
704                assert_eq!(gradients, grads);
705                assert!((loss - 0.456).abs() < 1e-6);
706                assert_eq!(correct, 8);
707                assert_eq!(total, 10);
708            }
709            other => panic!("expected GradientPayload, got {other:?}"),
710        }
711    }
712
713    #[test]
714    fn test_wire_averaged_gradient_roundtrip() {
715        let grads = vec![0.5f32, 1.0, 1.5];
716        let msg =
717            WireMessage::AveragedGradient { step: 5, gradients: grads.clone(), global_loss: 0.789 };
718        let bytes = msg.to_bytes();
719        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
720        match decoded {
721            WireMessage::AveragedGradient { step, gradients, global_loss } => {
722                assert_eq!(step, 5);
723                assert_eq!(gradients, grads);
724                assert!((global_loss - 0.789).abs() < 1e-6);
725            }
726            other => panic!("expected AveragedGradient, got {other:?}"),
727        }
728    }
729
730    #[test]
731    fn test_wire_heartbeat_roundtrip() {
732        let msg = WireMessage::Heartbeat {
733            node_id: "lambda-5678".to_string(),
734            timestamp_ms: 1_709_000_000_000,
735        };
736        let bytes = msg.to_bytes();
737        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
738        match decoded {
739            WireMessage::Heartbeat { node_id, timestamp_ms } => {
740                assert_eq!(node_id, "lambda-5678");
741                assert_eq!(timestamp_ms, 1_709_000_000_000);
742            }
743            other => panic!("expected Heartbeat, got {other:?}"),
744        }
745    }
746
747    #[test]
748    fn test_wire_shutdown_roundtrip() {
749        let msg = WireMessage::Shutdown;
750        let bytes = msg.to_bytes();
751        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
752        assert!(matches!(decoded, WireMessage::Shutdown));
753    }
754
755    #[test]
756    fn test_wire_empty_payload_error() {
757        let result = WireMessage::from_payload(&[]);
758        assert!(result.is_err());
759    }
760
761    #[test]
762    fn test_wire_unknown_tag_error() {
763        let result = WireMessage::from_payload(&[0xFF]);
764        assert!(result.is_err());
765    }
766
767    // ── v2 Wire protocol tests (per-block gradient messages) ──────────
768
769    #[test]
770    fn test_wire_block_gradient_payload_roundtrip() {
771        let component_sizes = vec![100, 50, 50, 100, 200, 200, 200, 10, 10];
772        let total: u32 = component_sizes.iter().sum();
773        let grads: Vec<f32> = (0..total).map(|i| i as f32 * 0.01).collect();
774        let msg = WireMessage::BlockGradientPayload {
775            step: 42,
776            worker_id: 1,
777            block_idx: 5,
778            num_blocks: 24,
779            gradients: grads.clone(),
780            component_sizes: component_sizes.clone(),
781        };
782        let bytes = msg.to_bytes();
783        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
784        match decoded {
785            WireMessage::BlockGradientPayload {
786                step,
787                worker_id,
788                block_idx,
789                num_blocks,
790                gradients,
791                component_sizes: cs,
792            } => {
793                assert_eq!(step, 42);
794                assert_eq!(worker_id, 1);
795                assert_eq!(block_idx, 5);
796                assert_eq!(num_blocks, 24);
797                assert_eq!(gradients, grads);
798                assert_eq!(cs, component_sizes);
799            }
800            other => panic!("expected BlockGradientPayload, got {other:?}"),
801        }
802    }
803
804    #[test]
805    fn test_wire_averaged_block_gradient_roundtrip() {
806        let component_sizes = vec![100, 50, 50, 100, 200, 200, 200, 10, 10];
807        let total: u32 = component_sizes.iter().sum();
808        let grads: Vec<f32> = (0..total).map(|i| i as f32 * -0.005).collect();
809        let msg = WireMessage::AveragedBlockGradient {
810            step: 99,
811            block_idx: 23,
812            gradients: grads.clone(),
813            component_sizes: component_sizes.clone(),
814        };
815        let bytes = msg.to_bytes();
816        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
817        match decoded {
818            WireMessage::AveragedBlockGradient {
819                step,
820                block_idx,
821                gradients,
822                component_sizes: cs,
823            } => {
824                assert_eq!(step, 99);
825                assert_eq!(block_idx, 23);
826                assert_eq!(gradients, grads);
827                assert_eq!(cs, component_sizes);
828            }
829            other => panic!("expected AveragedBlockGradient, got {other:?}"),
830        }
831    }
832
833    #[test]
834    fn test_wire_non_block_gradient_payload_roundtrip() {
835        let grads = vec![1.0f32, -2.0, 3.5, 0.0, f32::MIN_POSITIVE];
836        let msg = WireMessage::NonBlockGradientPayload {
837            step: 10,
838            worker_id: 0,
839            component: 2, // embedding
840            gradients: grads.clone(),
841        };
842        let bytes = msg.to_bytes();
843        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
844        match decoded {
845            WireMessage::NonBlockGradientPayload { step, worker_id, component, gradients } => {
846                assert_eq!(step, 10);
847                assert_eq!(worker_id, 0);
848                assert_eq!(component, 2);
849                assert_eq!(gradients, grads);
850            }
851            other => panic!("expected NonBlockGradientPayload, got {other:?}"),
852        }
853    }
854
855    #[test]
856    fn test_wire_averaged_non_block_gradient_roundtrip() {
857        let grads = vec![0.5f32; 32768];
858        let msg = WireMessage::AveragedNonBlockGradient {
859            step: 50,
860            component: 0, // lm_head
861            gradients: grads.clone(),
862        };
863        let bytes = msg.to_bytes();
864        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
865        match decoded {
866            WireMessage::AveragedNonBlockGradient { step, component, gradients } => {
867                assert_eq!(step, 50);
868                assert_eq!(component, 0);
869                assert_eq!(gradients, grads);
870            }
871            other => panic!("expected AveragedNonBlockGradient, got {other:?}"),
872        }
873    }
874
875    #[test]
876    fn test_wire_block_gradient_truncated_error() {
877        // Tag 0x08 with insufficient header bytes
878        let result = WireMessage::from_payload(&[0x08, 0x00, 0x00, 0x00]);
879        assert!(result.is_err());
880        assert!(result.unwrap_err().contains("truncated"));
881    }
882
883    #[test]
884    fn test_wire_non_block_gradient_special_values() {
885        let grads = vec![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0, -0.0];
886        let msg = WireMessage::NonBlockGradientPayload {
887            step: 1,
888            worker_id: 0,
889            component: 1,
890            gradients: grads,
891        };
892        let bytes = msg.to_bytes();
893        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
894        match decoded {
895            WireMessage::NonBlockGradientPayload { gradients, .. } => {
896                assert!(gradients[0].is_nan());
897                assert!(gradients[1].is_infinite() && gradients[1].is_sign_positive());
898                assert!(gradients[2].is_infinite() && gradients[2].is_sign_negative());
899                assert_eq!(gradients[3], 0.0);
900                assert_eq!(gradients[4], -0.0);
901            }
902            other => panic!("expected NonBlockGradientPayload, got {other:?}"),
903        }
904    }
905
906    #[test]
907    fn test_wire_large_gradient_roundtrip() {
908        // Simulate real LoRA gradient size: ~1.3M params
909        let grad_len = 1_378_050;
910        let grads: Vec<f32> = (0..grad_len).map(|i| (i as f32) * 0.0001).collect();
911        let msg = WireMessage::GradientPayload {
912            step: 100,
913            worker_id: 0,
914            gradients: grads.clone(),
915            loss: 0.123,
916            correct: 95,
917            total: 100,
918        };
919        let bytes = msg.to_bytes();
920        // Verify size: 4 (len prefix) + 1 (tag) + 8 (step) + 4 (worker_id) +
921        //   8 (grad_len) + grad_len*4 + 4 (loss) + 8 (correct) + 8 (total)
922        let expected_size = 4 + 1 + 8 + 4 + 8 + grad_len * 4 + 4 + 8 + 8;
923        assert_eq!(bytes.len(), expected_size);
924
925        let decoded = WireMessage::from_payload(&bytes[4..]).expect("valid");
926        match decoded {
927            WireMessage::GradientPayload { gradients, loss, .. } => {
928                assert_eq!(gradients.len(), grad_len);
929                assert!((loss - 0.123).abs() < 1e-6);
930            }
931            other => panic!("expected GradientPayload, got {other:?}"),
932        }
933    }
934}