Skip to main content

cake_core/cake/
worker.rs

1use std::{
2    collections::HashMap,
3    net::SocketAddr,
4    sync::Arc,
5    time::{Duration, Instant},
6};
7
8use super::{Context, Forwarder, Message, WorkerInfo};
9use crate::models::Generator;
10
11use anyhow::Result;
12use candle_core::{DType, Device};
13use tokio::{
14    io::{AsyncReadExt, AsyncWriteExt},
15    net::{TcpListener, TcpStream},
16};
17
18/// Determines how often worker statistics are calculated and printed.
19const NUM_OPS_TO_STATS: usize = 5;
20
21/// A single worker state.
22#[derive(Clone)]
23struct WorkerContext<F> {
24    device: Device,
25    device_idx: usize,
26    dtype: DType,
27    blocks: Arc<HashMap<String, Box<F>>>,
28    /// Maps each layer name to the device it was loaded on.
29    layer_devices: Arc<HashMap<String, Device>>,
30    context: Context,
31}
32
33impl<F: Forwarder> WorkerContext<F> {
34    /// Create a WorkerInfo structure to be sent to the master.
35    fn to_info(&self, latency: u128) -> WorkerInfo {
36        WorkerInfo {
37            version: env!("CARGO_PKG_VERSION").to_string(),
38            os: std::env::consts::OS.to_string(),
39            arch: std::env::consts::ARCH.to_string(),
40            device: if self.device.is_cuda() {
41                "cuda".to_string()
42            } else if self.device.is_metal() {
43                "metal".to_string()
44            } else {
45                "cpu".to_string()
46            },
47            device_idx: self.device_idx,
48            latency,
49            dtype: format!("{:?}", self.dtype),
50        }
51    }
52
53    /// Create a copy of self with new kv-cache.
54    fn get_client_context(&self) -> Self {
55        let cache = self.context.cache.as_ref().map(|cache| cache.as_new());
56
57        let mut cloned_context = self.context.clone();
58        cloned_context.cache = cache;
59
60        WorkerContext {
61            device: self.device.clone(),
62            device_idx: self.device_idx,
63            dtype: self.dtype,
64            blocks: self.blocks.clone(),
65            layer_devices: self.layer_devices.clone(),
66            // each client loop gets a new cache
67            context: cloned_context,
68        }
69    }
70}
71
72/// Cake worker node.
73pub struct Worker<G: Generator> {
74    listener: TcpListener,
75    context: WorkerContext<G::Shardable>,
76}
77
78impl<G: Generator + 'static> Worker<G> {
79    /// Detect how many CUDA devices are available.
80    fn detect_cuda_device_count() -> usize {
81        #[cfg(feature = "cuda")]
82        {
83            // Try creating devices until one fails
84            let mut count = 0;
85            while Device::new_cuda(count).is_ok() {
86                count += 1;
87            }
88            count
89        }
90        #[cfg(not(feature = "cuda"))]
91        {
92            0
93        }
94    }
95
96    /// Create a new Worker from the context.
97    pub async fn new(ctx: &mut Context) -> Result<Self> {
98        let worker_name = if let Some(name) = &ctx.args.name {
99            name.to_string()
100        } else {
101            return Err(anyhow!("no --name provided for worker"));
102        };
103
104        let worker_topology = if let Some(node) = ctx.topology.get(&worker_name) {
105            node
106        } else if !ctx.topology.is_empty() {
107            let first = ctx.topology.keys().next().unwrap();
108            log::warn!(
109                "topology for worker name '{}' not found, using '{}'",
110                &worker_name,
111                first
112            );
113            ctx.topology.get(first).unwrap()
114        } else {
115            return Err(anyhow!(
116                "could not find topology for {worker_name} and topology file is empty"
117            ));
118        };
119
120        // Detect available GPUs for multi-GPU support
121        let num_gpus = if ctx.device.is_cuda() {
122            Self::detect_cuda_device_count().max(1)
123        } else {
124            1
125        };
126
127        let use_multi_gpu = num_gpus > 1 && worker_topology.layers.len() > 1;
128
129        if use_multi_gpu {
130            log::info!(
131                "detected {} CUDA devices, splitting {} layers across GPUs",
132                num_gpus,
133                worker_topology.layers.len()
134            );
135        }
136
137        let mut blocks = HashMap::new();
138        let mut layer_devices: HashMap<String, Device> = HashMap::new();
139
140        if use_multi_gpu {
141            let model_index = ctx.data_path.join("model.safetensors.index.json");
142
143            // Group layers by GPU assignment
144            let mut gpu_layer_groups: Vec<Vec<String>> = vec![vec![]; num_gpus];
145            for (i, name) in worker_topology.layers.iter().enumerate() {
146                let gpu_idx = i * num_gpus / worker_topology.layers.len();
147                gpu_layer_groups[gpu_idx].push(name.clone());
148            }
149
150            // Create per-GPU devices and VarBuilders (filtered to each GPU's layers)
151            let mut gpu_devices: Vec<Device> = Vec::new();
152            let mut gpu_var_builders: Vec<candle_nn::VarBuilder<'static>> = Vec::new();
153
154            for ordinal in 0..num_gpus {
155                let dev = Device::new_cuda(ordinal)?;
156
157                #[cfg(feature = "cuda")]
158                if let Device::Cuda(cuda_dev) = &dev {
159                    unsafe {
160                        cuda_dev.disable_event_tracking();
161                    }
162                }
163
164                let vb = crate::utils::load_var_builder_for_specific_layers(
165                    model_index.clone(),
166                    ctx.dtype,
167                    dev.clone(),
168                    &gpu_layer_groups[ordinal],
169                    ctx.fp8,
170                )?;
171                log::info!("  GPU {} ready", ordinal);
172                gpu_devices.push(dev);
173                gpu_var_builders.push(vb);
174            }
175
176            // Load layers in parallel across GPUs
177            let mut handles = Vec::new();
178            for gpu_idx in 0..num_gpus {
179                let dev = gpu_devices[gpu_idx].clone();
180                let vb = gpu_var_builders[gpu_idx].clone();
181                let layers = std::mem::take(&mut gpu_layer_groups[gpu_idx]);
182                let mut thread_ctx = ctx.clone();
183                thread_ctx.device = dev.clone();
184                thread_ctx.var_builder = Some(vb);
185
186                handles.push(std::thread::spawn(
187                    move || -> Result<Vec<(String, Device, Box<G::Shardable>)>> {
188                        #[cfg(feature = "cuda")]
189                        if let Device::Cuda(ref cuda_dev) = dev {
190                            cuda_dev
191                                .cuda_stream()
192                                .context()
193                                .bind_to_thread()
194                                .map_err(|e| {
195                                    anyhow!(
196                                        "failed to bind CUDA context for GPU {gpu_idx}: {e:?}"
197                                    )
198                                })?;
199                        }
200
201                        let mut results = Vec::new();
202                        for layer_name in layers {
203                            log::info!("loading {} on cuda:{} ...", &layer_name, gpu_idx);
204                            let block =
205                                G::Shardable::load(layer_name.clone(), &thread_ctx)?;
206                            results.push((layer_name, dev.clone(), block));
207                        }
208                        Ok(results)
209                    },
210                ));
211            }
212
213            // Collect results from all GPU threads
214            for handle in handles {
215                let results = handle
216                    .join()
217                    .map_err(|_| anyhow!("GPU loading thread panicked"))??;
218                for (name, dev, block) in results {
219                    layer_devices.insert(name.clone(), dev);
220                    blocks.insert(name, block);
221                }
222            }
223        } else {
224            for block_layer_name in worker_topology.layers.iter() {
225                log::info!("loading {} ...", &block_layer_name);
226
227                let block = G::Shardable::load(block_layer_name.to_string(), ctx)?;
228                layer_devices.insert(block_layer_name.to_string(), ctx.device.clone());
229                blocks.insert(block_layer_name.to_string(), block);
230            }
231        }
232
233        let blocks = Arc::new(blocks);
234        let layer_devices = Arc::new(layer_devices);
235
236        let listener = {
237            let taken = ctx.listener_override.lock().unwrap().take();
238            if let Some(existing) = taken {
239                existing
240            } else {
241                TcpListener::bind(&ctx.args.address).await?
242            }
243        };
244
245        log::info!(
246            "listening on {} (mem:{}) ...",
247            &ctx.args.address,
248            human_bytes::human_bytes(memory_stats::memory_stats().map(|m| m.physical_mem).unwrap_or(0) as f64)
249        );
250
251        let device = ctx.device.clone();
252        let dtype = ctx.dtype;
253        let device_idx = ctx.args.device;
254
255        let context = WorkerContext {
256            device,
257            device_idx,
258            dtype,
259            blocks,
260            layer_devices,
261            context: ctx.clone(),
262        };
263
264        Ok(Self { listener, context })
265    }
266
267    /// Read a message from the socket and return elapsed time, message size and message.
268    async fn read_message_timed<R>(mut socket: R) -> Result<(Duration, usize, Message)>
269    where
270        R: AsyncReadExt + Unpin,
271    {
272        let start = Instant::now();
273        let (size, message) = Message::from_reader(&mut socket).await?;
274        let latency = start.elapsed();
275
276        Ok((latency, size, message))
277    }
278
279    /// Write a message to the socket and return the elapsed time with written size.
280    async fn write_message_timed<W>(mut socket: W, message: Message) -> Result<(Duration, usize)>
281    where
282        W: AsyncWriteExt + Unpin,
283    {
284        let start = Instant::now();
285        let size = message.to_writer(&mut socket).await?;
286        let latency = start.elapsed();
287
288        Ok((latency, size))
289    }
290
291    /// Main loop handling communication with the master.
292    async fn handle_master_client(
293        mut socket: TcpStream,
294        client: SocketAddr,
295        mut context: WorkerContext<G::Shardable>,
296    ) -> Result<()> {
297        // Authenticate if cluster key is set
298        if let Some(ref cluster_key) = context.context.args.cluster_key {
299            super::auth::authenticate_as_worker(&mut socket, cluster_key)
300                .await
301                .map_err(|e| anyhow!("[{}] authentication failed: {}", &client, e))?;
302            log::debug!("[{}] authenticated", &client);
303        }
304
305        // read first message: expect Hello, but handle LayerAssignment for master restarts
306        let (latency, _size, first_msg) = Self::read_message_timed(&mut socket).await?;
307        match first_msg {
308            Message::Hello => { /* normal inference handshake, continue below */ }
309            Message::LayerAssignment { ref layers, .. } => {
310                // Master restarted and is re-running setup against an already-running worker.
311                // Ack the assignment (we already have cached data) and signal ready,
312                // then close this connection so the master can reconnect for inference.
313                log::info!(
314                    "[{}] master re-setup: accepting {} layer assignment(s)",
315                    &client,
316                    layers.len()
317                );
318                let ack = Message::LayerAssignmentAck { needs_data: false };
319                ack.to_writer(&mut socket).await?;
320                Message::WorkerReady.to_writer(&mut socket).await?;
321                log::info!("[{}] re-setup complete, closing setup connection", &client);
322                return Ok(());
323            }
324            other => {
325                return Err(anyhow!(
326                    "[{}] unexpected first message (expected Hello): {:?}",
327                    &client,
328                    other
329                ));
330            }
331        }
332
333        // send info
334        if let Err(e) = Self::write_message_timed(
335            &mut socket,
336            Message::WorkerInfo(context.to_info(latency.as_millis())),
337        )
338        .await
339        {
340            return Err(anyhow!("[{}] could not send worker info: {:?}", &client, e));
341        }
342
343        let mut msg_idx = 0;
344        let mut avg_ops = 0;
345        let mut avg_write = 0;
346        let mut avg_read = 0;
347        let mut read_buf = Vec::new();
348        let mut write_buf = Vec::new();
349
350        // keep reading messages
351        while let Ok((read_time, read_size, op_message)) = {
352            let start = Instant::now();
353            Message::from_reader_buf(&mut socket, &mut read_buf)
354                .await
355                .map(|(size, msg)| (start.elapsed(), size, msg))
356        } {
357            if matches!(op_message, Message::Goodbye) {
358                log::debug!("[{}] goodbye", &client);
359                context
360                    .context
361                    .cache
362                    .as_mut()
363                    .expect("No cache specified")
364                    .clear();
365
366                // send info
367                if let Err(e) = Self::write_message_timed(
368                    &mut socket,
369                    Message::WorkerInfo(context.to_info(read_time.as_millis())),
370                )
371                .await
372                {
373                    return Err(anyhow!("[{}] could not send worker info: {:?}", &client, e));
374                }
375
376                continue;
377            }
378
379            let (x, ops) = match op_message {
380                // single block operation
381                Message::SingleOp {
382                    layer_name,
383                    x,
384                    index_pos,
385                    block_idx,
386                } => (x, vec![(layer_name, index_pos, block_idx)]),
387                // batched
388                Message::Batch { x, batch } => (x, batch),
389                _ => {
390                    return Err(anyhow!(
391                        "[{}] unhandled message in loop: {:?}",
392                        &client,
393                        op_message
394                    ));
395                }
396            };
397
398            // load raw tensor to the first block's device
399            let load_start = Instant::now();
400            let first_device = ops
401                .first()
402                .and_then(|(name, _, _)| context.layer_devices.get(name))
403                .unwrap_or(&context.device);
404
405            // Ensure the CUDA context for the target device is active on this thread.
406            #[cfg(feature = "cuda")]
407            if let Device::Cuda(cuda_dev) = first_device {
408                if let Err(e) = cuda_dev.cuda_stream().context().bind_to_thread() {
409                    log::error!("[{client}] failed to bind CUDA context: {:?}", e);
410                }
411            }
412
413            let mut x = match x.to_tensor(first_device) {
414                Ok(t) => t,
415                Err(e) => {
416                    let msg = format!("failed to load tensor to device: {e}");
417                    log::error!("[{}] {}", &client, &msg);
418                    let _ = Self::write_message_timed(
419                        &mut socket,
420                        Message::WorkerError { message: msg },
421                    )
422                    .await;
423                    continue;
424                }
425            };
426
427            let load_elapsed = load_start.elapsed();
428
429            let num_ops = ops.len();
430            let start_ops = Instant::now();
431
432            let mut batch_error = false;
433
434            // for each element in the ops batch
435            for (layer_name, index_pos, block_idx) in ops {
436                // move tensor to the block's device if needed (multi-GPU)
437                if let Some(block_device) = context.layer_devices.get(&layer_name) {
438                    // Bind CUDA context before cross-device transfer
439                    #[cfg(feature = "cuda")]
440                    if let Device::Cuda(cuda_dev) = block_device {
441                        if let Err(e) = cuda_dev.cuda_stream().context().bind_to_thread() {
442                            log::error!(
443                                "[{client}] failed to bind CUDA context for {}: {:?}",
444                                &layer_name,
445                                e
446                            );
447                        }
448                    }
449
450                    x = match x.to_device(block_device) {
451                        Ok(t) => t,
452                        Err(e) => {
453                            let msg = format!(
454                                "failed to move tensor to device for layer {}: {e}",
455                                &layer_name
456                            );
457                            log::error!("[{}] {}", &client, &msg);
458                            let _ = Self::write_message_timed(
459                                &mut socket,
460                                Message::WorkerError { message: msg },
461                            )
462                            .await;
463                            batch_error = true;
464                            break;
465                        }
466                    };
467                }
468
469                // get layer block by name
470                if let Some(block) = context.blocks.get(&layer_name) {
471                    // run forward pass
472                    x = match block
473                        .forward(&x, index_pos, block_idx, &mut context.context)
474                        .await
475                    {
476                        Ok(t) => {
477                            // Metal requires per-layer sync to prevent command buffer
478                            // accumulation which causes catastrophic performance degradation.
479                            if t.device().is_metal() {
480                                let _ = t.device().synchronize();
481                            }
482                            t
483                        }
484                        Err(e) => {
485                            let msg = format!(
486                                "forward pass failed for layer {} (block_idx={}): {e}",
487                                &layer_name, block_idx
488                            );
489                            log::error!("[{}] {}", &client, &msg);
490                            let _ = Self::write_message_timed(
491                                &mut socket,
492                                Message::WorkerError { message: msg },
493                            )
494                            .await;
495                            batch_error = true;
496                            break;
497                        }
498                    };
499                } else {
500                    let msg = format!("could not find layer {}", &layer_name);
501                    log::error!("[{}] {}", &client, &msg);
502                    let _ = Self::write_message_timed(
503                        &mut socket,
504                        Message::WorkerError { message: msg },
505                    )
506                    .await;
507                    batch_error = true;
508                    break;
509                }
510            }
511
512            if batch_error {
513                continue;
514            }
515
516            let elaps_ops = start_ops.elapsed();
517
518            // serialize response tensor (includes GPU sync for data readback)
519            let ser_start = Instant::now();
520            let resp_msg = Message::from_tensor(&x);
521            let ser_elapsed = ser_start.elapsed();
522
523            // send response tensor (reuse write buffer)
524            let write_start = Instant::now();
525            match resp_msg.to_writer_buf(&mut socket, &mut write_buf).await {
526                Ok(written) => {
527                    let elaps_write = write_start.elapsed();
528                    log::debug!(
529                        "[{}] read={:.1}ms load={:.1}ms fwd={:.1}ms ser={:.1}ms write={:.1}ms ({} ops)",
530                        &client,
531                        read_time.as_secs_f64() * 1000.0,
532                        load_elapsed.as_secs_f64() * 1000.0,
533                        elaps_ops.as_secs_f64() * 1000.0,
534                        ser_elapsed.as_secs_f64() * 1000.0,
535                        elaps_write.as_secs_f64() * 1000.0,
536                        num_ops,
537                    );
538
539                    let ops_per_sec = (num_ops as f64 / elaps_ops.as_secs_f64()) as usize;
540                    let write_bytes_per_sec = (written as f64 / elaps_write.as_secs_f64()) as usize;
541                    let read_bytes_per_sec = (read_size as f64 / read_time.as_secs_f64()) as usize;
542
543                    avg_ops += ops_per_sec;
544                    avg_write += write_bytes_per_sec;
545                    avg_read += read_bytes_per_sec;
546                }
547                Err(e) => {
548                    return Err(anyhow!(
549                        "[{}] could not send response tensor: {:?}",
550                        &client,
551                        e
552                    ));
553                }
554            }
555
556            // compute and print stats every NUM_OPS_TO_STATS operations to avoid spamming stdout
557            if msg_idx % NUM_OPS_TO_STATS == 0 {
558                log::info!(
559                    "ops={}/s read={}/s write={}/s",
560                    avg_ops / NUM_OPS_TO_STATS,
561                    human_bytes::human_bytes(avg_read as f64 / NUM_OPS_TO_STATS as f64),
562                    human_bytes::human_bytes(avg_write as f64 / NUM_OPS_TO_STATS as f64)
563                );
564                avg_ops = 0;
565                avg_write = 0;
566                avg_read = 0;
567            }
568            msg_idx += 1;
569        }
570
571        Ok(())
572    }
573
574    /// Run the worker server accept loop.
575    pub async fn run(&mut self) -> Result<()> {
576        while let Ok((socket, client)) = self.listener.accept().await {
577            let _ = socket.set_nodelay(true);
578            log::debug!("{} connected", &client);
579
580            let context = self.context.get_client_context();
581            tokio::spawn(async move {
582                if let Err(e) = Self::handle_master_client(socket, client, context).await {
583                    log::error!("{}", e);
584                }
585            });
586        }
587
588        Ok(())
589    }
590}