1use std::{
2 collections::HashMap,
3 net::SocketAddr,
4 sync::Arc,
5 time::{Duration, Instant},
6};
7
8use super::{Context, Forwarder, Message, WorkerInfo};
9use crate::models::Generator;
10
11use anyhow::Result;
12use candle_core::{DType, Device};
13use tokio::{
14 io::{AsyncReadExt, AsyncWriteExt},
15 net::{TcpListener, TcpStream},
16};
17
18const NUM_OPS_TO_STATS: usize = 5;
20
21#[derive(Clone)]
23struct WorkerContext<F> {
24 device: Device,
25 device_idx: usize,
26 dtype: DType,
27 blocks: Arc<HashMap<String, Box<F>>>,
28 layer_devices: Arc<HashMap<String, Device>>,
30 context: Context,
31}
32
33impl<F: Forwarder> WorkerContext<F> {
34 fn to_info(&self, latency: u128) -> WorkerInfo {
36 WorkerInfo {
37 version: env!("CARGO_PKG_VERSION").to_string(),
38 os: std::env::consts::OS.to_string(),
39 arch: std::env::consts::ARCH.to_string(),
40 device: if self.device.is_cuda() {
41 "cuda".to_string()
42 } else if self.device.is_metal() {
43 "metal".to_string()
44 } else {
45 "cpu".to_string()
46 },
47 device_idx: self.device_idx,
48 latency,
49 dtype: format!("{:?}", self.dtype),
50 }
51 }
52
53 fn get_client_context(&self) -> Self {
55 let cache = self.context.cache.as_ref().map(|cache| cache.as_new());
56
57 let mut cloned_context = self.context.clone();
58 cloned_context.cache = cache;
59
60 WorkerContext {
61 device: self.device.clone(),
62 device_idx: self.device_idx,
63 dtype: self.dtype,
64 blocks: self.blocks.clone(),
65 layer_devices: self.layer_devices.clone(),
66 context: cloned_context,
68 }
69 }
70}
71
72pub struct Worker<G: Generator> {
74 listener: TcpListener,
75 context: WorkerContext<G::Shardable>,
76}
77
78impl<G: Generator + 'static> Worker<G> {
79 fn detect_cuda_device_count() -> usize {
81 #[cfg(feature = "cuda")]
82 {
83 let mut count = 0;
85 while Device::new_cuda(count).is_ok() {
86 count += 1;
87 }
88 count
89 }
90 #[cfg(not(feature = "cuda"))]
91 {
92 0
93 }
94 }
95
96 pub async fn new(ctx: &mut Context) -> Result<Self> {
98 let worker_name = if let Some(name) = &ctx.args.name {
99 name.to_string()
100 } else {
101 return Err(anyhow!("no --name provided for worker"));
102 };
103
104 let worker_topology = if let Some(node) = ctx.topology.get(&worker_name) {
105 node
106 } else if !ctx.topology.is_empty() {
107 let first = ctx.topology.keys().next().unwrap();
108 log::warn!(
109 "topology for worker name '{}' not found, using '{}'",
110 &worker_name,
111 first
112 );
113 ctx.topology.get(first).unwrap()
114 } else {
115 return Err(anyhow!(
116 "could not find topology for {worker_name} and topology file is empty"
117 ));
118 };
119
120 let num_gpus = if ctx.device.is_cuda() {
122 Self::detect_cuda_device_count().max(1)
123 } else {
124 1
125 };
126
127 let use_multi_gpu = num_gpus > 1 && worker_topology.layers.len() > 1;
128
129 if use_multi_gpu {
130 log::info!(
131 "detected {} CUDA devices, splitting {} layers across GPUs",
132 num_gpus,
133 worker_topology.layers.len()
134 );
135 }
136
137 let mut blocks = HashMap::new();
138 let mut layer_devices: HashMap<String, Device> = HashMap::new();
139
140 if use_multi_gpu {
141 let model_index = ctx.data_path.join("model.safetensors.index.json");
142
143 let mut gpu_layer_groups: Vec<Vec<String>> = vec![vec![]; num_gpus];
145 for (i, name) in worker_topology.layers.iter().enumerate() {
146 let gpu_idx = i * num_gpus / worker_topology.layers.len();
147 gpu_layer_groups[gpu_idx].push(name.clone());
148 }
149
150 let mut gpu_devices: Vec<Device> = Vec::new();
152 let mut gpu_var_builders: Vec<candle_nn::VarBuilder<'static>> = Vec::new();
153
154 for ordinal in 0..num_gpus {
155 let dev = Device::new_cuda(ordinal)?;
156
157 #[cfg(feature = "cuda")]
158 if let Device::Cuda(cuda_dev) = &dev {
159 unsafe {
160 cuda_dev.disable_event_tracking();
161 }
162 }
163
164 let vb = crate::utils::load_var_builder_for_specific_layers(
165 model_index.clone(),
166 ctx.dtype,
167 dev.clone(),
168 &gpu_layer_groups[ordinal],
169 ctx.fp8,
170 )?;
171 log::info!(" GPU {} ready", ordinal);
172 gpu_devices.push(dev);
173 gpu_var_builders.push(vb);
174 }
175
176 let mut handles = Vec::new();
178 for gpu_idx in 0..num_gpus {
179 let dev = gpu_devices[gpu_idx].clone();
180 let vb = gpu_var_builders[gpu_idx].clone();
181 let layers = std::mem::take(&mut gpu_layer_groups[gpu_idx]);
182 let mut thread_ctx = ctx.clone();
183 thread_ctx.device = dev.clone();
184 thread_ctx.var_builder = Some(vb);
185
186 handles.push(std::thread::spawn(
187 move || -> Result<Vec<(String, Device, Box<G::Shardable>)>> {
188 #[cfg(feature = "cuda")]
189 if let Device::Cuda(ref cuda_dev) = dev {
190 cuda_dev
191 .cuda_stream()
192 .context()
193 .bind_to_thread()
194 .map_err(|e| {
195 anyhow!(
196 "failed to bind CUDA context for GPU {gpu_idx}: {e:?}"
197 )
198 })?;
199 }
200
201 let mut results = Vec::new();
202 for layer_name in layers {
203 log::info!("loading {} on cuda:{} ...", &layer_name, gpu_idx);
204 let block =
205 G::Shardable::load(layer_name.clone(), &thread_ctx)?;
206 results.push((layer_name, dev.clone(), block));
207 }
208 Ok(results)
209 },
210 ));
211 }
212
213 for handle in handles {
215 let results = handle
216 .join()
217 .map_err(|_| anyhow!("GPU loading thread panicked"))??;
218 for (name, dev, block) in results {
219 layer_devices.insert(name.clone(), dev);
220 blocks.insert(name, block);
221 }
222 }
223 } else {
224 for block_layer_name in worker_topology.layers.iter() {
225 log::info!("loading {} ...", &block_layer_name);
226
227 let block = G::Shardable::load(block_layer_name.to_string(), ctx)?;
228 layer_devices.insert(block_layer_name.to_string(), ctx.device.clone());
229 blocks.insert(block_layer_name.to_string(), block);
230 }
231 }
232
233 let blocks = Arc::new(blocks);
234 let layer_devices = Arc::new(layer_devices);
235
236 let listener = {
237 let taken = ctx.listener_override.lock().unwrap().take();
238 if let Some(existing) = taken {
239 existing
240 } else {
241 TcpListener::bind(&ctx.args.address).await?
242 }
243 };
244
245 log::info!(
246 "listening on {} (mem:{}) ...",
247 &ctx.args.address,
248 human_bytes::human_bytes(memory_stats::memory_stats().map(|m| m.physical_mem).unwrap_or(0) as f64)
249 );
250
251 let device = ctx.device.clone();
252 let dtype = ctx.dtype;
253 let device_idx = ctx.args.device;
254
255 let context = WorkerContext {
256 device,
257 device_idx,
258 dtype,
259 blocks,
260 layer_devices,
261 context: ctx.clone(),
262 };
263
264 Ok(Self { listener, context })
265 }
266
267 async fn read_message_timed<R>(mut socket: R) -> Result<(Duration, usize, Message)>
269 where
270 R: AsyncReadExt + Unpin,
271 {
272 let start = Instant::now();
273 let (size, message) = Message::from_reader(&mut socket).await?;
274 let latency = start.elapsed();
275
276 Ok((latency, size, message))
277 }
278
279 async fn write_message_timed<W>(mut socket: W, message: Message) -> Result<(Duration, usize)>
281 where
282 W: AsyncWriteExt + Unpin,
283 {
284 let start = Instant::now();
285 let size = message.to_writer(&mut socket).await?;
286 let latency = start.elapsed();
287
288 Ok((latency, size))
289 }
290
291 async fn handle_master_client(
293 mut socket: TcpStream,
294 client: SocketAddr,
295 mut context: WorkerContext<G::Shardable>,
296 ) -> Result<()> {
297 if let Some(ref cluster_key) = context.context.args.cluster_key {
299 super::auth::authenticate_as_worker(&mut socket, cluster_key)
300 .await
301 .map_err(|e| anyhow!("[{}] authentication failed: {}", &client, e))?;
302 log::debug!("[{}] authenticated", &client);
303 }
304
305 let (latency, _size, first_msg) = Self::read_message_timed(&mut socket).await?;
307 match first_msg {
308 Message::Hello => { }
309 Message::LayerAssignment { ref layers, .. } => {
310 log::info!(
314 "[{}] master re-setup: accepting {} layer assignment(s)",
315 &client,
316 layers.len()
317 );
318 let ack = Message::LayerAssignmentAck { needs_data: false };
319 ack.to_writer(&mut socket).await?;
320 Message::WorkerReady.to_writer(&mut socket).await?;
321 log::info!("[{}] re-setup complete, closing setup connection", &client);
322 return Ok(());
323 }
324 other => {
325 return Err(anyhow!(
326 "[{}] unexpected first message (expected Hello): {:?}",
327 &client,
328 other
329 ));
330 }
331 }
332
333 if let Err(e) = Self::write_message_timed(
335 &mut socket,
336 Message::WorkerInfo(context.to_info(latency.as_millis())),
337 )
338 .await
339 {
340 return Err(anyhow!("[{}] could not send worker info: {:?}", &client, e));
341 }
342
343 let mut msg_idx = 0;
344 let mut avg_ops = 0;
345 let mut avg_write = 0;
346 let mut avg_read = 0;
347 let mut read_buf = Vec::new();
348 let mut write_buf = Vec::new();
349
350 while let Ok((read_time, read_size, op_message)) = {
352 let start = Instant::now();
353 Message::from_reader_buf(&mut socket, &mut read_buf)
354 .await
355 .map(|(size, msg)| (start.elapsed(), size, msg))
356 } {
357 if matches!(op_message, Message::Goodbye) {
358 log::debug!("[{}] goodbye", &client);
359 context
360 .context
361 .cache
362 .as_mut()
363 .expect("No cache specified")
364 .clear();
365
366 if let Err(e) = Self::write_message_timed(
368 &mut socket,
369 Message::WorkerInfo(context.to_info(read_time.as_millis())),
370 )
371 .await
372 {
373 return Err(anyhow!("[{}] could not send worker info: {:?}", &client, e));
374 }
375
376 continue;
377 }
378
379 let (x, ops) = match op_message {
380 Message::SingleOp {
382 layer_name,
383 x,
384 index_pos,
385 block_idx,
386 } => (x, vec![(layer_name, index_pos, block_idx)]),
387 Message::Batch { x, batch } => (x, batch),
389 _ => {
390 return Err(anyhow!(
391 "[{}] unhandled message in loop: {:?}",
392 &client,
393 op_message
394 ));
395 }
396 };
397
398 let load_start = Instant::now();
400 let first_device = ops
401 .first()
402 .and_then(|(name, _, _)| context.layer_devices.get(name))
403 .unwrap_or(&context.device);
404
405 #[cfg(feature = "cuda")]
407 if let Device::Cuda(cuda_dev) = first_device {
408 if let Err(e) = cuda_dev.cuda_stream().context().bind_to_thread() {
409 log::error!("[{client}] failed to bind CUDA context: {:?}", e);
410 }
411 }
412
413 let mut x = match x.to_tensor(first_device) {
414 Ok(t) => t,
415 Err(e) => {
416 let msg = format!("failed to load tensor to device: {e}");
417 log::error!("[{}] {}", &client, &msg);
418 let _ = Self::write_message_timed(
419 &mut socket,
420 Message::WorkerError { message: msg },
421 )
422 .await;
423 continue;
424 }
425 };
426
427 let load_elapsed = load_start.elapsed();
428
429 let num_ops = ops.len();
430 let start_ops = Instant::now();
431
432 let mut batch_error = false;
433
434 for (layer_name, index_pos, block_idx) in ops {
436 if let Some(block_device) = context.layer_devices.get(&layer_name) {
438 #[cfg(feature = "cuda")]
440 if let Device::Cuda(cuda_dev) = block_device {
441 if let Err(e) = cuda_dev.cuda_stream().context().bind_to_thread() {
442 log::error!(
443 "[{client}] failed to bind CUDA context for {}: {:?}",
444 &layer_name,
445 e
446 );
447 }
448 }
449
450 x = match x.to_device(block_device) {
451 Ok(t) => t,
452 Err(e) => {
453 let msg = format!(
454 "failed to move tensor to device for layer {}: {e}",
455 &layer_name
456 );
457 log::error!("[{}] {}", &client, &msg);
458 let _ = Self::write_message_timed(
459 &mut socket,
460 Message::WorkerError { message: msg },
461 )
462 .await;
463 batch_error = true;
464 break;
465 }
466 };
467 }
468
469 if let Some(block) = context.blocks.get(&layer_name) {
471 x = match block
473 .forward(&x, index_pos, block_idx, &mut context.context)
474 .await
475 {
476 Ok(t) => {
477 if t.device().is_metal() {
480 let _ = t.device().synchronize();
481 }
482 t
483 }
484 Err(e) => {
485 let msg = format!(
486 "forward pass failed for layer {} (block_idx={}): {e}",
487 &layer_name, block_idx
488 );
489 log::error!("[{}] {}", &client, &msg);
490 let _ = Self::write_message_timed(
491 &mut socket,
492 Message::WorkerError { message: msg },
493 )
494 .await;
495 batch_error = true;
496 break;
497 }
498 };
499 } else {
500 let msg = format!("could not find layer {}", &layer_name);
501 log::error!("[{}] {}", &client, &msg);
502 let _ = Self::write_message_timed(
503 &mut socket,
504 Message::WorkerError { message: msg },
505 )
506 .await;
507 batch_error = true;
508 break;
509 }
510 }
511
512 if batch_error {
513 continue;
514 }
515
516 let elaps_ops = start_ops.elapsed();
517
518 let ser_start = Instant::now();
520 let resp_msg = Message::from_tensor(&x);
521 let ser_elapsed = ser_start.elapsed();
522
523 let write_start = Instant::now();
525 match resp_msg.to_writer_buf(&mut socket, &mut write_buf).await {
526 Ok(written) => {
527 let elaps_write = write_start.elapsed();
528 log::debug!(
529 "[{}] read={:.1}ms load={:.1}ms fwd={:.1}ms ser={:.1}ms write={:.1}ms ({} ops)",
530 &client,
531 read_time.as_secs_f64() * 1000.0,
532 load_elapsed.as_secs_f64() * 1000.0,
533 elaps_ops.as_secs_f64() * 1000.0,
534 ser_elapsed.as_secs_f64() * 1000.0,
535 elaps_write.as_secs_f64() * 1000.0,
536 num_ops,
537 );
538
539 let ops_per_sec = (num_ops as f64 / elaps_ops.as_secs_f64()) as usize;
540 let write_bytes_per_sec = (written as f64 / elaps_write.as_secs_f64()) as usize;
541 let read_bytes_per_sec = (read_size as f64 / read_time.as_secs_f64()) as usize;
542
543 avg_ops += ops_per_sec;
544 avg_write += write_bytes_per_sec;
545 avg_read += read_bytes_per_sec;
546 }
547 Err(e) => {
548 return Err(anyhow!(
549 "[{}] could not send response tensor: {:?}",
550 &client,
551 e
552 ));
553 }
554 }
555
556 if msg_idx % NUM_OPS_TO_STATS == 0 {
558 log::info!(
559 "ops={}/s read={}/s write={}/s",
560 avg_ops / NUM_OPS_TO_STATS,
561 human_bytes::human_bytes(avg_read as f64 / NUM_OPS_TO_STATS as f64),
562 human_bytes::human_bytes(avg_write as f64 / NUM_OPS_TO_STATS as f64)
563 );
564 avg_ops = 0;
565 avg_write = 0;
566 avg_read = 0;
567 }
568 msg_idx += 1;
569 }
570
571 Ok(())
572 }
573
574 pub async fn run(&mut self) -> Result<()> {
576 while let Ok((socket, client)) = self.listener.accept().await {
577 let _ = socket.set_nodelay(true);
578 log::debug!("{} connected", &client);
579
580 let context = self.context.get_client_context();
581 tokio::spawn(async move {
582 if let Err(e) = Self::handle_master_client(socket, client, context).await {
583 log::error!("{}", e);
584 }
585 });
586 }
587
588 Ok(())
589 }
590}