Skip to main content

cake_core/cake/
discovery.rs

1//! UDP broadcast service discovery and GPU detection for zero-config clustering.
2
3use std::collections::HashMap;
4use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket};
5use std::time::Duration;
6
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use speedy::{Readable, Writable};
11
12/// UDP broadcast port for Cake discovery.
13const DISCOVERY_PORT: u16 = 10127;
14
15/// Magic bytes to identify Cake discovery packets.
16const MAGIC: &[u8; 4] = b"CAKE";
17
18/// Default discovery timeout.
19pub const DEFAULT_DISCOVERY_TIMEOUT: Duration = Duration::from_secs(10);
20
21/// GPU information advertised by a worker.
22#[derive(Debug, Clone, Serialize, Deserialize, Readable, Writable)]
23pub struct GpuInfo {
24    pub name: String,
25    pub vram_bytes: u64,
26    #[serde(default)]
27    pub tflops: f32,
28}
29
30/// A worker discovered via broadcast.
31#[derive(Debug, Clone)]
32pub struct DiscoveredWorker {
33    pub name: String,
34    pub host: String,
35    pub port: u16,
36    pub gpus: Vec<GpuInfo>,
37    pub backend: String,
38    pub hostname: String,
39    pub os: String,
40}
41
42impl DiscoveredWorker {
43    /// Total VRAM across all GPUs.
44    pub fn total_vram(&self) -> u64 {
45        self.gpus.iter().map(|g| g.vram_bytes).sum()
46    }
47
48    /// Maximum number of layers this worker can fit, based on per-GPU VRAM.
49    ///
50    /// For dedicated GPUs (CUDA), reserves ~5% for driver/runtime overhead
51    /// (typically 200–600 MiB for CUDA context + cuBLAS workspace).
52    /// For unified-memory devices (Apple Silicon), reserves 28% of total
53    /// (minimum 6 GiB) for macOS + inference working memory, since model
54    /// weights compete with the OS for the same physical RAM and insufficient
55    /// headroom causes catastrophic memory-compressor thrashing.
56    pub fn max_layers_for_size(&self, layer_size_bytes: u64) -> usize {
57        if layer_size_bytes == 0 || self.gpus.is_empty() {
58            return usize::MAX;
59        }
60        self.gpus
61            .iter()
62            .map(|g| {
63                let name_lower = g.name.to_lowercase();
64                let is_cpu = name_lower.starts_with("cpu");
65                let is_unified = name_lower.contains("apple");
66                let usable = if is_cpu {
67                    // CPU / mobile worker: reported vram_bytes is system RAM.
68                    // Reserve 20% for OS + runtime; no large fixed minimum since
69                    // mobile devices may have only 2–4 GiB total.
70                    let reserve = (g.vram_bytes as f64 * 0.20) as u64;
71                    g.vram_bytes.saturating_sub(reserve)
72                } else if is_unified {
73                    // Unified memory: reserve 28% of total (min 6 GiB) for OS +
74                    // Metal working memory. At 30 layers on a 36 GiB M3 Pro,
75                    // only 8 GiB remained and macOS memory compressor caused
76                    // 100+ sec/forward-pass thrashing; 28% keeps ~10 GiB free.
77                    let min_reserve = 6u64 * 1024 * 1024 * 1024;
78                    let pct_reserve = (g.vram_bytes as f64 * 0.28) as u64;
79                    let os_reserve = pct_reserve.max(min_reserve);
80                    g.vram_bytes.saturating_sub(os_reserve)
81                } else {
82                    // Dedicated VRAM: reserve max(5%, 768 MiB) for CUDA context,
83                    // cuBLAS workspace, and memory fragmentation. The percentage
84                    // works for large GPUs (24+ GB), but on 12 GB GPUs 5% = 600 MB
85                    // leaves only ~50 MB headroom after filling layers, causing OOM.
86                    let min_reserve = 768u64 * 1024 * 1024;
87                    let pct_reserve = (g.vram_bytes as f64 * 0.05) as u64;
88                    let reserve = pct_reserve.max(min_reserve);
89                    g.vram_bytes.saturating_sub(reserve)
90                };
91                (usable / layer_size_bytes) as usize
92            })
93            .sum()
94    }
95
96    /// Total estimated TFLOPS across all GPUs.
97    /// Falls back to a VRAM-based estimate when workers report 0 (old binaries).
98    pub fn total_tflops(&self) -> f64 {
99        let reported: f64 = self.gpus.iter().map(|g| g.tflops as f64).sum();
100        if reported > 0.0 {
101            return reported;
102        }
103        // Fallback: estimate from VRAM and device name
104        self.gpus
105            .iter()
106            .map(|g| {
107                let vram_gb = g.vram_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
108                let name_lower = g.name.to_lowercase();
109                if name_lower.contains("nvidia")
110                    || name_lower.contains("geforce")
111                    || name_lower.contains("rtx")
112                    || name_lower.contains("gtx")
113                    || name_lower.contains("tesla")
114                {
115                    vram_gb * 3.0 // CUDA GPU fallback
116                } else if name_lower.contains("apple") || name_lower.contains("silicon") {
117                    vram_gb * 0.4 // Metal fallback
118                } else {
119                    2.0 // CPU fallback
120                }
121            })
122            .sum()
123    }
124}
125
126/// Compute the first 8 hex chars of SHA-256(cluster_key) for filtering.
127pub fn cluster_hash(cluster_key: &str) -> String {
128    let mut hasher = Sha256::new();
129    hasher.update(cluster_key.as_bytes());
130    let result = hasher.finalize();
131    hex::encode(&result[..4])
132}
133
134/// Detect available compute devices on this system.
135///
136/// Only reports NVIDIA GPUs when the `cuda` feature is compiled in,
137/// and Metal on macOS when the `metal` feature is compiled in.
138/// Otherwise falls back to CPU with system RAM.
139pub fn detect_gpus() -> Vec<GpuInfo> {
140    // Only probe NVIDIA GPUs if built with CUDA support
141    #[cfg(feature = "cuda")]
142    {
143        if let Ok(output) = std::process::Command::new("nvidia-smi")
144            .args([
145                "--query-gpu=name,memory.total,clocks.max.graphics",
146                "--format=csv,noheader,nounits",
147            ])
148            .output()
149        {
150            if output.status.success() {
151                let stdout = String::from_utf8_lossy(&output.stdout);
152                let gpus: Vec<GpuInfo> = stdout
153                    .lines()
154                    .filter_map(|line| {
155                        let parts: Vec<&str> = line.splitn(3, ',').collect();
156                        if parts.len() >= 2 {
157                            let name = parts[0].trim().to_string();
158                            let vram_mb: u64 = parts[1].trim().parse().ok()?;
159                            let vram_gb = vram_mb as f32 / 1024.0;
160                            // Estimate FP16 TFLOPS from VRAM tier and max clock
161                            let tflops = if parts.len() >= 3 {
162                                let clock_mhz: f32 =
163                                    parts[2].trim().parse().unwrap_or(1500.0);
164                                vram_gb * (clock_mhz / 1000.0) * 1.5
165                            } else {
166                                vram_gb * 3.0
167                            };
168                            Some(GpuInfo {
169                                name,
170                                vram_bytes: vram_mb * 1024 * 1024,
171                                tflops,
172                            })
173                        } else {
174                            None
175                        }
176                    })
177                    .collect();
178
179                if !gpus.is_empty() {
180                    return gpus;
181                }
182            }
183        }
184    }
185
186    // Report Metal on macOS/iOS when built with metal support
187    #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
188    {
189        let chip = detect_apple_chip().unwrap_or_else(|| format!("Apple ({})", std::env::consts::ARCH));
190        let vram_bytes = detect_system_memory();
191        let tflops = vram_bytes as f32 / (1024.0 * 1024.0 * 1024.0) * 0.4;
192        return vec![GpuInfo {
193            name: chip,
194            vram_bytes,
195            tflops,
196        }];
197    }
198
199    // Fallback: CPU with system RAM
200    #[allow(unreachable_code)]
201    {
202        let name = format!("CPU ({})", std::env::consts::ARCH);
203        let vram_bytes = detect_system_memory();
204        vec![GpuInfo {
205            name,
206            vram_bytes,
207            tflops: 2.0,
208        }]
209    }
210}
211
212/// Detect the system hostname.
213pub fn detect_hostname() -> String {
214    if let Ok(output) = std::process::Command::new("hostname").output() {
215        if output.status.success() {
216            let h = String::from_utf8_lossy(&output.stdout).trim().to_string();
217            if !h.is_empty() {
218                return h;
219            }
220        }
221    }
222    "unknown".to_string()
223}
224
225/// Detect the compute backend description for this node.
226pub fn detect_backend() -> String {
227    #[cfg(feature = "cuda")]
228    {
229        if let Some(ver) = detect_cuda_version() {
230            return ver;
231        }
232    }
233    #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
234    {
235        if let Some(chip) = detect_apple_chip() {
236            return chip;
237        }
238        return "Metal".to_string();
239    }
240    #[allow(unreachable_code)]
241    "CPU".to_string()
242}
243
244/// Detect the CUDA toolkit version via nvcc.
245pub fn detect_cuda_version() -> Option<String> {
246    let output = std::process::Command::new("nvcc")
247        .arg("--version")
248        .output()
249        .ok()?;
250    if !output.status.success() {
251        return None;
252    }
253    let stdout = String::from_utf8_lossy(&output.stdout);
254    // Look for "release X.Y" in output like "Cuda compilation tools, release 12.4, V12.4.131"
255    for line in stdout.lines() {
256        if let Some(idx) = line.find("release ") {
257            let rest = &line[idx + 8..];
258            let ver = rest.split(',').next().unwrap_or(rest).trim();
259            if !ver.is_empty() {
260                return Some(format!("CUDA {}", ver));
261            }
262        }
263    }
264    None
265}
266
267/// Detect the Apple chip model (e.g. "Apple M2 Max" on macOS, "iPad8,3" on iOS).
268#[cfg(any(target_os = "macos", target_os = "ios"))]
269fn detect_apple_chip() -> Option<String> {
270    // Try machdep.cpu.brand_string first (macOS), then hw.machine (iOS).
271    for key in &["machdep.cpu.brand_string", "hw.machine"] {
272        if let Some(val) = sysctl_string(key) {
273            return Some(val);
274        }
275    }
276    None
277}
278
279/// Read a sysctl string value using the C API (works in iOS sandbox unlike subprocess).
280#[cfg(any(target_os = "macos", target_os = "ios"))]
281fn sysctl_string(name: &str) -> Option<String> {
282    use std::ffi::CString;
283    let c_name = CString::new(name).ok()?;
284    let mut len: usize = 0;
285    // First call to get buffer size
286    let ret = unsafe {
287        libc::sysctlbyname(c_name.as_ptr(), std::ptr::null_mut(), &mut len, std::ptr::null_mut(), 0)
288    };
289    if ret != 0 || len == 0 {
290        return None;
291    }
292    let mut buf = vec![0u8; len];
293    let ret = unsafe {
294        libc::sysctlbyname(c_name.as_ptr(), buf.as_mut_ptr() as *mut _, &mut len, std::ptr::null_mut(), 0)
295    };
296    if ret != 0 {
297        return None;
298    }
299    // Strip trailing null bytes
300    while buf.last() == Some(&0) {
301        buf.pop();
302    }
303    String::from_utf8(buf).ok().filter(|s| !s.is_empty())
304}
305
306/// Read a sysctl u64 value using the C API.
307#[cfg(any(target_os = "macos", target_os = "ios"))]
308fn sysctl_u64(name: &str) -> Option<u64> {
309    use std::ffi::CString;
310    let c_name = CString::new(name).ok()?;
311    let mut value: u64 = 0;
312    let mut len = std::mem::size_of::<u64>();
313    let ret = unsafe {
314        libc::sysctlbyname(c_name.as_ptr(), &mut value as *mut u64 as *mut _, &mut len, std::ptr::null_mut(), 0)
315    };
316    if ret == 0 && value > 0 {
317        Some(value)
318    } else {
319        None
320    }
321}
322
323/// Detect total system memory in bytes.
324fn detect_system_memory() -> u64 {
325    // On macOS/iOS, use sysctl C API (works in iOS sandbox)
326    #[cfg(any(target_os = "macos", target_os = "ios"))]
327    {
328        if let Some(bytes) = sysctl_u64("hw.memsize") {
329            return bytes;
330        }
331    }
332
333    // On Linux and Android, read /proc/meminfo
334    // Note: Android uses target_os = "android", not "linux", so both are listed.
335    #[cfg(any(target_os = "linux", target_os = "android"))]
336    {
337        if let Ok(contents) = std::fs::read_to_string("/proc/meminfo") {
338            for line in contents.lines() {
339                if let Some(rest) = line.strip_prefix("MemTotal:") {
340                    let rest = rest.trim();
341                    if let Some(kb_str) = rest.strip_suffix("kB") {
342                        if let Ok(kb) = kb_str.trim().parse::<u64>() {
343                            return kb * 1024;
344                        }
345                    }
346                }
347            }
348        }
349    }
350
351    // Fallback to memory_stats
352    memory_stats::memory_stats()
353        .map(|s| s.physical_mem as u64)
354        .unwrap_or(0)
355}
356
357// ── Discovery packet format ────────────────────────────────────────────────
358
359/// A discovery query broadcast by the master.
360#[derive(Serialize, Deserialize)]
361struct DiscoveryQuery {
362    cluster_hash: String,
363}
364
365/// A discovery response sent by workers (unicast back to master).
366#[derive(Serialize, Deserialize)]
367struct DiscoveryResponse {
368    cluster_hash: String,
369    worker_name: String,
370    port: u16,
371    gpus: Vec<GpuInfo>,
372    #[serde(default)]
373    backend: String,
374    #[serde(default)]
375    hostname: String,
376    #[serde(default)]
377    os: String,
378}
379
380fn encode_packet(payload: &[u8]) -> Vec<u8> {
381    let mut pkt = Vec::with_capacity(4 + payload.len());
382    pkt.extend_from_slice(MAGIC);
383    pkt.extend_from_slice(payload);
384    pkt
385}
386
387fn decode_packet(data: &[u8]) -> Option<&[u8]> {
388    if data.len() > 4 && data[..4] == *MAGIC {
389        Some(&data[4..])
390    } else {
391        None
392    }
393}
394
395// ── Worker advertisement (listen for queries, respond) ─────────────────────
396
397/// Handle for a running discovery listener.
398/// Must be kept alive for the worker to respond to discovery queries.
399/// Dropping this handle signals the listener thread to exit.
400pub struct DiscoveryListener {
401    stop: std::sync::Arc<std::sync::atomic::AtomicBool>,
402    _handle: std::thread::JoinHandle<()>,
403}
404
405impl Drop for DiscoveryListener {
406    fn drop(&mut self) {
407        // Signal the listener thread to exit on its next recv_from timeout (~1s).
408        self.stop
409            .store(true, std::sync::atomic::Ordering::SeqCst);
410    }
411}
412
413/// Start listening for discovery queries and responding with worker info.
414///
415/// Spawns a background thread. Returns a handle that must be kept alive.
416pub fn advertise_worker(
417    worker_name: &str,
418    port: u16,
419    cluster_key: &str,
420    gpus: &[GpuInfo],
421) -> Result<DiscoveryListener> {
422    let hash = cluster_hash(cluster_key);
423    let hostname = detect_hostname();
424    let backend = detect_backend();
425    let response = DiscoveryResponse {
426        cluster_hash: hash.clone(),
427        worker_name: worker_name.to_string(),
428        port,
429        gpus: gpus.to_vec(),
430        backend,
431        hostname,
432        os: std::env::consts::OS.to_string(),
433    };
434    let response_json = serde_json::to_vec(&response)?;
435    let response_pkt = encode_packet(&response_json);
436
437    let sock = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, DISCOVERY_PORT))
438        .map_err(|e| anyhow!("failed to bind discovery UDP socket on port {}: {}", DISCOVERY_PORT, e))?;
439    sock.set_broadcast(true)?;
440    sock.set_read_timeout(Some(Duration::from_secs(1)))?;
441
442    log::info!(
443        "listening for discovery queries on UDP port {}",
444        DISCOVERY_PORT
445    );
446
447    let stop = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
448    let stop_thread = stop.clone();
449
450    let handle = std::thread::spawn(move || {
451        let mut buf = [0u8; 4096];
452        while !stop_thread.load(std::sync::atomic::Ordering::Relaxed) {
453            match sock.recv_from(&mut buf) {
454                Ok((len, src)) => {
455                    if let Some(payload) = decode_packet(&buf[..len]) {
456                        if let Ok(query) = serde_json::from_slice::<DiscoveryQuery>(payload) {
457                            if query.cluster_hash == hash {
458                                // Respond directly to the master
459                                if let Err(e) = sock.send_to(&response_pkt, src) {
460                                    log::warn!("failed to send discovery response to {}: {}", src, e);
461                                }
462                            }
463                        }
464                    }
465                }
466                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock
467                    || e.kind() == std::io::ErrorKind::TimedOut =>
468                {
469                    // Normal timeout — loop and check stop flag
470                }
471                Err(e) => {
472                    log::warn!("discovery listener error: {}", e);
473                    break;
474                }
475            }
476        }
477        log::debug!("discovery listener thread exited");
478    });
479
480    Ok(DiscoveryListener { stop, _handle: handle })
481}
482
483// ── Interface enumeration ──────────────────────────────────────────────────
484
485/// Get directed broadcast addresses for all local IPv4 interfaces.
486/// Falls back to 255.255.255.255 if enumeration fails.
487fn get_broadcast_addresses() -> Vec<Ipv4Addr> {
488    let mut addrs = Vec::new();
489
490    // Parse `ip addr` on Linux or `ifconfig` on macOS to find broadcast addresses
491    #[cfg(target_os = "linux")]
492    {
493        if let Ok(output) = std::process::Command::new("ip")
494            .args(["-4", "addr", "show"])
495            .output()
496        {
497            let stdout = String::from_utf8_lossy(&output.stdout);
498            for line in stdout.lines() {
499                // Lines like: "    inet 192.168.50.199/24 brd 192.168.50.255 scope global ..."
500                if let Some(brd_idx) = line.find("brd ") {
501                    let rest = &line[brd_idx + 4..];
502                    if let Some(end) = rest.find(' ') {
503                        if let Ok(ip) = rest[..end].parse::<Ipv4Addr>() {
504                            if !ip.is_loopback() {
505                                addrs.push(ip);
506                            }
507                        }
508                    }
509                }
510            }
511        }
512    }
513
514    #[cfg(target_os = "macos")]
515    {
516        if let Ok(output) = std::process::Command::new("ifconfig").output() {
517            let stdout = String::from_utf8_lossy(&output.stdout);
518            for line in stdout.lines() {
519                // Lines like: "	inet 192.168.50.32 netmask 0xffffff00 broadcast 192.168.50.255"
520                if let Some(brd_idx) = line.find("broadcast ") {
521                    let rest = &line[brd_idx + 10..];
522                    let addr_str = rest.split_whitespace().next().unwrap_or("");
523                    if let Ok(ip) = addr_str.parse::<Ipv4Addr>() {
524                        if !ip.is_loopback() {
525                            addrs.push(ip);
526                        }
527                    }
528                }
529            }
530        }
531    }
532
533    // Always include the limited broadcast as a fallback
534    addrs.push(Ipv4Addr::BROADCAST);
535    addrs.dedup();
536    addrs
537}
538
539// ── Master browsing (broadcast query, collect responses) ──────────────────
540
541/// Browse for workers on the network matching the given cluster key.
542///
543/// Sends periodic UDP broadcast queries, collects responses until timeout.
544pub async fn discover_workers(
545    cluster_key: &str,
546    timeout: Duration,
547) -> Result<Vec<DiscoveredWorker>> {
548    let expected_hash = cluster_hash(cluster_key);
549
550    log::info!(
551        "discovering workers (timeout: {}s)...",
552        timeout.as_secs()
553    );
554
555    let workers = tokio::task::spawn_blocking(move || {
556        let sock = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
557            .map_err(|e| anyhow!("failed to bind discovery socket: {}", e))?;
558        sock.set_broadcast(true)?;
559        sock.set_read_timeout(Some(Duration::from_millis(500)))?;
560
561        let query = DiscoveryQuery {
562            cluster_hash: expected_hash.clone(),
563        };
564        let query_json = serde_json::to_vec(&query)?;
565        let query_pkt = encode_packet(&query_json);
566
567        // Collect broadcast addresses: directed subnet broadcasts are more
568        // reliable than 255.255.255.255 which may not cross interfaces.
569        let broadcast_addrs = get_broadcast_addresses();
570
571        let mut workers: HashMap<String, DiscoveredWorker> = HashMap::new();
572        let deadline = std::time::Instant::now() + timeout;
573        let mut last_query = std::time::Instant::now() - Duration::from_secs(10);
574        let query_interval = Duration::from_secs(1);
575        let mut buf = [0u8; 65535];
576
577        loop {
578            let now = std::time::Instant::now();
579            if now >= deadline {
580                break;
581            }
582
583            // Send periodic broadcast queries to all known broadcast addresses
584            if now.duration_since(last_query) >= query_interval {
585                for addr in &broadcast_addrs {
586                    let dest = SocketAddr::V4(SocketAddrV4::new(*addr, DISCOVERY_PORT));
587                    let _ = sock.send_to(&query_pkt, dest);
588                }
589                last_query = now;
590            }
591
592            // Listen for responses
593            match sock.recv_from(&mut buf) {
594                Ok((len, src)) => {
595                    if let Some(payload) = decode_packet(&buf[..len]) {
596                        if let Ok(resp) = serde_json::from_slice::<DiscoveryResponse>(payload) {
597                            if resp.cluster_hash != expected_hash {
598                                continue;
599                            }
600
601                            let src_ip = match src {
602                                SocketAddr::V4(a) => a.ip().to_string(),
603                                SocketAddr::V6(a) => a.ip().to_string(),
604                            };
605                            let host = format!("{}:{}", src_ip, resp.port);
606
607                            if !workers.contains_key(&resp.worker_name) {
608                                log::info!(
609                                    "discovered worker '{}' at {} with {} GPU(s)",
610                                    &resp.worker_name,
611                                    &host,
612                                    resp.gpus.len()
613                                );
614
615                                for gpu in &resp.gpus {
616                                    log::info!(
617                                        "  {} — {} (~{:.1} TFLOPS)",
618                                        &gpu.name,
619                                        human_bytes::human_bytes(gpu.vram_bytes as f64),
620                                        gpu.tflops
621                                    );
622                                }
623
624                                workers.insert(resp.worker_name.clone(), DiscoveredWorker {
625                                    name: resp.worker_name,
626                                    host,
627                                    port: resp.port,
628                                    gpus: resp.gpus,
629                                    backend: resp.backend,
630                                    hostname: resp.hostname,
631                                    os: resp.os,
632                                });
633                            }
634                        }
635                    }
636                }
637                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock
638                    || e.kind() == std::io::ErrorKind::TimedOut =>
639                {
640                    // Normal timeout, loop again
641                }
642                Err(e) => {
643                    log::warn!("discovery recv error: {}", e);
644                }
645            }
646        }
647
648        Ok::<_, anyhow::Error>(workers)
649    }).await??;
650
651    log::info!("discovery complete: {} worker(s) found", workers.len());
652    Ok(workers.into_values().collect())
653}