Skip to main content

cake_core/cake/
setup.rs

1//! Zero-config setup protocol for master and worker nodes.
2//!
3//! Runs **before** the normal inference lifecycle:
4//!   1. Workers advertise via mDNS, master discovers them.
5//!   2. Master computes layer assignments based on GPU VRAM.
6//!   3. Master connects to each worker, authenticates, assigns layers,
7//!      and pushes model data if the worker doesn't have it cached.
8//!   4. Workers load their assigned layers and signal readiness.
9//!
10//! After setup, both sides proceed with normal `Context::from_args()` / inference.
11
12use std::collections::HashSet;
13use std::io::Write;
14use std::path::{Path, PathBuf};
15use std::time::{Duration, Instant};
16
17use anyhow::Result;
18use tokio::net::{TcpListener, TcpStream};
19
20use super::auth;
21use super::discovery::{self, DiscoveredWorker};
22use super::proto::Message;
23use super::topology::{Node, Topology};
24
25/// Derive the layer name prefix from config.json.
26/// Returns e.g. "model.language_model.layers" for Qwen3.5, "model.layers" otherwise.
27fn layer_prefix_for_config(config_json: &serde_json::Value) -> String {
28    if let Some(archs) = config_json.get("architectures").and_then(|v| v.as_array()) {
29        for arch in archs {
30            if let Some(s) = arch.as_str() {
31                if s == "Qwen3_5ForConditionalGeneration" {
32                    return "model.language_model.layers".to_string();
33                }
34            }
35        }
36    }
37    "model.layers".to_string()
38}
39
40/// Maximum chunk size for model data transfer (128 MB).
41const MODEL_DATA_CHUNK_SIZE: usize = 128 * 1024 * 1024;
42
43/// Query actual free GPU memory via nvidia-smi (CUDA only).
44/// Returns 0 if unavailable.
45fn detect_free_gpu_memory() -> u64 {
46    #[cfg(feature = "cuda")]
47    {
48        if let Ok(output) = std::process::Command::new("nvidia-smi")
49            .args(["--query-gpu=memory.free", "--format=csv,noheader,nounits"])
50            .output()
51        {
52            if output.status.success() {
53                let stdout = String::from_utf8_lossy(&output.stdout);
54                return stdout
55                    .lines()
56                    .filter_map(|line| line.trim().parse::<u64>().ok())
57                    .map(|mb| mb * 1024 * 1024)
58                    .sum();
59            }
60        }
61    }
62    0
63}
64
65/// Read a safetensors header and return per-tensor byte sizes.
66/// The safetensors format stores an 8-byte LE header length followed by a JSON
67/// object mapping tensor names to `{dtype, shape, data_offsets: [start, end]}`.
68fn read_safetensors_tensor_sizes(path: &Path) -> Option<Vec<(String, u64)>> {
69    use std::io::Read;
70    let mut f = std::fs::File::open(path).ok()?;
71    let mut len_buf = [0u8; 8];
72    f.read_exact(&mut len_buf).ok()?;
73    let header_len = u64::from_le_bytes(len_buf) as usize;
74    // Sanity: headers are typically < 1 MB
75    if header_len > 10 * 1024 * 1024 {
76        return None;
77    }
78    let mut header_buf = vec![0u8; header_len];
79    f.read_exact(&mut header_buf).ok()?;
80    let header: serde_json::Value = serde_json::from_slice(&header_buf).ok()?;
81    let obj = header.as_object()?;
82    let mut result = Vec::with_capacity(obj.len());
83    for (name, meta) in obj {
84        if name == "__metadata__" {
85            continue;
86        }
87        if let Some(offsets) = meta.get("data_offsets").and_then(|v| v.as_array()) {
88            if offsets.len() == 2 {
89                let start = offsets[0].as_u64().unwrap_or(0);
90                let end = offsets[1].as_u64().unwrap_or(0);
91                result.push((name.clone(), end.saturating_sub(start)));
92            }
93        }
94    }
95    Some(result)
96}
97
98/// Estimate average transformer layer size in bytes from safetensors files.
99///
100/// For sharded models, reads each shard's header to compute exact per-tensor
101/// byte sizes, then sums only tensors matching `layer_prefix`. This excludes
102/// non-layer weights (visual encoder, MTP heads, embeddings, lm_head) which
103/// can be significant — e.g. Qwen3.5-27B-FP8 has ~6 GB of non-layer data.
104fn estimate_layer_size(model_path: &Path, num_layers: usize, layer_prefix: &str) -> u64 {
105    if num_layers == 0 {
106        return 0;
107    }
108
109    let layer_dot = format!("{}.", layer_prefix);
110
111    // Try sharded model first
112    let index_path = model_path.join("model.safetensors.index.json");
113    if let Ok(data) = std::fs::read_to_string(&index_path) {
114        if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
115            if let Some(weight_map) = json.get("weight_map").and_then(|v| v.as_object()) {
116                let shards: HashSet<&str> =
117                    weight_map.values().filter_map(|v| v.as_str()).collect();
118
119                // Try reading safetensors headers for exact tensor sizes
120                let mut layer_bytes: u64 = 0;
121                let mut total_bytes: u64 = 0;
122                let mut headers_ok = true;
123
124                for shard in &shards {
125                    let shard_path = model_path.join(shard);
126                    if let Some(tensors) = read_safetensors_tensor_sizes(&shard_path) {
127                        for (name, size) in &tensors {
128                            total_bytes += size;
129                            if name.starts_with(&layer_dot) {
130                                layer_bytes += size;
131                            }
132                        }
133                    } else {
134                        headers_ok = false;
135                        break;
136                    }
137                }
138
139                if headers_ok && layer_bytes > 0 {
140                    let non_layer = total_bytes - layer_bytes;
141                    if non_layer > 0 {
142                        log::info!(
143                            "model weights: {} total, {} layers, {} non-layer ({:.0}% excluded)",
144                            human_bytes::human_bytes(total_bytes as f64),
145                            human_bytes::human_bytes(layer_bytes as f64),
146                            human_bytes::human_bytes(non_layer as f64),
147                            non_layer as f64 / total_bytes as f64 * 100.0,
148                        );
149                    }
150                    return layer_bytes / num_layers as u64;
151                }
152
153                // Fallback: raw file size division
154                let total: u64 = shards
155                    .iter()
156                    .filter_map(|s| std::fs::metadata(model_path.join(s)).ok())
157                    .map(|m| m.len())
158                    .sum();
159                return total / num_layers as u64;
160            }
161        }
162    }
163
164    // Single safetensors file
165    let single = model_path.join("model.safetensors");
166    if let Ok(single_path) = single.canonicalize() {
167        // Try header-based estimation for single file too
168        if let Some(tensors) = read_safetensors_tensor_sizes(&single_path) {
169            let mut layer_bytes: u64 = 0;
170            for (name, size) in &tensors {
171                if name.starts_with(&layer_dot) {
172                    layer_bytes += size;
173                }
174            }
175            if layer_bytes > 0 {
176                return layer_bytes / num_layers as u64;
177            }
178        }
179        // Fallback
180        if let Ok(m) = std::fs::metadata(&single_path) {
181            return m.len() / num_layers as u64;
182        }
183    }
184
185    0
186}
187
188// ── Layer assignment ────────────────────────────────────────────────────────
189
190/// Compute layer assignments proportional to each worker's estimated TFLOPS,
191/// accounting for the master's own compute so it retains its fair share of layers.
192/// When `layer_size_bytes` > 0, each worker's assignment is capped by its
193/// per-GPU VRAM to prevent out-of-memory errors on multi-GPU nodes.
194/// The master's local layers are also capped by `master_max_layers` to avoid OOM.
195///
196/// Returns a vec of `(worker_index, layer_names)`.
197/// Workers are sorted by TFLOPS descending, and layers are assigned as
198/// contiguous ranges starting from layer 0. Unassigned layers remain on master.
199pub fn compute_layer_assignments(
200    workers: &[DiscoveredWorker],
201    num_layers: usize,
202    master_tflops: f64,
203    layer_size_bytes: u64,
204    master_max_layers: usize,
205    layer_prefix: &str,
206) -> Vec<(usize, Vec<String>)> {
207    if workers.is_empty() || num_layers == 0 {
208        return vec![];
209    }
210
211    // Include master TFLOPS in total so layers are split proportionally
212    let total_tflops: f64 =
213        workers.iter().map(|w| w.total_tflops()).sum::<f64>() + master_tflops;
214
215    if total_tflops <= 0.0 {
216        // No compute info — give half to workers, half to master
217        let worker_layers = num_layers / 2;
218        let per_worker = worker_layers / workers.len();
219        let mut assignments = vec![];
220        let mut offset = 0;
221        for (i, _) in workers.iter().enumerate() {
222            let count = if i == workers.len() - 1 {
223                worker_layers - offset
224            } else {
225                per_worker
226            };
227            let layers: Vec<String> = (offset..offset + count)
228                .map(|l| format!("{layer_prefix}.{l}"))
229                .collect();
230            assignments.push((i, layers));
231            offset += count;
232        }
233        return assignments;
234    }
235
236    // Sort worker indices by TFLOPS descending
237    let mut indices: Vec<usize> = (0..workers.len()).collect();
238    indices.sort_by(|a, b| {
239        workers[*b]
240            .total_tflops()
241            .partial_cmp(&workers[*a].total_tflops())
242            .unwrap_or(std::cmp::Ordering::Equal)
243    });
244
245    // Total layers for all workers combined (master keeps its share)
246    let workers_tflops: f64 = workers.iter().map(|w| w.total_tflops()).sum();
247    let total_worker_layers =
248        (workers_tflops / total_tflops * num_layers as f64).round() as usize;
249    let total_worker_layers = total_worker_layers.min(num_layers);
250
251    log::info!(
252        "master: {:.1} TFLOPS — workers: {:.1} TFLOPS — assigning {} of {} layers to workers",
253        master_tflops,
254        workers_tflops,
255        total_worker_layers,
256        num_layers
257    );
258
259    let mut assignments = vec![];
260    let mut offset = 0;
261    let mut remaining_layers = total_worker_layers;
262    let mut remaining_tflops = workers_tflops;
263
264    for (pos, &worker_idx) in indices.iter().enumerate() {
265        if remaining_layers == 0 {
266            break;
267        }
268
269        let mut count = if pos == indices.len() - 1 {
270            remaining_layers
271        } else {
272            let worker_tflops = workers[worker_idx].total_tflops();
273            let proportional =
274                (worker_tflops / remaining_tflops * remaining_layers as f64).round() as usize;
275            proportional.max(1).min(remaining_layers)
276        };
277
278        // Cap by per-GPU VRAM to avoid OOM on multi-GPU workers
279        if layer_size_bytes > 0 {
280            let max_layers = workers[worker_idx].max_layers_for_size(layer_size_bytes);
281            if count > max_layers {
282                log::info!(
283                    "  {} capped from {} to {} layers (VRAM limit: {} per layer)",
284                    &workers[worker_idx].name,
285                    count,
286                    max_layers,
287                    human_bytes::human_bytes(layer_size_bytes as f64)
288                );
289                count = max_layers;
290            }
291        }
292
293        let layers: Vec<String> = (offset..offset + count)
294            .map(|l| format!("{layer_prefix}.{l}"))
295            .collect();
296
297        assignments.push((worker_idx, layers));
298        offset += count;
299        remaining_layers -= count;
300        remaining_tflops -= workers[worker_idx].total_tflops();
301    }
302
303    // Check if master would be left with more layers than it can hold.
304    // The master keeps all layers from `offset` to `num_layers - 1`.
305    let master_layers = num_layers - offset;
306    if master_max_layers < usize::MAX && master_layers > master_max_layers {
307        let deficit = master_layers - master_max_layers;
308        log::info!(
309            "master has {} local layers but can fit {} — redistributing {} to workers",
310            master_layers,
311            master_max_layers,
312            deficit
313        );
314
315        // Try to push excess layers to workers that have spare VRAM capacity.
316        let mut extra_needed = deficit;
317        for (worker_idx, layers) in assignments.iter_mut() {
318            if extra_needed == 0 {
319                break;
320            }
321            let current = layers.len();
322            let max = if layer_size_bytes > 0 {
323                workers[*worker_idx].max_layers_for_size(layer_size_bytes)
324            } else {
325                usize::MAX
326            };
327            let spare = max.saturating_sub(current);
328            if spare > 0 {
329                let take = spare.min(extra_needed);
330                // Extend this worker's range (layers are at the end)
331                let new_start = offset;
332                for l in new_start..new_start + take {
333                    layers.push(format!("{layer_prefix}.{l}"));
334                }
335                offset += take;
336                extra_needed -= take;
337                log::info!(
338                    "  {} takes {} extra layer(s) ({} → {} total)",
339                    &workers[*worker_idx].name,
340                    take,
341                    current,
342                    layers.len()
343                );
344            }
345        }
346
347        if extra_needed > 0 {
348            log::warn!(
349                "cluster cannot fit all {} layers — {} layer(s) unassignable (master VRAM too small, workers full)",
350                num_layers,
351                extra_needed
352            );
353        }
354    }
355
356    assignments
357}
358
359// ── Master setup ────────────────────────────────────────────────────────────
360
361/// Run the full zero-config master setup.
362///
363/// Discovers workers via mDNS, computes layer assignments based on VRAM,
364/// connects to each worker with mutual authentication, pushes model data
365/// as needed, and returns a `Topology` ready for normal inference.
366pub async fn master_setup(
367    cluster_key: &str,
368    model_path: &Path,
369    discovery_timeout: Duration,
370) -> Result<Topology> {
371    // Read config.json and compute a fingerprint for cache keying
372    let config_path = model_path.join("config.json");
373    let config_data = std::fs::read_to_string(&config_path)
374        .map_err(|e| anyhow!("failed to read {}: {}", config_path.display(), e))?;
375    let model_hash = {
376        use sha2::{Digest, Sha256};
377        let mut hasher = Sha256::new();
378        hasher.update(config_data.as_bytes());
379        let result = hasher.finalize();
380        hex::encode(&result[..4])
381    };
382    let config_json: serde_json::Value = serde_json::from_str(&config_data)?;
383    let num_layers = config_json
384        .get("num_hidden_layers")
385        .and_then(|v| v.as_u64())
386        .or_else(|| {
387            // Some models (e.g. Qwen3.5) nest config under text_config
388            config_json
389                .get("text_config")
390                .and_then(|tc| tc.get("num_hidden_layers"))
391                .and_then(|v| v.as_u64())
392        })
393        .ok_or_else(|| anyhow!("num_hidden_layers not found in config.json"))? as usize;
394
395    // Derive layer naming prefix from architecture (needed early for layer size estimation)
396    let layer_prefix = layer_prefix_for_config(&config_json);
397
398    log::info!(
399        "model has {} transformer layers (prefix: {})",
400        num_layers,
401        &layer_prefix,
402    );
403
404    // Detect master GPU and free VRAM concurrently with the discovery window
405    // (nvidia-smi can take ~1-2s; hide that cost inside the discovery timeout).
406    let master_gpus = discovery::detect_gpus();
407    let master_tflops: f64 = master_gpus.iter().map(|g| g.tflops as f64).sum();
408    let free_gpu_fut = tokio::task::spawn_blocking(detect_free_gpu_memory);
409
410    // Discover workers
411    let workers = discovery::discover_workers(cluster_key, discovery_timeout).await?;
412    if workers.is_empty() {
413        log::warn!("no workers discovered — all layers will be loaded locally");
414        return Ok(Topology::new());
415    }
416
417    // nvidia-smi result is now ready (ran during the discovery window)
418    let master_free_from_smi = free_gpu_fut.await.unwrap_or(0);
419
420    // Estimate per-layer size for VRAM-aware capping.
421    // Uses weight_map tensor-count fractions to exclude non-layer weights
422    // (visual encoder, MTP heads, embeddings, lm_head, FP8 scale_inv, etc.).
423    let layer_size_on_disk = estimate_layer_size(model_path, num_layers, &layer_prefix);
424    if layer_size_on_disk > 0 {
425        log::info!(
426            "estimated layer size (on disk): {}",
427            human_bytes::human_bytes(layer_size_on_disk as f64)
428        );
429    }
430
431    // Estimate non-layer overhead the master must hold (embeddings + lm_head + norm + CUDA runtime).
432    // embeddings = vocab_size * hidden_size * dtype_bytes
433    // lm_head    = vocab_size * hidden_size * dtype_bytes (unless tied)
434    let dtype_bytes: u64 = 2; // F16
435    let vocab_size = config_json
436        .get("vocab_size")
437        .and_then(|v| v.as_u64())
438        .or_else(|| {
439            config_json
440                .get("text_config")
441                .and_then(|tc| tc.get("vocab_size"))
442                .and_then(|v| v.as_u64())
443        })
444        .unwrap_or(32000);
445    let hidden_size = config_json
446        .get("hidden_size")
447        .and_then(|v| v.as_u64())
448        .or_else(|| {
449            config_json
450                .get("text_config")
451                .and_then(|tc| tc.get("hidden_size"))
452                .and_then(|v| v.as_u64())
453        })
454        .unwrap_or(4096);
455    let tie_embeddings = config_json
456        .get("tie_word_embeddings")
457        .and_then(|v| v.as_bool())
458        .unwrap_or(false);
459
460    let embed_size = vocab_size * hidden_size * dtype_bytes;
461    let lm_head_size = if tie_embeddings { 0 } else { embed_size };
462    // Add ~1 GiB for CUDA runtime/context, KV cache, memory fragmentation, and misc overhead
463    let master_overhead = embed_size + lm_head_size + 1024 * 1024 * 1024;
464
465    // FP8 models store weights at 1 byte per element on disk, but after dequantization
466    // they expand to the target dtype (F16 = 2 bytes, BF16 = 2 bytes). Scale the layer
467    // size estimate so VRAM-based capping uses the actual in-memory size.
468    let is_fp8 = crate::utils::fp8::is_fp8_quantized(&config_path);
469    let layer_size_bytes = if is_fp8 && layer_size_on_disk > 0 {
470        let expanded = layer_size_on_disk * dtype_bytes; // FP8 is 1 byte, target is dtype_bytes
471        log::info!(
472            "FP8 model: layer size after dequantization: {} ({}x expansion)",
473            human_bytes::human_bytes(expanded as f64),
474            dtype_bytes,
475        );
476        expanded
477    } else {
478        layer_size_on_disk
479    };
480
481    log::info!(
482        "master overhead: embeddings={} lm_head={} total={}",
483        human_bytes::human_bytes(embed_size as f64),
484        human_bytes::human_bytes(lm_head_size as f64),
485        human_bytes::human_bytes(master_overhead as f64),
486    );
487
488    // Cap master layers by its own GPU VRAM minus the non-layer overhead.
489    // Use actual free VRAM (from nvidia-smi) when available, as total VRAM
490    // overestimates on systems with display servers or other GPU consumers.
491    let master_max_layers = if layer_size_bytes > 0 && !master_gpus.is_empty() {
492        let master_vram: u64 = master_gpus.iter().map(|g| g.vram_bytes).sum();
493        let effective_vram = if master_free_from_smi > 0 && master_free_from_smi < master_vram {
494            log::info!(
495                "master GPU: {} total, {} free",
496                human_bytes::human_bytes(master_vram as f64),
497                human_bytes::human_bytes(master_free_from_smi as f64),
498            );
499            master_free_from_smi
500        } else {
501            master_vram
502        };
503        let available = effective_vram.saturating_sub(master_overhead);
504        let max = (available / layer_size_bytes) as usize;
505        log::info!(
506            "master GPU: {} available for layers — can fit ~{} layers locally",
507            human_bytes::human_bytes(available as f64),
508            max
509        );
510        max
511    } else {
512        usize::MAX
513    };
514
515    // Compute assignments based on TFLOPS, capped by per-GPU VRAM
516    let assignments = compute_layer_assignments(
517        &workers,
518        num_layers,
519        master_tflops,
520        layer_size_bytes,
521        master_max_layers,
522        &layer_prefix,
523    );
524
525    // Summarise layer assignments and estimate per-node weight loads
526    let total_assigned: usize = assignments.iter().map(|(_, l)| l.len()).sum();
527    let master_layers = num_layers - total_assigned;
528    log::info!("layer assignments:");
529    for (worker_idx, layers) in &assignments {
530        let w = &workers[*worker_idx];
531        let range = if layers.is_empty() {
532            "(none)".to_string()
533        } else {
534            format!("{} — {}", layers.first().unwrap(), layers.last().unwrap())
535        };
536        let weight_load = layers.len() as u64 * layer_size_bytes;
537        log::info!(
538            "  {} ({}, {:.1} TFLOPS) → {} layers ({}) [{}]",
539            &w.name,
540            human_bytes::human_bytes(w.total_vram() as f64),
541            w.total_tflops(),
542            layers.len(),
543            human_bytes::human_bytes(weight_load as f64),
544            range,
545        );
546    }
547    log::info!(
548        "  master ({:.1} TFLOPS) → {} layers ({} weights + {} overhead)",
549        master_tflops,
550        master_layers,
551        human_bytes::human_bytes((master_layers as u64 * layer_size_bytes) as f64),
552        human_bytes::human_bytes(master_overhead as f64),
553    );
554    if layer_size_bytes > 0 {
555        log::info!(
556            "total weight read per token: {} ({} per layer × {})",
557            human_bytes::human_bytes((num_layers as u64 * layer_size_bytes) as f64),
558            human_bytes::human_bytes(layer_size_bytes as f64),
559            num_layers,
560        );
561    }
562
563    // Connect to all workers concurrently: authenticate, assign layers, push data
564    let mut handles = Vec::new();
565
566    for (worker_idx, layers) in &assignments {
567        let worker = workers[*worker_idx].clone();
568        if layers.is_empty() {
569            continue;
570        }
571
572        let layers = layers.clone();
573        let cluster_key = cluster_key.to_string();
574        let model_hash = model_hash.clone();
575        let model_path = model_path.to_path_buf();
576        let model_name = model_path
577            .file_name()
578            .unwrap_or_default()
579            .to_string_lossy()
580            .to_string();
581
582        handles.push(tokio::spawn(async move {
583            log::info!(
584                "connecting to worker '{}' at {} ...",
585                &worker.name,
586                &worker.host
587            );
588
589            let mut stream = TcpStream::connect(&worker.host)
590                .await
591                .map_err(|e| anyhow!("can't connect to {}: {}", &worker.host, e))?;
592            let _ = stream.set_nodelay(true);
593
594            // Mutual authentication
595            auth::authenticate_as_master(&mut stream, &cluster_key).await?;
596            log::info!("[{}] authenticated", &worker.name);
597
598            // Send layer assignment
599            let msg = Message::LayerAssignment {
600                layers: layers.clone(),
601                model_hash,
602            };
603            msg.to_writer(&mut stream).await?;
604
605            // Read ack
606            let (_, ack) = Message::from_reader(&mut stream).await?;
607            let needs_data = match ack {
608                Message::LayerAssignmentAck { needs_data } => needs_data,
609                other => {
610                    return Err(anyhow!(
611                        "[{}] unexpected response to LayerAssignment: {:?}",
612                        &worker.name,
613                        other
614                    ))
615                }
616            };
617
618            if needs_data {
619                push_model_data(&mut stream, &model_path, &layers, &worker.name, &model_name).await?;
620            } else {
621                log::info!("[{}] worker has model data cached", &worker.name);
622            }
623
624            // Wait for WorkerReady
625            let (_, ready) = Message::from_reader(&mut stream).await?;
626            if !matches!(ready, Message::WorkerReady) {
627                return Err(anyhow!(
628                    "[{}] expected WorkerReady, got {:?}",
629                    &worker.name,
630                    ready
631                ));
632            }
633            log::info!("[{}] worker ready", &worker.name);
634
635            Ok::<_, anyhow::Error>((worker, layers))
636        }));
637    }
638
639    // Collect results
640    let mut topology = Topology::new();
641    for handle in handles {
642        let (worker, layers) = handle.await??;
643        topology.insert(
644            worker.name.clone(),
645            Node {
646                host: worker.host.clone(),
647                description: Some(
648                    worker
649                        .gpus
650                        .iter()
651                        .map(|g| g.name.clone())
652                        .collect::<Vec<_>>()
653                        .join(", "),
654                ),
655                layers,
656                vram_bytes: worker.total_vram(),
657                tflops: worker.total_tflops(),
658                backend: worker.backend.clone(),
659                hostname: worker.hostname.clone(),
660                os: worker.os.clone(),
661            },
662        );
663    }
664
665    Ok(topology)
666}
667
668/// Push model data files to a worker that doesn't have them cached.
669async fn push_model_data(
670    stream: &mut TcpStream,
671    model_path: &Path,
672    layers: &[String],
673    worker_name: &str,
674    model_name: &str,
675) -> Result<()> {
676    let overall_start = Instant::now();
677    let mut overall_bytes: u64 = 0;
678
679    let layer_range = if layers.is_empty() {
680        "(none)".to_string()
681    } else {
682        format!(
683            "{} — {} ({} layers)",
684            layers.first().unwrap(),
685            layers.last().unwrap(),
686            layers.len()
687        )
688    };
689
690    log::info!(
691        "[{}] pushing {} [{}]",
692        worker_name,
693        model_name,
694        layer_range
695    );
696
697    // Always send config.json and tokenizer.json
698    let mut files_to_send: Vec<PathBuf> = vec![
699        model_path.join("config.json"),
700        model_path.join("tokenizer.json"),
701    ];
702
703    // Determine which safetensors shard files contain the assigned layers
704    let index_path = model_path.join("model.safetensors.index.json");
705    let mut filtered_index: Option<Vec<u8>> = None;
706    if index_path.exists() {
707        files_to_send.push(index_path.clone());
708        let index_data = std::fs::read(&index_path)?;
709        let mut index_json: serde_json::Value = serde_json::from_slice(&index_data)?;
710        let weight_map = index_json
711            .get("weight_map")
712            .and_then(|v| v.as_object())
713            .ok_or_else(|| anyhow!("no weight_map in model.safetensors.index.json"))?
714            .clone();
715
716        // Find shard files that contain tensors for the assigned layers
717        let mut needed_shards: HashSet<String> = HashSet::new();
718        let mut needed_weights: serde_json::Map<String, serde_json::Value> =
719            serde_json::Map::new();
720        for (tensor_name, shard_file) in &weight_map {
721            for layer in layers {
722                if tensor_name.starts_with(&format!("{}.", layer)) {
723                    if let Some(filename) = shard_file.as_str() {
724                        needed_shards.insert(filename.to_string());
725                    }
726                    needed_weights.insert(tensor_name.clone(), shard_file.clone());
727                }
728            }
729        }
730
731        // Build a filtered index.json that only references the pushed shards
732        if let Some(obj) = index_json.as_object_mut() {
733            obj.insert(
734                "weight_map".to_string(),
735                serde_json::Value::Object(needed_weights),
736            );
737        }
738        filtered_index = Some(serde_json::to_vec_pretty(&index_json)?);
739
740        log::info!(
741            "[{}] pushing {} shard file(s) + config + tokenizer + index",
742            worker_name,
743            needed_shards.len()
744        );
745
746        for shard in &needed_shards {
747            files_to_send.push(model_path.join(shard));
748        }
749    } else {
750        // Single safetensors file
751        let single = model_path.join("model.safetensors");
752        if single.exists() {
753            files_to_send.push(single);
754        }
755    }
756
757    // Stream each file
758    for file_path in &files_to_send {
759        let filename = file_path
760            .file_name()
761            .unwrap_or_default()
762            .to_string_lossy()
763            .to_string();
764
765        // Use filtered index if this is the index file
766        let file_data = if filename == "model.safetensors.index.json" {
767            if let Some(ref data) = filtered_index {
768                data.clone()
769            } else {
770                std::fs::read(file_path)
771                    .map_err(|e| anyhow!("failed to read {}: {}", file_path.display(), e))?
772            }
773        } else {
774            std::fs::read(file_path)
775                .map_err(|e| anyhow!("failed to read {}: {}", file_path.display(), e))?
776        };
777        let total_size = file_data.len() as u64;
778        let file_start = Instant::now();
779        let mut offset: u64 = 0;
780
781        log::info!(
782            "[{}] sending {} ({}) ...",
783            worker_name,
784            &filename,
785            human_bytes::human_bytes(total_size as f64)
786        );
787
788        for chunk in file_data.chunks(MODEL_DATA_CHUNK_SIZE) {
789            let msg = Message::ModelDataChunk {
790                filename: filename.clone(),
791                offset,
792                total_size,
793                data: chunk.to_vec(),
794            };
795            msg.to_writer(stream).await?;
796            offset += chunk.len() as u64;
797
798            // Log progress for large files
799            if total_size > MODEL_DATA_CHUNK_SIZE as u64 {
800                let elapsed = file_start.elapsed().as_secs_f64();
801                let speed = offset as f64 / elapsed;
802                let pct = (offset as f64 / total_size as f64) * 100.0;
803                let remaining = total_size - offset;
804                let eta_secs = if speed > 0.0 {
805                    remaining as f64 / speed
806                } else {
807                    0.0
808                };
809                log::info!(
810                    "[{}] {} — {}/{} ({:.1}%) — {}/s — ETA {:.0}s",
811                    worker_name,
812                    &filename,
813                    human_bytes::human_bytes(offset as f64),
814                    human_bytes::human_bytes(total_size as f64),
815                    pct,
816                    human_bytes::human_bytes(speed),
817                    eta_secs
818                );
819            }
820        }
821
822        let file_elapsed = file_start.elapsed();
823        let file_speed = total_size as f64 / file_elapsed.as_secs_f64();
824        overall_bytes += total_size;
825
826        log::info!(
827            "[{}] sent {} ({}) in {:.1}s — {}/s",
828            worker_name,
829            &filename,
830            human_bytes::human_bytes(total_size as f64),
831            file_elapsed.as_secs_f64(),
832            human_bytes::human_bytes(file_speed)
833        );
834    }
835
836    // Signal done
837    Message::ModelDataDone.to_writer(stream).await?;
838
839    let overall_elapsed = overall_start.elapsed();
840    let overall_speed = overall_bytes as f64 / overall_elapsed.as_secs_f64();
841    log::info!(
842        "[{}] transfer complete: {} in {:.1}s — {}/s avg",
843        worker_name,
844        human_bytes::human_bytes(overall_bytes as f64),
845        overall_elapsed.as_secs_f64(),
846        human_bytes::human_bytes(overall_speed)
847    );
848
849    Ok(())
850}
851
852/// Check whether a cache directory contains valid model data for the given layers.
853///
854/// For sharded models, verifies that the cached index's weight_map references all
855/// assigned layers and that the shard files containing those layers exist on disk.
856fn has_valid_model_cache(cache_dir: &Path, layers: &[String]) -> bool {
857    if !cache_dir.join("config.json").exists() {
858        return false;
859    }
860    // Single safetensors file — if it exists, assume it has everything
861    if cache_dir.join("model.safetensors").exists() {
862        return true;
863    }
864    // Sharded model: need index + shard files for all assigned layers
865    let index_path = cache_dir.join("model.safetensors.index.json");
866    if index_path.exists() {
867        if let Ok(data) = std::fs::read_to_string(&index_path) {
868            if let Ok(index) = serde_json::from_str::<serde_json::Value>(&data) {
869                if let Some(weight_map) = index.get("weight_map").and_then(|v| v.as_object()) {
870                    // For each assigned layer, check that at least one tensor exists
871                    // in the weight_map and its shard file is present on disk.
872                    for layer in layers {
873                        let prefix = format!("{}.", layer);
874                        let has_layer = weight_map.iter().any(|(tensor_name, shard_file)| {
875                            tensor_name.starts_with(&prefix)
876                                && shard_file
877                                    .as_str()
878                                    .is_some_and(|f| cache_dir.join(f).exists())
879                        });
880                        if !has_layer {
881                            log::debug!(
882                                "cache miss: {} not found in {}",
883                                layer,
884                                cache_dir.display()
885                            );
886                            return false;
887                        }
888                    }
889                    return true;
890                }
891            }
892        }
893    }
894    false
895}
896
897// ── Worker setup ────────────────────────────────────────────────────────────
898
899/// Run the zero-config worker setup.
900///
901/// Advertises via mDNS, waits for the master to connect and assign layers,
902/// receives model data if needed, and returns the assigned layers and model
903/// cache path.
904///
905/// The `listener` is returned so it can be reused for inference connections.
906/// Progress callback for worker setup stages.
907///
908/// Arguments: (stage, message, progress 0.0–1.0)
909/// - stage: "discovery", "connected", "authenticated", "layers", "receiving", "cached", "ready"
910/// - message: human-readable status
911/// - progress: 0.0–1.0 for transfer, 0.0 otherwise
912pub type SetupProgressFn = dyn Fn(&str, &str, f64) + Send + Sync;
913
914pub async fn worker_setup(
915    worker_name: &str,
916    cluster_key: &str,
917    bind_address: &str,
918    model_cache_dir: &Path,
919) -> Result<(Vec<String>, PathBuf, TcpListener)> {
920    worker_setup_with_progress(worker_name, cluster_key, bind_address, model_cache_dir, None).await
921}
922
923pub async fn worker_setup_with_progress(
924    worker_name: &str,
925    cluster_key: &str,
926    bind_address: &str,
927    model_cache_dir: &Path,
928    on_progress: Option<&SetupProgressFn>,
929) -> Result<(Vec<String>, PathBuf, TcpListener)> {
930    // Detect GPUs
931    let gpus = discovery::detect_gpus();
932    log::info!("detected {} GPU(s):", gpus.len());
933    for gpu in &gpus {
934        log::info!(
935            "  {} — {} (~{:.1} TFLOPS)",
936            &gpu.name,
937            human_bytes::human_bytes(gpu.vram_bytes as f64),
938            gpu.tflops
939        );
940    }
941
942    // Bind listener
943    let listener = TcpListener::bind(bind_address).await?;
944    let port = listener.local_addr()?.port();
945    log::info!("listening on {} (setup mode)", bind_address);
946
947    // Advertise via UDP broadcast
948    let _discovery = discovery::advertise_worker(worker_name, port, cluster_key, &gpus)?;
949
950    log::info!("waiting for master to connect and assign layers...");
951    if let Some(cb) = &on_progress {
952        cb("discovery", "Waiting for master...", 0.0);
953    }
954
955    // Accept one setup connection from master
956    let (mut stream, client_addr) = listener.accept().await?;
957    let _ = stream.set_nodelay(true);
958    log::info!("[{}] master connected", client_addr);
959    if let Some(cb) = &on_progress {
960        cb("connected", &format!("Master connected ({})", client_addr), 0.0);
961    }
962
963    // Authenticate
964    auth::authenticate_as_worker(&mut stream, cluster_key).await?;
965    log::info!("[{}] authenticated", client_addr);
966    if let Some(cb) = &on_progress {
967        cb("authenticated", "Authenticated with master", 0.0);
968    }
969
970    // Receive layer assignment
971    let (_, msg) = Message::from_reader(&mut stream).await?;
972    let (layers, model_hash) = match msg {
973        Message::LayerAssignment {
974            layers,
975            model_hash,
976        } => (layers, model_hash),
977        other => {
978            return Err(anyhow!(
979                "expected LayerAssignment, got {:?}",
980                other
981            ))
982        }
983    };
984
985    log::info!("assigned {} layers:", layers.len());
986    for layer in &layers {
987        log::info!("  {}", layer);
988    }
989    if let Some(cb) = &on_progress {
990        cb("layers", &format!("Assigned {} layer(s)", layers.len()), 0.0);
991    }
992
993    // Determine cache directory: cluster_hash/model_hash
994    // This ensures switching models invalidates the cache.
995    let cluster_id = discovery::cluster_hash(cluster_key);
996    let cache_dir = if model_hash.is_empty() {
997        // Backwards compat with old masters that don't send model_hash
998        model_cache_dir.join(&cluster_id)
999    } else {
1000        model_cache_dir.join(format!("{}-{}", cluster_id, model_hash))
1001    };
1002    std::fs::create_dir_all(&cache_dir)?;
1003
1004    // Check if we already have a valid model data cache for the assigned layers.
1005    let needs_data = !has_valid_model_cache(&cache_dir, &layers);
1006
1007    let ack = Message::LayerAssignmentAck { needs_data };
1008    ack.to_writer(&mut stream).await?;
1009
1010    if needs_data {
1011        if let Some(cb) = &on_progress {
1012            cb("receiving", "Receiving model data...", 0.0);
1013        }
1014        receive_model_data(&mut stream, &cache_dir, &layers, on_progress).await?;
1015    } else {
1016        log::info!("using cached model data from {}", cache_dir.display());
1017        if let Some(cb) = &on_progress {
1018            cb("cached", "Using cached model data", 1.0);
1019        }
1020    }
1021
1022    // Signal ready
1023    Message::WorkerReady.to_writer(&mut stream).await?;
1024    log::info!("setup complete, ready for inference");
1025    if let Some(cb) = &on_progress {
1026        cb("ready", "Setup complete", 1.0);
1027    }
1028
1029    // Drop the setup connection (stream goes out of scope)
1030    // The listener is returned for reuse
1031    Ok((layers, cache_dir, listener))
1032}
1033
1034/// Receive model data from master and write to the cache directory.
1035async fn receive_model_data(
1036    stream: &mut TcpStream,
1037    cache_dir: &Path,
1038    layers: &[String],
1039    on_progress: Option<&SetupProgressFn>,
1040) -> Result<()> {
1041    let overall_start = Instant::now();
1042    let mut overall_bytes: u64 = 0;
1043    let mut current_file: Option<(String, std::fs::File, Instant, u64)> = None;
1044
1045    let layer_range = if layers.is_empty() {
1046        "(none)".to_string()
1047    } else {
1048        format!(
1049            "{} — {} ({} layers)",
1050            layers.first().unwrap(),
1051            layers.last().unwrap(),
1052            layers.len()
1053        )
1054    };
1055
1056    log::info!("receiving model data [{}] ...", layer_range);
1057
1058    loop {
1059        let (_, msg) = Message::from_reader(stream).await?;
1060
1061        match msg {
1062            Message::ModelDataChunk {
1063                filename,
1064                offset,
1065                total_size,
1066                data,
1067            } => {
1068                // Open new file if needed
1069                let file = if let Some((ref name, ref mut file, _, _)) = current_file {
1070                    if name == &filename {
1071                        file
1072                    } else {
1073                        // Close previous file, log stats
1074                        if let Some((prev_name, _, start, size)) = current_file.take() {
1075                            let elapsed = start.elapsed();
1076                            let speed = size as f64 / elapsed.as_secs_f64();
1077                            log::info!(
1078                                "received {} ({}) — {}/s",
1079                                &prev_name,
1080                                human_bytes::human_bytes(size as f64),
1081                                human_bytes::human_bytes(speed)
1082                            );
1083                            if let Some(cb) = &on_progress {
1084                                cb("receiving", &format!("{} complete", &prev_name), 1.0);
1085                            }
1086                        }
1087                        let path = cache_dir.join(&filename);
1088                        let f = std::fs::File::create(&path)?;
1089                        current_file = Some((filename.clone(), f, Instant::now(), total_size));
1090                        if let Some(cb) = &on_progress {
1091                            cb("receiving", &format!("Receiving {} ({})", &filename, human_bytes::human_bytes(total_size as f64)), 0.0);
1092                        }
1093                        &mut current_file.as_mut().unwrap().1
1094                    }
1095                } else {
1096                    let path = cache_dir.join(&filename);
1097                    let f = std::fs::File::create(&path)?;
1098                    current_file = Some((filename.clone(), f, Instant::now(), total_size));
1099                    if let Some(cb) = &on_progress {
1100                        cb("receiving", &format!("Receiving {} ({})", &filename, human_bytes::human_bytes(total_size as f64)), 0.0);
1101                    }
1102                    &mut current_file.as_mut().unwrap().1
1103                };
1104
1105                file.write_all(&data)?;
1106                overall_bytes += data.len() as u64;
1107
1108                // Progress callback for all files (not just large ones)
1109                let written = offset + data.len() as u64;
1110                if let Some((_, _, ref start, _)) = current_file {
1111                    let elapsed = start.elapsed().as_secs_f64();
1112                    let speed = if elapsed > 0.0 { written as f64 / elapsed } else { 0.0 };
1113                    let pct = if total_size > 0 { (written as f64 / total_size as f64) * 100.0 } else { 0.0 };
1114
1115                    if total_size > MODEL_DATA_CHUNK_SIZE as u64 && written < total_size {
1116                        let remaining = total_size - written;
1117                        let eta_secs = if speed > 0.0 { remaining as f64 / speed } else { 0.0 };
1118                        log::info!(
1119                            "  {} — {}/{} ({:.1}%) — {}/s — ETA {:.0}s",
1120                            &filename,
1121                            human_bytes::human_bytes(written as f64),
1122                            human_bytes::human_bytes(total_size as f64),
1123                            pct,
1124                            human_bytes::human_bytes(speed),
1125                            eta_secs
1126                        );
1127                        if let Some(cb) = &on_progress {
1128                            let msg = format!(
1129                                "{} — {}/{} — {}/s — ETA {:.0}s",
1130                                &filename,
1131                                human_bytes::human_bytes(written as f64),
1132                                human_bytes::human_bytes(total_size as f64),
1133                                human_bytes::human_bytes(speed),
1134                                eta_secs
1135                            );
1136                            cb("receiving", &msg, pct / 100.0);
1137                        }
1138                    }
1139                }
1140            }
1141            Message::ModelDataDone => {
1142                // Log last file
1143                if let Some((name, _, start, size)) = current_file.take() {
1144                    let elapsed = start.elapsed();
1145                    let speed = size as f64 / elapsed.as_secs_f64();
1146                    log::info!(
1147                        "received {} ({}) — {}/s",
1148                        &name,
1149                        human_bytes::human_bytes(size as f64),
1150                        human_bytes::human_bytes(speed)
1151                    );
1152                }
1153                break;
1154            }
1155            other => {
1156                return Err(anyhow!(
1157                    "unexpected message during data transfer: {:?}",
1158                    other
1159                ));
1160            }
1161        }
1162    }
1163
1164    let overall_elapsed = overall_start.elapsed();
1165    let overall_speed = overall_bytes as f64 / overall_elapsed.as_secs_f64();
1166    log::info!(
1167        "model data received: {} in {:.1}s — {}/s avg, cached to {}",
1168        human_bytes::human_bytes(overall_bytes as f64),
1169        overall_elapsed.as_secs_f64(),
1170        human_bytes::human_bytes(overall_speed),
1171        cache_dir.display()
1172    );
1173
1174    Ok(())
1175}