Skip to main content

entrenar/finetune/
gradient_server.rs

1//! TCP gradient server for distributed training (coordinator side)
2//!
3//! The `GradientServer` runs on the coordinator node and:
4//! 1. Accepts worker connections
5//! 2. Assigns shard ranges per training step
6//! 3. Collects gradients from all workers
7//! 4. Computes AllReduce (average) and broadcasts result
8//!
9//! # Contract: F-DP-001 (Weight Consistency)
10//!
11//! After broadcasting averaged gradients, all workers apply the same optimizer
12//! step, maintaining weight consistency.
13//!
14//! # Contract: F-DP-003 (Gradient Stability)
15//!
16//! If any worker sends NaN/Inf gradients, the server halts training (Jidoka).
17
18use super::data_parallel::{average_gradients, has_non_finite};
19use super::distributed::{DistributedConfig, WireMessage};
20use std::io::{Read, Write};
21use std::net::{TcpListener, TcpStream};
22use std::time::Instant;
23
24/// Connected worker info tracked by the server.
25#[derive(Debug)]
26struct WorkerConnection {
27    worker_id: u32,
28    #[allow(dead_code)]
29    node_id: String,
30    #[allow(dead_code)]
31    gpu_count: u32,
32    #[allow(dead_code)]
33    backend: String,
34    stream: TcpStream,
35}
36
37/// Gradient server running on the coordinator node.
38pub struct GradientServer {
39    config: DistributedConfig,
40    listener: TcpListener,
41    workers: Vec<WorkerConnection>,
42    total_samples: usize,
43}
44
45/// Result of one AllReduce step across all workers.
46#[derive(Debug, Clone)]
47pub struct AllReduceResult {
48    /// Averaged gradient vector
49    pub avg_gradients: Vec<f32>,
50    /// Sample-weighted average loss
51    pub global_loss: f32,
52    /// Total correct predictions
53    pub total_correct: usize,
54    /// Total samples processed
55    pub total_samples: usize,
56    /// AllReduce wall time in milliseconds
57    pub allreduce_ms: f64,
58}
59
60/// Result of per-block AllReduce for DDP pretraining.
61#[derive(Debug, Clone)]
62pub struct BlockAllReduceResult {
63    /// Block index
64    pub block_idx: u32,
65    /// Averaged gradient vector (flattened, same layout as BlockGradientPayload)
66    pub avg_gradients: Vec<f32>,
67    /// Component sizes (for reconstructing block gradient structure)
68    pub component_sizes: Vec<u32>,
69    /// AllReduce wall time in milliseconds
70    pub allreduce_ms: f64,
71}
72
73/// Result of non-block AllReduce for DDP pretraining.
74#[derive(Debug, Clone)]
75pub struct NonBlockAllReduceResult {
76    /// Component ID (0=lm_head, 1=final_norm, 2=embedding)
77    pub component: u8,
78    /// Averaged gradient vector
79    pub avg_gradients: Vec<f32>,
80    /// AllReduce wall time in milliseconds
81    pub allreduce_ms: f64,
82}
83
84impl GradientServer {
85    /// Create and bind the gradient server.
86    ///
87    /// # Errors
88    /// Returns error if binding fails.
89    pub fn bind(config: DistributedConfig) -> Result<Self, String> {
90        let listener = TcpListener::bind(config.bind_addr)
91            .map_err(|e| format!("failed to bind {}: {e}", config.bind_addr))?;
92        eprintln!(
93            "[coordinator] Listening on {} (expecting {} workers)",
94            config.bind_addr, config.expect_workers
95        );
96        Ok(Self { config, listener, workers: Vec::new(), total_samples: 0 })
97    }
98
99    /// Wait for all expected workers to connect.
100    ///
101    /// Blocks until `expect_workers` workers have sent JoinRequest messages.
102    ///
103    /// # Errors
104    /// Returns error if any connection fails or timeout is exceeded.
105    pub fn wait_for_workers(&mut self) -> Result<(), String> {
106        let expected = self.config.expect_workers;
107        eprintln!("[coordinator] Waiting for {expected} workers to connect...");
108
109        while self.workers.len() < expected {
110            let (stream, addr) =
111                self.listener.accept().map_err(|e| format!("accept failed: {e}"))?;
112            eprintln!("[coordinator] Connection from {addr}");
113
114            // Read JoinRequest
115            let msg = read_wire_message(&stream)?;
116            match msg {
117                WireMessage::JoinRequest { node_id, gpu_count, backend } => {
118                    let worker_id = self.workers.len() as u32;
119                    eprintln!(
120                        "[coordinator] Worker {worker_id} joined: {node_id} ({gpu_count} GPUs, {backend})"
121                    );
122
123                    // Send JoinAccepted
124                    let response =
125                        WireMessage::JoinAccepted { worker_id, total_workers: expected as u32 };
126                    send_wire_message(&stream, &response)?;
127
128                    self.workers.push(WorkerConnection {
129                        worker_id,
130                        node_id,
131                        gpu_count,
132                        backend,
133                        stream,
134                    });
135                }
136                other => {
137                    return Err(format!("expected JoinRequest, got {other:?}"));
138                }
139            }
140        }
141
142        eprintln!("[coordinator] All {expected} workers connected");
143        Ok(())
144    }
145
146    /// Set total sample count for sharding
147    pub fn set_total_samples(&mut self, n: usize) {
148        self.total_samples = n;
149    }
150
151    /// Send shard assignments to all workers for a given step.
152    ///
153    /// # Errors
154    /// Returns error if any send fails.
155    pub fn send_shard_assignments(&mut self, step: u64) -> Result<(), String> {
156        let n = self.workers.len();
157        let shard_size = self.total_samples / n;
158
159        for (i, worker) in self.workers.iter().enumerate() {
160            let start = i * shard_size;
161            let end = if i == n - 1 { self.total_samples } else { start + shard_size };
162            let msg = WireMessage::ShardAssignment { step, shard_start: start, shard_end: end };
163            send_wire_message(&worker.stream, &msg)?;
164        }
165        Ok(())
166    }
167
168    /// Collect gradients from all workers and compute AllReduce.
169    ///
170    /// # Contract: F-DP-003
171    ///
172    /// If any gradient contains NaN/Inf, returns an error (Jidoka halt).
173    ///
174    /// # Errors
175    /// Returns error on communication failure or non-finite gradient.
176    pub fn collect_and_reduce(&mut self, step: u64) -> Result<AllReduceResult, String> {
177        let start = Instant::now();
178        let n = self.workers.len();
179        let mut all_grads: Vec<Vec<f32>> = Vec::with_capacity(n);
180        let mut total_loss = 0.0f32;
181        let mut total_correct = 0usize;
182        let mut total_samples = 0usize;
183
184        for worker in &self.workers {
185            let msg = read_wire_message(&worker.stream)?;
186            match msg {
187                WireMessage::GradientPayload {
188                    step: recv_step,
189                    gradients,
190                    loss,
191                    correct,
192                    total,
193                    ..
194                } => {
195                    if recv_step != step {
196                        return Err(format!("step mismatch: expected {step}, got {recv_step}"));
197                    }
198
199                    // Jidoka: halt on NaN/Inf (F-DP-003)
200                    if has_non_finite(&gradients) {
201                        return Err(format!(
202                            "JIDOKA HALT: worker {} sent non-finite gradient at step {step}",
203                            worker.worker_id
204                        ));
205                    }
206
207                    total_loss += loss * total as f32;
208                    total_correct += correct;
209                    total_samples += total;
210                    all_grads.push(gradients);
211                }
212                other => {
213                    return Err(format!(
214                        "expected GradientPayload from worker {}, got {other:?}",
215                        worker.worker_id
216                    ));
217                }
218            }
219        }
220
221        // AllReduce: average gradients (F-DP-001)
222        let avg_gradients = average_gradients(&all_grads);
223        let global_loss = if total_samples > 0 { total_loss / total_samples as f32 } else { 0.0 };
224
225        let allreduce_ms = start.elapsed().as_secs_f64() * 1000.0;
226
227        Ok(AllReduceResult {
228            avg_gradients,
229            global_loss,
230            total_correct,
231            total_samples,
232            allreduce_ms,
233        })
234    }
235
236    /// Broadcast averaged gradients to all workers.
237    ///
238    /// # Errors
239    /// Returns error if any send fails.
240    pub fn broadcast_averaged(
241        &mut self,
242        step: u64,
243        result: &AllReduceResult,
244    ) -> Result<(), String> {
245        let msg = WireMessage::AveragedGradient {
246            step,
247            gradients: result.avg_gradients.clone(),
248            global_loss: result.global_loss,
249        };
250        for worker in &self.workers {
251            send_wire_message(&worker.stream, &msg)?;
252        }
253        Ok(())
254    }
255
256    /// Send shutdown message to all workers.
257    pub fn shutdown_workers(&mut self) {
258        for worker in &self.workers {
259            let _ = send_wire_message(&worker.stream, &WireMessage::Shutdown);
260        }
261    }
262
263    /// Number of connected workers
264    #[must_use]
265    pub fn worker_count(&self) -> usize {
266        self.workers.len()
267    }
268
269    /// Collect and reduce per-block gradients from all workers.
270    ///
271    /// Waits for `BlockGradientPayload` from each worker for the specified
272    /// block index, averages them, and returns the result.
273    ///
274    /// # Contract: C-DDP-001
275    ///
276    /// Output equals arithmetic mean of all workers' block gradients.
277    /// Jidoka halt on NaN/Inf gradients (F-DP-003).
278    ///
279    /// # Errors
280    ///
281    /// Returns error on communication failure, step mismatch, or NaN gradient.
282    pub fn collect_and_reduce_block(
283        &mut self,
284        step: u64,
285        block_idx: u32,
286    ) -> Result<BlockAllReduceResult, String> {
287        let start = Instant::now();
288        let n = self.workers.len();
289        let mut all_grads: Vec<Vec<f32>> = Vec::with_capacity(n);
290        let mut component_sizes = Vec::new();
291
292        for worker in &self.workers {
293            let msg = read_wire_message(&worker.stream)?;
294            match msg {
295                WireMessage::BlockGradientPayload {
296                    step: recv_step,
297                    block_idx: recv_block_idx,
298                    gradients,
299                    component_sizes: cs,
300                    ..
301                } => {
302                    if recv_step != step {
303                        return Err(format!("step mismatch: expected {step}, got {recv_step}"));
304                    }
305                    if recv_block_idx != block_idx {
306                        return Err(format!(
307                            "block_idx mismatch: expected {block_idx}, got {recv_block_idx}"
308                        ));
309                    }
310                    if has_non_finite(&gradients) {
311                        return Err(format!(
312                            "JIDOKA HALT: worker {} sent non-finite block {block_idx} gradient at step {step}",
313                            worker.worker_id
314                        ));
315                    }
316                    if component_sizes.is_empty() {
317                        component_sizes = cs;
318                    }
319                    all_grads.push(gradients);
320                }
321                other => {
322                    return Err(format!(
323                        "expected BlockGradientPayload from worker {}, got {other:?}",
324                        worker.worker_id
325                    ));
326                }
327            }
328        }
329
330        let avg_gradients = average_gradients(&all_grads);
331        let allreduce_ms = start.elapsed().as_secs_f64() * 1000.0;
332
333        Ok(BlockAllReduceResult { block_idx, avg_gradients, component_sizes, allreduce_ms })
334    }
335
336    /// Broadcast averaged block gradient to all workers.
337    ///
338    /// # Errors
339    /// Returns error if any send fails.
340    pub fn broadcast_averaged_block(
341        &mut self,
342        step: u64,
343        result: &BlockAllReduceResult,
344    ) -> Result<(), String> {
345        let msg = WireMessage::AveragedBlockGradient {
346            step,
347            block_idx: result.block_idx,
348            gradients: result.avg_gradients.clone(),
349            component_sizes: result.component_sizes.clone(),
350        };
351        for worker in &self.workers {
352            send_wire_message(&worker.stream, &msg)?;
353        }
354        Ok(())
355    }
356
357    /// Collect and reduce non-block gradient from all workers.
358    ///
359    /// Used for LM head, final norm, and embedding gradients.
360    ///
361    /// # Errors
362    /// Returns error on communication failure or NaN gradient.
363    pub fn collect_and_reduce_non_block(
364        &mut self,
365        step: u64,
366        expected_component: u8,
367    ) -> Result<NonBlockAllReduceResult, String> {
368        let start = Instant::now();
369        let n = self.workers.len();
370        let mut all_grads: Vec<Vec<f32>> = Vec::with_capacity(n);
371
372        for worker in &self.workers {
373            let msg = read_wire_message(&worker.stream)?;
374            match msg {
375                WireMessage::NonBlockGradientPayload {
376                    step: recv_step,
377                    component,
378                    gradients,
379                    ..
380                } => {
381                    if recv_step != step {
382                        return Err(format!("step mismatch: expected {step}, got {recv_step}"));
383                    }
384                    if component != expected_component {
385                        return Err(format!(
386                            "component mismatch: expected {expected_component}, got {component}"
387                        ));
388                    }
389                    if has_non_finite(&gradients) {
390                        return Err(format!(
391                            "JIDOKA HALT: worker {} sent non-finite component {component} gradient at step {step}",
392                            worker.worker_id
393                        ));
394                    }
395                    all_grads.push(gradients);
396                }
397                other => {
398                    return Err(format!(
399                        "expected NonBlockGradientPayload from worker {}, got {other:?}",
400                        worker.worker_id
401                    ));
402                }
403            }
404        }
405
406        let avg_gradients = average_gradients(&all_grads);
407        let allreduce_ms = start.elapsed().as_secs_f64() * 1000.0;
408
409        Ok(NonBlockAllReduceResult { component: expected_component, avg_gradients, allreduce_ms })
410    }
411
412    /// Broadcast averaged non-block gradient to all workers.
413    pub fn broadcast_averaged_non_block(
414        &mut self,
415        step: u64,
416        result: &NonBlockAllReduceResult,
417    ) -> Result<(), String> {
418        let msg = WireMessage::AveragedNonBlockGradient {
419            step,
420            component: result.component,
421            gradients: result.avg_gradients.clone(),
422        };
423        for worker in &self.workers {
424            send_wire_message(&worker.stream, &msg)?;
425        }
426        Ok(())
427    }
428}
429
430// ─── TCP IO helpers ──────────────────────────────────────────────────────────
431
432/// Read a length-prefixed wire message from a TCP stream.
433pub(crate) fn read_wire_message(stream: &TcpStream) -> Result<WireMessage, String> {
434    let mut len_buf = [0u8; 4];
435    (&*stream).read_exact(&mut len_buf).map_err(|e| format!("read length failed: {e}"))?;
436    let len = u32::from_be_bytes(len_buf) as usize;
437
438    if len > 100_000_000 {
439        return Err(format!("message too large: {len} bytes"));
440    }
441
442    let mut payload = vec![0u8; len];
443    (&*stream).read_exact(&mut payload).map_err(|e| format!("read payload failed: {e}"))?;
444
445    WireMessage::from_payload(&payload)
446}
447
448/// Send a wire message to a TCP stream.
449pub(crate) fn send_wire_message(stream: &TcpStream, msg: &WireMessage) -> Result<(), String> {
450    let bytes = msg.to_bytes();
451    (&*stream).write_all(&bytes).map_err(|e| format!("send failed: {e}"))?;
452    (&*stream).flush().map_err(|e| format!("flush failed: {e}"))?;
453    Ok(())
454}
455
456impl GradientServer {
457    /// Get the local address this server is listening on.
458    ///
459    /// Useful when binding to port 0 (OS-assigned) in tests.
460    #[must_use]
461    pub fn local_addr(&self) -> std::net::SocketAddr {
462        self.listener.local_addr().expect("listener has local addr")
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    #![allow(clippy::unwrap_used)]
469    use super::*;
470    use std::net::TcpStream;
471    use std::thread;
472
473    #[test]
474    fn test_server_bind() {
475        // Bind to random port
476        let config = DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
477        let server = GradientServer::bind(config);
478        assert!(server.is_ok());
479    }
480
481    #[test]
482    fn test_server_worker_count_initially_zero() {
483        let config = DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
484        let server = GradientServer::bind(config).expect("valid");
485        assert_eq!(server.worker_count(), 0);
486    }
487
488    #[test]
489    fn test_server_accept_worker() {
490        let config = DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
491        let mut server = GradientServer::bind(config).expect("valid");
492        let addr = server.local_addr();
493
494        // Spawn a worker that sends JoinRequest
495        let handle = thread::spawn(move || {
496            let stream = TcpStream::connect(addr).expect("valid");
497            let join = WireMessage::JoinRequest {
498                node_id: "test-worker".to_string(),
499                gpu_count: 1,
500                backend: "cpu".to_string(),
501            };
502            send_wire_message(&stream, &join).expect("valid");
503
504            // Read JoinAccepted
505            let response = read_wire_message(&stream).expect("valid");
506            match response {
507                WireMessage::JoinAccepted { worker_id, total_workers } => {
508                    assert_eq!(worker_id, 0);
509                    assert_eq!(total_workers, 1);
510                }
511                other => panic!("expected JoinAccepted, got {other:?}"),
512            }
513            stream
514        });
515
516        server.wait_for_workers().expect("valid");
517        assert_eq!(server.worker_count(), 1);
518
519        let _stream = handle.join().expect("valid");
520    }
521
522    #[test]
523    fn test_server_shard_and_reduce() {
524        let config = DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 2);
525        let mut server = GradientServer::bind(config).expect("valid");
526        let addr = server.local_addr();
527
528        // Spawn 2 workers
529        let handles: Vec<_> = (0..2)
530            .map(|i| {
531                thread::spawn(move || {
532                    let stream = TcpStream::connect(addr).expect("valid");
533                    let join = WireMessage::JoinRequest {
534                        node_id: format!("worker-{i}"),
535                        gpu_count: 1,
536                        backend: "cpu".to_string(),
537                    };
538                    send_wire_message(&stream, &join).expect("valid");
539                    let _ = read_wire_message(&stream).expect("valid"); // JoinAccepted
540
541                    // Read shard assignment
542                    let shard_msg = read_wire_message(&stream).expect("valid");
543                    let (shard_start, shard_end) = match shard_msg {
544                        WireMessage::ShardAssignment { shard_start, shard_end, .. } => {
545                            (shard_start, shard_end)
546                        }
547                        other => panic!("expected ShardAssignment, got {other:?}"),
548                    };
549
550                    // Send gradient
551                    let grad = WireMessage::GradientPayload {
552                        step: 0,
553                        worker_id: i,
554                        gradients: vec![1.0 + i as f32, 2.0 + i as f32],
555                        loss: 0.5 + i as f32 * 0.1,
556                        correct: shard_end - shard_start,
557                        total: shard_end - shard_start,
558                    };
559                    send_wire_message(&stream, &grad).expect("valid");
560
561                    // Read averaged gradient
562                    let avg_msg = read_wire_message(&stream).expect("valid");
563                    match avg_msg {
564                        WireMessage::AveragedGradient { gradients, .. } => {
565                            // Average of [1,2] and [2,3] should be [1.5, 2.5]
566                            assert!((gradients[0] - 1.5).abs() < 1e-5);
567                            assert!((gradients[1] - 2.5).abs() < 1e-5);
568                        }
569                        other => panic!("expected AveragedGradient, got {other:?}"),
570                    }
571
572                    stream
573                })
574            })
575            .collect();
576
577        // Server flow
578        server.wait_for_workers().expect("valid");
579        server.set_total_samples(100);
580        server.send_shard_assignments(0).expect("valid");
581        let result = server.collect_and_reduce(0).expect("valid");
582
583        assert!((result.avg_gradients[0] - 1.5).abs() < 1e-5);
584        assert!((result.avg_gradients[1] - 2.5).abs() < 1e-5);
585        assert_eq!(result.total_samples, 100);
586        assert!(result.allreduce_ms >= 0.0);
587
588        server.broadcast_averaged(0, &result).expect("valid");
589
590        for h in handles {
591            let _stream = h.join().expect("valid");
592        }
593    }
594
595    #[test]
596    fn test_server_jidoka_halt_on_nan() {
597        let config = DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
598        let mut server = GradientServer::bind(config).expect("valid");
599        let addr = server.local_addr();
600
601        let handle = thread::spawn(move || {
602            let stream = TcpStream::connect(addr).expect("valid");
603            let join = WireMessage::JoinRequest {
604                node_id: "bad-worker".to_string(),
605                gpu_count: 1,
606                backend: "cpu".to_string(),
607            };
608            send_wire_message(&stream, &join).expect("valid");
609            let _ = read_wire_message(&stream).expect("valid");
610
611            // Read shard
612            let _ = read_wire_message(&stream).expect("valid");
613
614            // Send NaN gradient
615            let grad = WireMessage::GradientPayload {
616                step: 0,
617                worker_id: 0,
618                gradients: vec![1.0, f32::NAN, 3.0],
619                loss: 0.5,
620                correct: 5,
621                total: 10,
622            };
623            send_wire_message(&stream, &grad).expect("valid");
624            stream
625        });
626
627        server.wait_for_workers().expect("valid");
628        server.set_total_samples(10);
629        server.send_shard_assignments(0).expect("valid");
630        let result = server.collect_and_reduce(0);
631        assert!(result.is_err());
632        assert!(result.unwrap_err().contains("JIDOKA HALT"));
633
634        let _stream = handle.join().expect("valid");
635    }
636}