1use std::collections::HashSet;
13use std::io::Write;
14use std::path::{Path, PathBuf};
15use std::time::{Duration, Instant};
16
17use anyhow::Result;
18use tokio::net::{TcpListener, TcpStream};
19
20use super::auth;
21use super::discovery::{self, DiscoveredWorker};
22use super::proto::Message;
23use super::topology::{Node, Topology};
24
25fn layer_prefix_for_config(config_json: &serde_json::Value) -> String {
28 if let Some(archs) = config_json.get("architectures").and_then(|v| v.as_array()) {
29 for arch in archs {
30 if let Some(s) = arch.as_str() {
31 if s == "Qwen3_5ForConditionalGeneration" {
32 return "model.language_model.layers".to_string();
33 }
34 }
35 }
36 }
37 "model.layers".to_string()
38}
39
40const MODEL_DATA_CHUNK_SIZE: usize = 128 * 1024 * 1024;
42
43fn detect_free_gpu_memory() -> u64 {
46 #[cfg(feature = "cuda")]
47 {
48 if let Ok(output) = std::process::Command::new("nvidia-smi")
49 .args(["--query-gpu=memory.free", "--format=csv,noheader,nounits"])
50 .output()
51 {
52 if output.status.success() {
53 let stdout = String::from_utf8_lossy(&output.stdout);
54 return stdout
55 .lines()
56 .filter_map(|line| line.trim().parse::<u64>().ok())
57 .map(|mb| mb * 1024 * 1024)
58 .sum();
59 }
60 }
61 }
62 0
63}
64
65fn read_safetensors_tensor_sizes(path: &Path) -> Option<Vec<(String, u64)>> {
69 use std::io::Read;
70 let mut f = std::fs::File::open(path).ok()?;
71 let mut len_buf = [0u8; 8];
72 f.read_exact(&mut len_buf).ok()?;
73 let header_len = u64::from_le_bytes(len_buf) as usize;
74 if header_len > 10 * 1024 * 1024 {
76 return None;
77 }
78 let mut header_buf = vec![0u8; header_len];
79 f.read_exact(&mut header_buf).ok()?;
80 let header: serde_json::Value = serde_json::from_slice(&header_buf).ok()?;
81 let obj = header.as_object()?;
82 let mut result = Vec::with_capacity(obj.len());
83 for (name, meta) in obj {
84 if name == "__metadata__" {
85 continue;
86 }
87 if let Some(offsets) = meta.get("data_offsets").and_then(|v| v.as_array()) {
88 if offsets.len() == 2 {
89 let start = offsets[0].as_u64().unwrap_or(0);
90 let end = offsets[1].as_u64().unwrap_or(0);
91 result.push((name.clone(), end.saturating_sub(start)));
92 }
93 }
94 }
95 Some(result)
96}
97
98fn estimate_layer_size(model_path: &Path, num_layers: usize, layer_prefix: &str) -> u64 {
105 if num_layers == 0 {
106 return 0;
107 }
108
109 let layer_dot = format!("{}.", layer_prefix);
110
111 let index_path = model_path.join("model.safetensors.index.json");
113 if let Ok(data) = std::fs::read_to_string(&index_path) {
114 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
115 if let Some(weight_map) = json.get("weight_map").and_then(|v| v.as_object()) {
116 let shards: HashSet<&str> =
117 weight_map.values().filter_map(|v| v.as_str()).collect();
118
119 let mut layer_bytes: u64 = 0;
121 let mut total_bytes: u64 = 0;
122 let mut headers_ok = true;
123
124 for shard in &shards {
125 let shard_path = model_path.join(shard);
126 if let Some(tensors) = read_safetensors_tensor_sizes(&shard_path) {
127 for (name, size) in &tensors {
128 total_bytes += size;
129 if name.starts_with(&layer_dot) {
130 layer_bytes += size;
131 }
132 }
133 } else {
134 headers_ok = false;
135 break;
136 }
137 }
138
139 if headers_ok && layer_bytes > 0 {
140 let non_layer = total_bytes - layer_bytes;
141 if non_layer > 0 {
142 log::info!(
143 "model weights: {} total, {} layers, {} non-layer ({:.0}% excluded)",
144 human_bytes::human_bytes(total_bytes as f64),
145 human_bytes::human_bytes(layer_bytes as f64),
146 human_bytes::human_bytes(non_layer as f64),
147 non_layer as f64 / total_bytes as f64 * 100.0,
148 );
149 }
150 return layer_bytes / num_layers as u64;
151 }
152
153 let total: u64 = shards
155 .iter()
156 .filter_map(|s| std::fs::metadata(model_path.join(s)).ok())
157 .map(|m| m.len())
158 .sum();
159 return total / num_layers as u64;
160 }
161 }
162 }
163
164 let single = model_path.join("model.safetensors");
166 if let Ok(single_path) = single.canonicalize() {
167 if let Some(tensors) = read_safetensors_tensor_sizes(&single_path) {
169 let mut layer_bytes: u64 = 0;
170 for (name, size) in &tensors {
171 if name.starts_with(&layer_dot) {
172 layer_bytes += size;
173 }
174 }
175 if layer_bytes > 0 {
176 return layer_bytes / num_layers as u64;
177 }
178 }
179 if let Ok(m) = std::fs::metadata(&single_path) {
181 return m.len() / num_layers as u64;
182 }
183 }
184
185 0
186}
187
188pub fn compute_layer_assignments(
200 workers: &[DiscoveredWorker],
201 num_layers: usize,
202 master_tflops: f64,
203 layer_size_bytes: u64,
204 master_max_layers: usize,
205 layer_prefix: &str,
206) -> Vec<(usize, Vec<String>)> {
207 if workers.is_empty() || num_layers == 0 {
208 return vec![];
209 }
210
211 let total_tflops: f64 =
213 workers.iter().map(|w| w.total_tflops()).sum::<f64>() + master_tflops;
214
215 if total_tflops <= 0.0 {
216 let worker_layers = num_layers / 2;
218 let per_worker = worker_layers / workers.len();
219 let mut assignments = vec![];
220 let mut offset = 0;
221 for (i, _) in workers.iter().enumerate() {
222 let count = if i == workers.len() - 1 {
223 worker_layers - offset
224 } else {
225 per_worker
226 };
227 let layers: Vec<String> = (offset..offset + count)
228 .map(|l| format!("{layer_prefix}.{l}"))
229 .collect();
230 assignments.push((i, layers));
231 offset += count;
232 }
233 return assignments;
234 }
235
236 let mut indices: Vec<usize> = (0..workers.len()).collect();
238 indices.sort_by(|a, b| {
239 workers[*b]
240 .total_tflops()
241 .partial_cmp(&workers[*a].total_tflops())
242 .unwrap_or(std::cmp::Ordering::Equal)
243 });
244
245 let workers_tflops: f64 = workers.iter().map(|w| w.total_tflops()).sum();
247 let total_worker_layers =
248 (workers_tflops / total_tflops * num_layers as f64).round() as usize;
249 let total_worker_layers = total_worker_layers.min(num_layers);
250
251 log::info!(
252 "master: {:.1} TFLOPS — workers: {:.1} TFLOPS — assigning {} of {} layers to workers",
253 master_tflops,
254 workers_tflops,
255 total_worker_layers,
256 num_layers
257 );
258
259 let mut assignments = vec![];
260 let mut offset = 0;
261 let mut remaining_layers = total_worker_layers;
262 let mut remaining_tflops = workers_tflops;
263
264 for (pos, &worker_idx) in indices.iter().enumerate() {
265 if remaining_layers == 0 {
266 break;
267 }
268
269 let mut count = if pos == indices.len() - 1 {
270 remaining_layers
271 } else {
272 let worker_tflops = workers[worker_idx].total_tflops();
273 let proportional =
274 (worker_tflops / remaining_tflops * remaining_layers as f64).round() as usize;
275 proportional.max(1).min(remaining_layers)
276 };
277
278 if layer_size_bytes > 0 {
280 let max_layers = workers[worker_idx].max_layers_for_size(layer_size_bytes);
281 if count > max_layers {
282 log::info!(
283 " {} capped from {} to {} layers (VRAM limit: {} per layer)",
284 &workers[worker_idx].name,
285 count,
286 max_layers,
287 human_bytes::human_bytes(layer_size_bytes as f64)
288 );
289 count = max_layers;
290 }
291 }
292
293 let layers: Vec<String> = (offset..offset + count)
294 .map(|l| format!("{layer_prefix}.{l}"))
295 .collect();
296
297 assignments.push((worker_idx, layers));
298 offset += count;
299 remaining_layers -= count;
300 remaining_tflops -= workers[worker_idx].total_tflops();
301 }
302
303 let master_layers = num_layers - offset;
306 if master_max_layers < usize::MAX && master_layers > master_max_layers {
307 let deficit = master_layers - master_max_layers;
308 log::info!(
309 "master has {} local layers but can fit {} — redistributing {} to workers",
310 master_layers,
311 master_max_layers,
312 deficit
313 );
314
315 let mut extra_needed = deficit;
317 for (worker_idx, layers) in assignments.iter_mut() {
318 if extra_needed == 0 {
319 break;
320 }
321 let current = layers.len();
322 let max = if layer_size_bytes > 0 {
323 workers[*worker_idx].max_layers_for_size(layer_size_bytes)
324 } else {
325 usize::MAX
326 };
327 let spare = max.saturating_sub(current);
328 if spare > 0 {
329 let take = spare.min(extra_needed);
330 let new_start = offset;
332 for l in new_start..new_start + take {
333 layers.push(format!("{layer_prefix}.{l}"));
334 }
335 offset += take;
336 extra_needed -= take;
337 log::info!(
338 " {} takes {} extra layer(s) ({} → {} total)",
339 &workers[*worker_idx].name,
340 take,
341 current,
342 layers.len()
343 );
344 }
345 }
346
347 if extra_needed > 0 {
348 log::warn!(
349 "cluster cannot fit all {} layers — {} layer(s) unassignable (master VRAM too small, workers full)",
350 num_layers,
351 extra_needed
352 );
353 }
354 }
355
356 assignments
357}
358
359pub async fn master_setup(
367 cluster_key: &str,
368 model_path: &Path,
369 discovery_timeout: Duration,
370) -> Result<Topology> {
371 let config_path = model_path.join("config.json");
373 let config_data = std::fs::read_to_string(&config_path)
374 .map_err(|e| anyhow!("failed to read {}: {}", config_path.display(), e))?;
375 let model_hash = {
376 use sha2::{Digest, Sha256};
377 let mut hasher = Sha256::new();
378 hasher.update(config_data.as_bytes());
379 let result = hasher.finalize();
380 hex::encode(&result[..4])
381 };
382 let config_json: serde_json::Value = serde_json::from_str(&config_data)?;
383 let num_layers = config_json
384 .get("num_hidden_layers")
385 .and_then(|v| v.as_u64())
386 .or_else(|| {
387 config_json
389 .get("text_config")
390 .and_then(|tc| tc.get("num_hidden_layers"))
391 .and_then(|v| v.as_u64())
392 })
393 .ok_or_else(|| anyhow!("num_hidden_layers not found in config.json"))? as usize;
394
395 let layer_prefix = layer_prefix_for_config(&config_json);
397
398 log::info!(
399 "model has {} transformer layers (prefix: {})",
400 num_layers,
401 &layer_prefix,
402 );
403
404 let master_gpus = discovery::detect_gpus();
407 let master_tflops: f64 = master_gpus.iter().map(|g| g.tflops as f64).sum();
408 let free_gpu_fut = tokio::task::spawn_blocking(detect_free_gpu_memory);
409
410 let workers = discovery::discover_workers(cluster_key, discovery_timeout).await?;
412 if workers.is_empty() {
413 log::warn!("no workers discovered — all layers will be loaded locally");
414 return Ok(Topology::new());
415 }
416
417 let master_free_from_smi = free_gpu_fut.await.unwrap_or(0);
419
420 let layer_size_on_disk = estimate_layer_size(model_path, num_layers, &layer_prefix);
424 if layer_size_on_disk > 0 {
425 log::info!(
426 "estimated layer size (on disk): {}",
427 human_bytes::human_bytes(layer_size_on_disk as f64)
428 );
429 }
430
431 let dtype_bytes: u64 = 2; let vocab_size = config_json
436 .get("vocab_size")
437 .and_then(|v| v.as_u64())
438 .or_else(|| {
439 config_json
440 .get("text_config")
441 .and_then(|tc| tc.get("vocab_size"))
442 .and_then(|v| v.as_u64())
443 })
444 .unwrap_or(32000);
445 let hidden_size = config_json
446 .get("hidden_size")
447 .and_then(|v| v.as_u64())
448 .or_else(|| {
449 config_json
450 .get("text_config")
451 .and_then(|tc| tc.get("hidden_size"))
452 .and_then(|v| v.as_u64())
453 })
454 .unwrap_or(4096);
455 let tie_embeddings = config_json
456 .get("tie_word_embeddings")
457 .and_then(|v| v.as_bool())
458 .unwrap_or(false);
459
460 let embed_size = vocab_size * hidden_size * dtype_bytes;
461 let lm_head_size = if tie_embeddings { 0 } else { embed_size };
462 let master_overhead = embed_size + lm_head_size + 1024 * 1024 * 1024;
464
465 let is_fp8 = crate::utils::fp8::is_fp8_quantized(&config_path);
469 let layer_size_bytes = if is_fp8 && layer_size_on_disk > 0 {
470 let expanded = layer_size_on_disk * dtype_bytes; log::info!(
472 "FP8 model: layer size after dequantization: {} ({}x expansion)",
473 human_bytes::human_bytes(expanded as f64),
474 dtype_bytes,
475 );
476 expanded
477 } else {
478 layer_size_on_disk
479 };
480
481 log::info!(
482 "master overhead: embeddings={} lm_head={} total={}",
483 human_bytes::human_bytes(embed_size as f64),
484 human_bytes::human_bytes(lm_head_size as f64),
485 human_bytes::human_bytes(master_overhead as f64),
486 );
487
488 let master_max_layers = if layer_size_bytes > 0 && !master_gpus.is_empty() {
492 let master_vram: u64 = master_gpus.iter().map(|g| g.vram_bytes).sum();
493 let effective_vram = if master_free_from_smi > 0 && master_free_from_smi < master_vram {
494 log::info!(
495 "master GPU: {} total, {} free",
496 human_bytes::human_bytes(master_vram as f64),
497 human_bytes::human_bytes(master_free_from_smi as f64),
498 );
499 master_free_from_smi
500 } else {
501 master_vram
502 };
503 let available = effective_vram.saturating_sub(master_overhead);
504 let max = (available / layer_size_bytes) as usize;
505 log::info!(
506 "master GPU: {} available for layers — can fit ~{} layers locally",
507 human_bytes::human_bytes(available as f64),
508 max
509 );
510 max
511 } else {
512 usize::MAX
513 };
514
515 let assignments = compute_layer_assignments(
517 &workers,
518 num_layers,
519 master_tflops,
520 layer_size_bytes,
521 master_max_layers,
522 &layer_prefix,
523 );
524
525 let total_assigned: usize = assignments.iter().map(|(_, l)| l.len()).sum();
527 let master_layers = num_layers - total_assigned;
528 log::info!("layer assignments:");
529 for (worker_idx, layers) in &assignments {
530 let w = &workers[*worker_idx];
531 let range = if layers.is_empty() {
532 "(none)".to_string()
533 } else {
534 format!("{} — {}", layers.first().unwrap(), layers.last().unwrap())
535 };
536 let weight_load = layers.len() as u64 * layer_size_bytes;
537 log::info!(
538 " {} ({}, {:.1} TFLOPS) → {} layers ({}) [{}]",
539 &w.name,
540 human_bytes::human_bytes(w.total_vram() as f64),
541 w.total_tflops(),
542 layers.len(),
543 human_bytes::human_bytes(weight_load as f64),
544 range,
545 );
546 }
547 log::info!(
548 " master ({:.1} TFLOPS) → {} layers ({} weights + {} overhead)",
549 master_tflops,
550 master_layers,
551 human_bytes::human_bytes((master_layers as u64 * layer_size_bytes) as f64),
552 human_bytes::human_bytes(master_overhead as f64),
553 );
554 if layer_size_bytes > 0 {
555 log::info!(
556 "total weight read per token: {} ({} per layer × {})",
557 human_bytes::human_bytes((num_layers as u64 * layer_size_bytes) as f64),
558 human_bytes::human_bytes(layer_size_bytes as f64),
559 num_layers,
560 );
561 }
562
563 let mut handles = Vec::new();
565
566 for (worker_idx, layers) in &assignments {
567 let worker = workers[*worker_idx].clone();
568 if layers.is_empty() {
569 continue;
570 }
571
572 let layers = layers.clone();
573 let cluster_key = cluster_key.to_string();
574 let model_hash = model_hash.clone();
575 let model_path = model_path.to_path_buf();
576 let model_name = model_path
577 .file_name()
578 .unwrap_or_default()
579 .to_string_lossy()
580 .to_string();
581
582 handles.push(tokio::spawn(async move {
583 log::info!(
584 "connecting to worker '{}' at {} ...",
585 &worker.name,
586 &worker.host
587 );
588
589 let mut stream = TcpStream::connect(&worker.host)
590 .await
591 .map_err(|e| anyhow!("can't connect to {}: {}", &worker.host, e))?;
592 let _ = stream.set_nodelay(true);
593
594 auth::authenticate_as_master(&mut stream, &cluster_key).await?;
596 log::info!("[{}] authenticated", &worker.name);
597
598 let msg = Message::LayerAssignment {
600 layers: layers.clone(),
601 model_hash,
602 };
603 msg.to_writer(&mut stream).await?;
604
605 let (_, ack) = Message::from_reader(&mut stream).await?;
607 let needs_data = match ack {
608 Message::LayerAssignmentAck { needs_data } => needs_data,
609 other => {
610 return Err(anyhow!(
611 "[{}] unexpected response to LayerAssignment: {:?}",
612 &worker.name,
613 other
614 ))
615 }
616 };
617
618 if needs_data {
619 push_model_data(&mut stream, &model_path, &layers, &worker.name, &model_name).await?;
620 } else {
621 log::info!("[{}] worker has model data cached", &worker.name);
622 }
623
624 let (_, ready) = Message::from_reader(&mut stream).await?;
626 if !matches!(ready, Message::WorkerReady) {
627 return Err(anyhow!(
628 "[{}] expected WorkerReady, got {:?}",
629 &worker.name,
630 ready
631 ));
632 }
633 log::info!("[{}] worker ready", &worker.name);
634
635 Ok::<_, anyhow::Error>((worker, layers))
636 }));
637 }
638
639 let mut topology = Topology::new();
641 for handle in handles {
642 let (worker, layers) = handle.await??;
643 topology.insert(
644 worker.name.clone(),
645 Node {
646 host: worker.host.clone(),
647 description: Some(
648 worker
649 .gpus
650 .iter()
651 .map(|g| g.name.clone())
652 .collect::<Vec<_>>()
653 .join(", "),
654 ),
655 layers,
656 vram_bytes: worker.total_vram(),
657 tflops: worker.total_tflops(),
658 backend: worker.backend.clone(),
659 hostname: worker.hostname.clone(),
660 os: worker.os.clone(),
661 },
662 );
663 }
664
665 Ok(topology)
666}
667
668async fn push_model_data(
670 stream: &mut TcpStream,
671 model_path: &Path,
672 layers: &[String],
673 worker_name: &str,
674 model_name: &str,
675) -> Result<()> {
676 let overall_start = Instant::now();
677 let mut overall_bytes: u64 = 0;
678
679 let layer_range = if layers.is_empty() {
680 "(none)".to_string()
681 } else {
682 format!(
683 "{} — {} ({} layers)",
684 layers.first().unwrap(),
685 layers.last().unwrap(),
686 layers.len()
687 )
688 };
689
690 log::info!(
691 "[{}] pushing {} [{}]",
692 worker_name,
693 model_name,
694 layer_range
695 );
696
697 let mut files_to_send: Vec<PathBuf> = vec![
699 model_path.join("config.json"),
700 model_path.join("tokenizer.json"),
701 ];
702
703 let index_path = model_path.join("model.safetensors.index.json");
705 let mut filtered_index: Option<Vec<u8>> = None;
706 if index_path.exists() {
707 files_to_send.push(index_path.clone());
708 let index_data = std::fs::read(&index_path)?;
709 let mut index_json: serde_json::Value = serde_json::from_slice(&index_data)?;
710 let weight_map = index_json
711 .get("weight_map")
712 .and_then(|v| v.as_object())
713 .ok_or_else(|| anyhow!("no weight_map in model.safetensors.index.json"))?
714 .clone();
715
716 let mut needed_shards: HashSet<String> = HashSet::new();
718 let mut needed_weights: serde_json::Map<String, serde_json::Value> =
719 serde_json::Map::new();
720 for (tensor_name, shard_file) in &weight_map {
721 for layer in layers {
722 if tensor_name.starts_with(&format!("{}.", layer)) {
723 if let Some(filename) = shard_file.as_str() {
724 needed_shards.insert(filename.to_string());
725 }
726 needed_weights.insert(tensor_name.clone(), shard_file.clone());
727 }
728 }
729 }
730
731 if let Some(obj) = index_json.as_object_mut() {
733 obj.insert(
734 "weight_map".to_string(),
735 serde_json::Value::Object(needed_weights),
736 );
737 }
738 filtered_index = Some(serde_json::to_vec_pretty(&index_json)?);
739
740 log::info!(
741 "[{}] pushing {} shard file(s) + config + tokenizer + index",
742 worker_name,
743 needed_shards.len()
744 );
745
746 for shard in &needed_shards {
747 files_to_send.push(model_path.join(shard));
748 }
749 } else {
750 let single = model_path.join("model.safetensors");
752 if single.exists() {
753 files_to_send.push(single);
754 }
755 }
756
757 for file_path in &files_to_send {
759 let filename = file_path
760 .file_name()
761 .unwrap_or_default()
762 .to_string_lossy()
763 .to_string();
764
765 let file_data = if filename == "model.safetensors.index.json" {
767 if let Some(ref data) = filtered_index {
768 data.clone()
769 } else {
770 std::fs::read(file_path)
771 .map_err(|e| anyhow!("failed to read {}: {}", file_path.display(), e))?
772 }
773 } else {
774 std::fs::read(file_path)
775 .map_err(|e| anyhow!("failed to read {}: {}", file_path.display(), e))?
776 };
777 let total_size = file_data.len() as u64;
778 let file_start = Instant::now();
779 let mut offset: u64 = 0;
780
781 log::info!(
782 "[{}] sending {} ({}) ...",
783 worker_name,
784 &filename,
785 human_bytes::human_bytes(total_size as f64)
786 );
787
788 for chunk in file_data.chunks(MODEL_DATA_CHUNK_SIZE) {
789 let msg = Message::ModelDataChunk {
790 filename: filename.clone(),
791 offset,
792 total_size,
793 data: chunk.to_vec(),
794 };
795 msg.to_writer(stream).await?;
796 offset += chunk.len() as u64;
797
798 if total_size > MODEL_DATA_CHUNK_SIZE as u64 {
800 let elapsed = file_start.elapsed().as_secs_f64();
801 let speed = offset as f64 / elapsed;
802 let pct = (offset as f64 / total_size as f64) * 100.0;
803 let remaining = total_size - offset;
804 let eta_secs = if speed > 0.0 {
805 remaining as f64 / speed
806 } else {
807 0.0
808 };
809 log::info!(
810 "[{}] {} — {}/{} ({:.1}%) — {}/s — ETA {:.0}s",
811 worker_name,
812 &filename,
813 human_bytes::human_bytes(offset as f64),
814 human_bytes::human_bytes(total_size as f64),
815 pct,
816 human_bytes::human_bytes(speed),
817 eta_secs
818 );
819 }
820 }
821
822 let file_elapsed = file_start.elapsed();
823 let file_speed = total_size as f64 / file_elapsed.as_secs_f64();
824 overall_bytes += total_size;
825
826 log::info!(
827 "[{}] sent {} ({}) in {:.1}s — {}/s",
828 worker_name,
829 &filename,
830 human_bytes::human_bytes(total_size as f64),
831 file_elapsed.as_secs_f64(),
832 human_bytes::human_bytes(file_speed)
833 );
834 }
835
836 Message::ModelDataDone.to_writer(stream).await?;
838
839 let overall_elapsed = overall_start.elapsed();
840 let overall_speed = overall_bytes as f64 / overall_elapsed.as_secs_f64();
841 log::info!(
842 "[{}] transfer complete: {} in {:.1}s — {}/s avg",
843 worker_name,
844 human_bytes::human_bytes(overall_bytes as f64),
845 overall_elapsed.as_secs_f64(),
846 human_bytes::human_bytes(overall_speed)
847 );
848
849 Ok(())
850}
851
852fn has_valid_model_cache(cache_dir: &Path, layers: &[String]) -> bool {
857 if !cache_dir.join("config.json").exists() {
858 return false;
859 }
860 if cache_dir.join("model.safetensors").exists() {
862 return true;
863 }
864 let index_path = cache_dir.join("model.safetensors.index.json");
866 if index_path.exists() {
867 if let Ok(data) = std::fs::read_to_string(&index_path) {
868 if let Ok(index) = serde_json::from_str::<serde_json::Value>(&data) {
869 if let Some(weight_map) = index.get("weight_map").and_then(|v| v.as_object()) {
870 for layer in layers {
873 let prefix = format!("{}.", layer);
874 let has_layer = weight_map.iter().any(|(tensor_name, shard_file)| {
875 tensor_name.starts_with(&prefix)
876 && shard_file
877 .as_str()
878 .is_some_and(|f| cache_dir.join(f).exists())
879 });
880 if !has_layer {
881 log::debug!(
882 "cache miss: {} not found in {}",
883 layer,
884 cache_dir.display()
885 );
886 return false;
887 }
888 }
889 return true;
890 }
891 }
892 }
893 }
894 false
895}
896
897pub type SetupProgressFn = dyn Fn(&str, &str, f64) + Send + Sync;
913
914pub async fn worker_setup(
915 worker_name: &str,
916 cluster_key: &str,
917 bind_address: &str,
918 model_cache_dir: &Path,
919) -> Result<(Vec<String>, PathBuf, TcpListener)> {
920 worker_setup_with_progress(worker_name, cluster_key, bind_address, model_cache_dir, None).await
921}
922
923pub async fn worker_setup_with_progress(
924 worker_name: &str,
925 cluster_key: &str,
926 bind_address: &str,
927 model_cache_dir: &Path,
928 on_progress: Option<&SetupProgressFn>,
929) -> Result<(Vec<String>, PathBuf, TcpListener)> {
930 let gpus = discovery::detect_gpus();
932 log::info!("detected {} GPU(s):", gpus.len());
933 for gpu in &gpus {
934 log::info!(
935 " {} — {} (~{:.1} TFLOPS)",
936 &gpu.name,
937 human_bytes::human_bytes(gpu.vram_bytes as f64),
938 gpu.tflops
939 );
940 }
941
942 let listener = TcpListener::bind(bind_address).await?;
944 let port = listener.local_addr()?.port();
945 log::info!("listening on {} (setup mode)", bind_address);
946
947 let _discovery = discovery::advertise_worker(worker_name, port, cluster_key, &gpus)?;
949
950 log::info!("waiting for master to connect and assign layers...");
951 if let Some(cb) = &on_progress {
952 cb("discovery", "Waiting for master...", 0.0);
953 }
954
955 let (mut stream, client_addr) = listener.accept().await?;
957 let _ = stream.set_nodelay(true);
958 log::info!("[{}] master connected", client_addr);
959 if let Some(cb) = &on_progress {
960 cb("connected", &format!("Master connected ({})", client_addr), 0.0);
961 }
962
963 auth::authenticate_as_worker(&mut stream, cluster_key).await?;
965 log::info!("[{}] authenticated", client_addr);
966 if let Some(cb) = &on_progress {
967 cb("authenticated", "Authenticated with master", 0.0);
968 }
969
970 let (_, msg) = Message::from_reader(&mut stream).await?;
972 let (layers, model_hash) = match msg {
973 Message::LayerAssignment {
974 layers,
975 model_hash,
976 } => (layers, model_hash),
977 other => {
978 return Err(anyhow!(
979 "expected LayerAssignment, got {:?}",
980 other
981 ))
982 }
983 };
984
985 log::info!("assigned {} layers:", layers.len());
986 for layer in &layers {
987 log::info!(" {}", layer);
988 }
989 if let Some(cb) = &on_progress {
990 cb("layers", &format!("Assigned {} layer(s)", layers.len()), 0.0);
991 }
992
993 let cluster_id = discovery::cluster_hash(cluster_key);
996 let cache_dir = if model_hash.is_empty() {
997 model_cache_dir.join(&cluster_id)
999 } else {
1000 model_cache_dir.join(format!("{}-{}", cluster_id, model_hash))
1001 };
1002 std::fs::create_dir_all(&cache_dir)?;
1003
1004 let needs_data = !has_valid_model_cache(&cache_dir, &layers);
1006
1007 let ack = Message::LayerAssignmentAck { needs_data };
1008 ack.to_writer(&mut stream).await?;
1009
1010 if needs_data {
1011 if let Some(cb) = &on_progress {
1012 cb("receiving", "Receiving model data...", 0.0);
1013 }
1014 receive_model_data(&mut stream, &cache_dir, &layers, on_progress).await?;
1015 } else {
1016 log::info!("using cached model data from {}", cache_dir.display());
1017 if let Some(cb) = &on_progress {
1018 cb("cached", "Using cached model data", 1.0);
1019 }
1020 }
1021
1022 Message::WorkerReady.to_writer(&mut stream).await?;
1024 log::info!("setup complete, ready for inference");
1025 if let Some(cb) = &on_progress {
1026 cb("ready", "Setup complete", 1.0);
1027 }
1028
1029 Ok((layers, cache_dir, listener))
1032}
1033
1034async fn receive_model_data(
1036 stream: &mut TcpStream,
1037 cache_dir: &Path,
1038 layers: &[String],
1039 on_progress: Option<&SetupProgressFn>,
1040) -> Result<()> {
1041 let overall_start = Instant::now();
1042 let mut overall_bytes: u64 = 0;
1043 let mut current_file: Option<(String, std::fs::File, Instant, u64)> = None;
1044
1045 let layer_range = if layers.is_empty() {
1046 "(none)".to_string()
1047 } else {
1048 format!(
1049 "{} — {} ({} layers)",
1050 layers.first().unwrap(),
1051 layers.last().unwrap(),
1052 layers.len()
1053 )
1054 };
1055
1056 log::info!("receiving model data [{}] ...", layer_range);
1057
1058 loop {
1059 let (_, msg) = Message::from_reader(stream).await?;
1060
1061 match msg {
1062 Message::ModelDataChunk {
1063 filename,
1064 offset,
1065 total_size,
1066 data,
1067 } => {
1068 let file = if let Some((ref name, ref mut file, _, _)) = current_file {
1070 if name == &filename {
1071 file
1072 } else {
1073 if let Some((prev_name, _, start, size)) = current_file.take() {
1075 let elapsed = start.elapsed();
1076 let speed = size as f64 / elapsed.as_secs_f64();
1077 log::info!(
1078 "received {} ({}) — {}/s",
1079 &prev_name,
1080 human_bytes::human_bytes(size as f64),
1081 human_bytes::human_bytes(speed)
1082 );
1083 if let Some(cb) = &on_progress {
1084 cb("receiving", &format!("{} complete", &prev_name), 1.0);
1085 }
1086 }
1087 let path = cache_dir.join(&filename);
1088 let f = std::fs::File::create(&path)?;
1089 current_file = Some((filename.clone(), f, Instant::now(), total_size));
1090 if let Some(cb) = &on_progress {
1091 cb("receiving", &format!("Receiving {} ({})", &filename, human_bytes::human_bytes(total_size as f64)), 0.0);
1092 }
1093 &mut current_file.as_mut().unwrap().1
1094 }
1095 } else {
1096 let path = cache_dir.join(&filename);
1097 let f = std::fs::File::create(&path)?;
1098 current_file = Some((filename.clone(), f, Instant::now(), total_size));
1099 if let Some(cb) = &on_progress {
1100 cb("receiving", &format!("Receiving {} ({})", &filename, human_bytes::human_bytes(total_size as f64)), 0.0);
1101 }
1102 &mut current_file.as_mut().unwrap().1
1103 };
1104
1105 file.write_all(&data)?;
1106 overall_bytes += data.len() as u64;
1107
1108 let written = offset + data.len() as u64;
1110 if let Some((_, _, ref start, _)) = current_file {
1111 let elapsed = start.elapsed().as_secs_f64();
1112 let speed = if elapsed > 0.0 { written as f64 / elapsed } else { 0.0 };
1113 let pct = if total_size > 0 { (written as f64 / total_size as f64) * 100.0 } else { 0.0 };
1114
1115 if total_size > MODEL_DATA_CHUNK_SIZE as u64 && written < total_size {
1116 let remaining = total_size - written;
1117 let eta_secs = if speed > 0.0 { remaining as f64 / speed } else { 0.0 };
1118 log::info!(
1119 " {} — {}/{} ({:.1}%) — {}/s — ETA {:.0}s",
1120 &filename,
1121 human_bytes::human_bytes(written as f64),
1122 human_bytes::human_bytes(total_size as f64),
1123 pct,
1124 human_bytes::human_bytes(speed),
1125 eta_secs
1126 );
1127 if let Some(cb) = &on_progress {
1128 let msg = format!(
1129 "{} — {}/{} — {}/s — ETA {:.0}s",
1130 &filename,
1131 human_bytes::human_bytes(written as f64),
1132 human_bytes::human_bytes(total_size as f64),
1133 human_bytes::human_bytes(speed),
1134 eta_secs
1135 );
1136 cb("receiving", &msg, pct / 100.0);
1137 }
1138 }
1139 }
1140 }
1141 Message::ModelDataDone => {
1142 if let Some((name, _, start, size)) = current_file.take() {
1144 let elapsed = start.elapsed();
1145 let speed = size as f64 / elapsed.as_secs_f64();
1146 log::info!(
1147 "received {} ({}) — {}/s",
1148 &name,
1149 human_bytes::human_bytes(size as f64),
1150 human_bytes::human_bytes(speed)
1151 );
1152 }
1153 break;
1154 }
1155 other => {
1156 return Err(anyhow!(
1157 "unexpected message during data transfer: {:?}",
1158 other
1159 ));
1160 }
1161 }
1162 }
1163
1164 let overall_elapsed = overall_start.elapsed();
1165 let overall_speed = overall_bytes as f64 / overall_elapsed.as_secs_f64();
1166 log::info!(
1167 "model data received: {} in {:.1}s — {}/s avg, cached to {}",
1168 human_bytes::human_bytes(overall_bytes as f64),
1169 overall_elapsed.as_secs_f64(),
1170 human_bytes::human_bytes(overall_speed),
1171 cache_dir.display()
1172 );
1173
1174 Ok(())
1175}