Skip to main content

entrenar/finetune/
worker_client.rs

1//! TCP worker client for distributed training (worker side)
2//!
3//! The `WorkerClient` runs on each worker node and:
4//! 1. Connects to the coordinator
5//! 2. Receives shard assignments per training step
6//! 3. Computes forward/backward locally
7//! 4. Sends gradients to coordinator
8//! 5. Receives averaged gradients and applies optimizer step
9//!
10//! # Contract: F-DP-004 (Backend Fallback)
11//!
12//! The worker uses `forward_hidden_dispatch()` which falls back:
13//! CUDA → wgpu → CPU. Training always proceeds regardless of GPU availability.
14
15use super::distributed::{DistributedConfig, WireMessage};
16use super::gradient_server::{read_wire_message, send_wire_message};
17use std::net::TcpStream;
18
19/// Worker client that connects to the coordinator.
20pub struct WorkerClient {
21    config: DistributedConfig,
22    stream: TcpStream,
23    worker_id: u32,
24    total_workers: u32,
25}
26
27/// Shard assignment received from coordinator.
28#[derive(Debug, Clone)]
29pub struct ShardAssignment {
30    pub step: u64,
31    pub shard_start: usize,
32    pub shard_end: usize,
33}
34
35/// Averaged gradient received from coordinator after AllReduce.
36#[derive(Debug, Clone)]
37pub struct AveragedResult {
38    pub step: u64,
39    pub gradients: Vec<f32>,
40    pub global_loss: f32,
41}
42
43/// Averaged block gradient received from coordinator (v2 per-block DDP).
44#[derive(Debug, Clone)]
45pub struct AveragedBlockResult {
46    pub step: u64,
47    pub block_idx: u32,
48    pub gradients: Vec<f32>,
49    pub component_sizes: Vec<u32>,
50}
51
52/// Averaged non-block gradient received from coordinator (v2 DDP).
53#[derive(Debug, Clone)]
54pub struct AveragedNonBlockResult {
55    pub step: u64,
56    pub component: u8,
57    pub gradients: Vec<f32>,
58}
59
60impl WorkerClient {
61    /// Connect to the coordinator and complete the join handshake.
62    ///
63    /// # Arguments
64    /// * `config` - Worker configuration with coordinator address
65    /// * `gpu_count` - Number of GPUs this worker has
66    /// * `backend` - Backend name (e.g., "wgpu", "cuda", "cpu")
67    ///
68    /// # Errors
69    /// Returns error if connection or handshake fails.
70    pub fn connect(
71        config: DistributedConfig,
72        gpu_count: u32,
73        backend: &str,
74    ) -> Result<Self, String> {
75        let coord_addr = config
76            .coordinator_addr
77            .ok_or_else(|| "worker config must have coordinator_addr".to_string())?;
78
79        eprintln!("[worker {}] Connecting to coordinator at {coord_addr}...", config.node_id);
80
81        let stream = TcpStream::connect(coord_addr)
82            .map_err(|e| format!("failed to connect to {coord_addr}: {e}"))?;
83
84        // Send JoinRequest
85        let join = WireMessage::JoinRequest {
86            node_id: config.node_id.clone(),
87            gpu_count,
88            backend: backend.to_string(),
89        };
90        send_wire_message(&stream, &join)?;
91
92        // Read JoinAccepted
93        let response = read_wire_message(&stream)?;
94        match response {
95            WireMessage::JoinAccepted { worker_id, total_workers } => {
96                eprintln!(
97                    "[worker {}] Joined as worker {worker_id}/{total_workers}",
98                    config.node_id
99                );
100                Ok(Self { config, stream, worker_id, total_workers })
101            }
102            other => Err(format!("expected JoinAccepted, got {other:?}")),
103        }
104    }
105
106    /// Receive shard assignment for the next training step.
107    ///
108    /// Returns `None` if the coordinator sends a Shutdown message.
109    ///
110    /// # Errors
111    /// Returns error on communication failure.
112    pub fn receive_shard(&self) -> Result<Option<ShardAssignment>, String> {
113        let msg = read_wire_message(&self.stream)?;
114        match msg {
115            WireMessage::ShardAssignment { step, shard_start, shard_end } => {
116                Ok(Some(ShardAssignment { step, shard_start, shard_end }))
117            }
118            WireMessage::Shutdown => {
119                eprintln!("[worker {}] Received shutdown from coordinator", self.config.node_id);
120                Ok(None)
121            }
122            other => Err(format!("expected ShardAssignment or Shutdown, got {other:?}")),
123        }
124    }
125
126    /// Send computed gradients to the coordinator.
127    ///
128    /// # Arguments
129    /// * `step` - Training step number
130    /// * `gradients` - Gradient vector (flattened LoRA params + classifier head)
131    /// * `loss` - Average loss for this shard
132    /// * `correct` - Number of correct predictions
133    /// * `total` - Total samples in shard
134    ///
135    /// # Errors
136    /// Returns error on send failure.
137    pub fn send_gradients(
138        &self,
139        step: u64,
140        gradients: Vec<f32>,
141        loss: f32,
142        correct: usize,
143        total: usize,
144    ) -> Result<(), String> {
145        let msg = WireMessage::GradientPayload {
146            step,
147            worker_id: self.worker_id,
148            gradients,
149            loss,
150            correct,
151            total,
152        };
153        send_wire_message(&self.stream, &msg)
154    }
155
156    /// Receive averaged gradients from coordinator after AllReduce.
157    ///
158    /// # Errors
159    /// Returns error on communication failure.
160    pub fn receive_averaged(&self) -> Result<AveragedResult, String> {
161        let msg = read_wire_message(&self.stream)?;
162        match msg {
163            WireMessage::AveragedGradient { step, gradients, global_loss } => {
164                Ok(AveragedResult { step, gradients, global_loss })
165            }
166            WireMessage::Shutdown => Err("shutdown during AllReduce".to_string()),
167            other => Err(format!("expected AveragedGradient, got {other:?}")),
168        }
169    }
170
171    // --- v2 per-block DDP methods ---
172
173    /// Send per-block gradient to coordinator for AllReduce (v2 DDP).
174    ///
175    /// # Arguments
176    /// * `step` - Training step number
177    /// * `block_idx` - Transformer block index (0-based)
178    /// * `num_blocks` - Total number of transformer blocks
179    /// * `gradients` - Flattened gradient vector (9 components concatenated)
180    /// * `component_sizes` - Element count for each of the 9 components
181    pub fn send_block_gradient(
182        &self,
183        step: u64,
184        block_idx: u32,
185        num_blocks: u32,
186        gradients: Vec<f32>,
187        component_sizes: Vec<u32>,
188    ) -> Result<(), String> {
189        let msg = WireMessage::BlockGradientPayload {
190            step,
191            worker_id: self.worker_id,
192            block_idx,
193            num_blocks,
194            gradients,
195            component_sizes,
196        };
197        send_wire_message(&self.stream, &msg)
198    }
199
200    /// Receive averaged block gradient from coordinator after AllReduce (v2 DDP).
201    pub fn receive_averaged_block(&self) -> Result<AveragedBlockResult, String> {
202        let msg = read_wire_message(&self.stream)?;
203        match msg {
204            WireMessage::AveragedBlockGradient { step, block_idx, gradients, component_sizes } => {
205                Ok(AveragedBlockResult { step, block_idx, gradients, component_sizes })
206            }
207            WireMessage::Shutdown => Err("shutdown during block AllReduce".to_string()),
208            other => Err(format!("expected AveragedBlockGradient, got {other:?}")),
209        }
210    }
211
212    /// Send non-block gradient to coordinator for AllReduce (v2 DDP).
213    ///
214    /// # Arguments
215    /// * `step` - Training step number
216    /// * `component` - 0=lm_head, 1=final_norm, 2=embedding
217    /// * `gradients` - Gradient vector for this component
218    pub fn send_non_block_gradient(
219        &self,
220        step: u64,
221        component: u8,
222        gradients: Vec<f32>,
223    ) -> Result<(), String> {
224        let msg = WireMessage::NonBlockGradientPayload {
225            step,
226            worker_id: self.worker_id,
227            component,
228            gradients,
229        };
230        send_wire_message(&self.stream, &msg)
231    }
232
233    /// Receive averaged non-block gradient from coordinator after AllReduce (v2 DDP).
234    pub fn receive_averaged_non_block(&self) -> Result<AveragedNonBlockResult, String> {
235        let msg = read_wire_message(&self.stream)?;
236        match msg {
237            WireMessage::AveragedNonBlockGradient { step, component, gradients } => {
238                Ok(AveragedNonBlockResult { step, component, gradients })
239            }
240            WireMessage::Shutdown => Err("shutdown during non-block AllReduce".to_string()),
241            other => Err(format!("expected AveragedNonBlockGradient, got {other:?}")),
242        }
243    }
244
245    /// This worker's assigned ID
246    #[must_use]
247    pub fn worker_id(&self) -> u32 {
248        self.worker_id
249    }
250
251    /// Total number of workers in the cluster
252    #[must_use]
253    pub fn total_workers(&self) -> u32 {
254        self.total_workers
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    #![allow(clippy::unwrap_used)]
261    use super::super::distributed::DistributedConfig;
262    use super::super::gradient_server::GradientServer;
263    use super::*;
264    use std::thread;
265
266    #[test]
267    fn test_worker_connect_and_join() {
268        let server_config =
269            DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
270        let mut server = GradientServer::bind(server_config).expect("valid");
271        let addr = server.local_addr();
272
273        let handle = thread::spawn(move || {
274            let worker_config = DistributedConfig::worker(addr);
275            let client = WorkerClient::connect(worker_config, 1, "cpu").expect("valid");
276            assert_eq!(client.worker_id(), 0);
277            assert_eq!(client.total_workers(), 1);
278            client
279        });
280
281        server.wait_for_workers().expect("valid");
282        let _client = handle.join().expect("valid");
283    }
284
285    #[test]
286    fn test_worker_block_gradient_roundtrip() {
287        let server_config =
288            DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
289        let mut server = GradientServer::bind(server_config).expect("valid");
290        let addr = server.local_addr();
291
292        let component_sizes = vec![4, 2, 2, 4, 8, 8, 8, 1, 1];
293        let total: u32 = component_sizes.iter().sum();
294        let grads: Vec<f32> = (0..total).map(|i| i as f32 * 0.1).collect();
295
296        let grads_clone = grads.clone();
297        let sizes_clone = component_sizes.clone();
298        let handle = thread::spawn(move || {
299            let worker_config = DistributedConfig::worker(addr);
300            let client = WorkerClient::connect(worker_config, 1, "cuda").expect("valid");
301
302            // Send block gradient
303            client.send_block_gradient(0, 5, 24, grads_clone, sizes_clone).expect("valid");
304
305            // Receive averaged block gradient
306            let avg = client.receive_averaged_block().expect("valid");
307            assert_eq!(avg.step, 0);
308            assert_eq!(avg.block_idx, 5);
309            // Single worker: averaged == original
310            assert_eq!(avg.gradients.len(), total as usize);
311            avg
312        });
313
314        server.wait_for_workers().expect("valid");
315        let result = server.collect_and_reduce_block(0, 5).expect("valid");
316        assert_eq!(result.block_idx, 5);
317        assert_eq!(result.avg_gradients.len(), total as usize);
318        server.broadcast_averaged_block(0, &result).expect("valid");
319
320        let avg = handle.join().expect("valid");
321        // Single worker: averaged gradients should equal original
322        for (a, b) in avg.gradients.iter().zip(grads.iter()) {
323            assert!((a - b).abs() < 1e-6, "gradient mismatch: {a} != {b}");
324        }
325    }
326
327    #[test]
328    fn test_worker_non_block_gradient_roundtrip() {
329        let server_config =
330            DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
331        let mut server = GradientServer::bind(server_config).expect("valid");
332        let addr = server.local_addr();
333
334        let grads = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
335
336        let grads_clone = grads.clone();
337        let handle = thread::spawn(move || {
338            let worker_config = DistributedConfig::worker(addr);
339            let client = WorkerClient::connect(worker_config, 1, "cuda").expect("valid");
340
341            // Send non-block gradient (component=0 = lm_head)
342            client.send_non_block_gradient(0, 0, grads_clone).expect("valid");
343
344            // Receive averaged
345            let avg = client.receive_averaged_non_block().expect("valid");
346            assert_eq!(avg.step, 0);
347            assert_eq!(avg.component, 0);
348            avg
349        });
350
351        server.wait_for_workers().expect("valid");
352        let result = server.collect_and_reduce_non_block(0, 0).expect("valid");
353        assert_eq!(result.component, 0);
354        server.broadcast_averaged_non_block(0, &result).expect("valid");
355
356        let avg = handle.join().expect("valid");
357        for (a, b) in avg.gradients.iter().zip(grads.iter()) {
358            assert!((a - b).abs() < 1e-6, "gradient mismatch: {a} != {b}");
359        }
360    }
361
362    #[test]
363    fn test_two_worker_block_allreduce() {
364        let server_config =
365            DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 2);
366        let mut server = GradientServer::bind(server_config).expect("valid");
367        let addr = server.local_addr();
368
369        let component_sizes = vec![2, 1, 1, 2, 2, 2, 2, 1, 1];
370        let total: u32 = component_sizes.iter().sum();
371
372        // Worker 0: gradients = [1.0, 1.0, ...]
373        let sizes0 = component_sizes.clone();
374        let h0 = thread::spawn(move || {
375            let cfg = DistributedConfig::worker(addr);
376            let client = WorkerClient::connect(cfg, 1, "cuda").expect("valid");
377            let grads = vec![1.0f32; total as usize];
378            client.send_block_gradient(0, 0, 1, grads, sizes0).expect("valid");
379            client.receive_averaged_block().expect("valid")
380        });
381
382        // Worker 1: gradients = [3.0, 3.0, ...]
383        let sizes1 = component_sizes.clone();
384        let h1 = thread::spawn(move || {
385            let cfg = DistributedConfig::worker(addr);
386            let client = WorkerClient::connect(cfg, 1, "cuda").expect("valid");
387            let grads = vec![3.0f32; total as usize];
388            client.send_block_gradient(0, 0, 1, grads, sizes1).expect("valid");
389            client.receive_averaged_block().expect("valid")
390        });
391
392        server.wait_for_workers().expect("valid");
393        let result = server.collect_and_reduce_block(0, 0).expect("valid");
394        server.broadcast_averaged_block(0, &result).expect("valid");
395
396        let avg0 = h0.join().expect("valid");
397        let avg1 = h1.join().expect("valid");
398
399        // Average of [1.0, 1.0, ...] and [3.0, 3.0, ...] = [2.0, 2.0, ...]
400        for g in &avg0.gradients {
401            assert!((g - 2.0).abs() < 1e-6, "expected 2.0, got {g}");
402        }
403        for g in &avg1.gradients {
404            assert!((g - 2.0).abs() < 1e-6, "expected 2.0, got {g}");
405        }
406    }
407
408    #[test]
409    fn test_worker_full_training_step() {
410        let server_config =
411            DistributedConfig::coordinator("127.0.0.1:0".parse().expect("valid"), 1);
412        let mut server = GradientServer::bind(server_config).expect("valid");
413        let addr = server.local_addr();
414
415        let handle = thread::spawn(move || {
416            let worker_config = DistributedConfig::worker(addr);
417            let client = WorkerClient::connect(worker_config, 1, "cpu").expect("valid");
418
419            // Receive shard
420            let shard = client.receive_shard().expect("valid").expect("should get shard");
421            assert_eq!(shard.step, 0);
422            assert_eq!(shard.shard_start, 0);
423            assert_eq!(shard.shard_end, 50);
424
425            // Send gradients
426            client.send_gradients(0, vec![1.0, 2.0, 3.0], 0.5, 48, 50).expect("valid");
427
428            // Receive averaged
429            let avg = client.receive_averaged().expect("valid");
430            assert_eq!(avg.step, 0);
431            assert_eq!(avg.gradients, vec![1.0, 2.0, 3.0]); // Single worker, no averaging
432            assert!((avg.global_loss - 0.5).abs() < 1e-5);
433
434            client
435        });
436
437        server.wait_for_workers().expect("valid");
438        server.set_total_samples(50);
439        server.send_shard_assignments(0).expect("valid");
440        let result = server.collect_and_reduce(0).expect("valid");
441        server.broadcast_averaged(0, &result).expect("valid");
442
443        let _client = handle.join().expect("valid");
444    }
445}