1use std::collections::HashMap;
4use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket};
5use std::time::Duration;
6
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use speedy::{Readable, Writable};
11
12const DISCOVERY_PORT: u16 = 10127;
14
15const MAGIC: &[u8; 4] = b"CAKE";
17
18pub const DEFAULT_DISCOVERY_TIMEOUT: Duration = Duration::from_secs(10);
20
21#[derive(Debug, Clone, Serialize, Deserialize, Readable, Writable)]
23pub struct GpuInfo {
24 pub name: String,
25 pub vram_bytes: u64,
26 #[serde(default)]
27 pub tflops: f32,
28}
29
30#[derive(Debug, Clone)]
32pub struct DiscoveredWorker {
33 pub name: String,
34 pub host: String,
35 pub port: u16,
36 pub gpus: Vec<GpuInfo>,
37 pub backend: String,
38 pub hostname: String,
39 pub os: String,
40}
41
42impl DiscoveredWorker {
43 pub fn total_vram(&self) -> u64 {
45 self.gpus.iter().map(|g| g.vram_bytes).sum()
46 }
47
48 pub fn max_layers_for_size(&self, layer_size_bytes: u64) -> usize {
57 if layer_size_bytes == 0 || self.gpus.is_empty() {
58 return usize::MAX;
59 }
60 self.gpus
61 .iter()
62 .map(|g| {
63 let name_lower = g.name.to_lowercase();
64 let is_cpu = name_lower.starts_with("cpu");
65 let is_unified = name_lower.contains("apple");
66 let usable = if is_cpu {
67 let reserve = (g.vram_bytes as f64 * 0.20) as u64;
71 g.vram_bytes.saturating_sub(reserve)
72 } else if is_unified {
73 let min_reserve = 6u64 * 1024 * 1024 * 1024;
78 let pct_reserve = (g.vram_bytes as f64 * 0.28) as u64;
79 let os_reserve = pct_reserve.max(min_reserve);
80 g.vram_bytes.saturating_sub(os_reserve)
81 } else {
82 let min_reserve = 768u64 * 1024 * 1024;
87 let pct_reserve = (g.vram_bytes as f64 * 0.05) as u64;
88 let reserve = pct_reserve.max(min_reserve);
89 g.vram_bytes.saturating_sub(reserve)
90 };
91 (usable / layer_size_bytes) as usize
92 })
93 .sum()
94 }
95
96 pub fn total_tflops(&self) -> f64 {
99 let reported: f64 = self.gpus.iter().map(|g| g.tflops as f64).sum();
100 if reported > 0.0 {
101 return reported;
102 }
103 self.gpus
105 .iter()
106 .map(|g| {
107 let vram_gb = g.vram_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
108 let name_lower = g.name.to_lowercase();
109 if name_lower.contains("nvidia")
110 || name_lower.contains("geforce")
111 || name_lower.contains("rtx")
112 || name_lower.contains("gtx")
113 || name_lower.contains("tesla")
114 {
115 vram_gb * 3.0 } else if name_lower.contains("apple") || name_lower.contains("silicon") {
117 vram_gb * 0.4 } else {
119 2.0 }
121 })
122 .sum()
123 }
124}
125
126pub fn cluster_hash(cluster_key: &str) -> String {
128 let mut hasher = Sha256::new();
129 hasher.update(cluster_key.as_bytes());
130 let result = hasher.finalize();
131 hex::encode(&result[..4])
132}
133
134pub fn detect_gpus() -> Vec<GpuInfo> {
140 #[cfg(feature = "cuda")]
142 {
143 if let Ok(output) = std::process::Command::new("nvidia-smi")
144 .args([
145 "--query-gpu=name,memory.total,clocks.max.graphics",
146 "--format=csv,noheader,nounits",
147 ])
148 .output()
149 {
150 if output.status.success() {
151 let stdout = String::from_utf8_lossy(&output.stdout);
152 let gpus: Vec<GpuInfo> = stdout
153 .lines()
154 .filter_map(|line| {
155 let parts: Vec<&str> = line.splitn(3, ',').collect();
156 if parts.len() >= 2 {
157 let name = parts[0].trim().to_string();
158 let vram_mb: u64 = parts[1].trim().parse().ok()?;
159 let vram_gb = vram_mb as f32 / 1024.0;
160 let tflops = if parts.len() >= 3 {
162 let clock_mhz: f32 =
163 parts[2].trim().parse().unwrap_or(1500.0);
164 vram_gb * (clock_mhz / 1000.0) * 1.5
165 } else {
166 vram_gb * 3.0
167 };
168 Some(GpuInfo {
169 name,
170 vram_bytes: vram_mb * 1024 * 1024,
171 tflops,
172 })
173 } else {
174 None
175 }
176 })
177 .collect();
178
179 if !gpus.is_empty() {
180 return gpus;
181 }
182 }
183 }
184 }
185
186 #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
188 {
189 let chip = detect_apple_chip().unwrap_or_else(|| format!("Apple ({})", std::env::consts::ARCH));
190 let vram_bytes = detect_system_memory();
191 let tflops = vram_bytes as f32 / (1024.0 * 1024.0 * 1024.0) * 0.4;
192 return vec![GpuInfo {
193 name: chip,
194 vram_bytes,
195 tflops,
196 }];
197 }
198
199 #[allow(unreachable_code)]
201 {
202 let name = format!("CPU ({})", std::env::consts::ARCH);
203 let vram_bytes = detect_system_memory();
204 vec![GpuInfo {
205 name,
206 vram_bytes,
207 tflops: 2.0,
208 }]
209 }
210}
211
212pub fn detect_hostname() -> String {
214 if let Ok(output) = std::process::Command::new("hostname").output() {
215 if output.status.success() {
216 let h = String::from_utf8_lossy(&output.stdout).trim().to_string();
217 if !h.is_empty() {
218 return h;
219 }
220 }
221 }
222 "unknown".to_string()
223}
224
225pub fn detect_backend() -> String {
227 #[cfg(feature = "cuda")]
228 {
229 if let Some(ver) = detect_cuda_version() {
230 return ver;
231 }
232 }
233 #[cfg(all(any(target_os = "macos", target_os = "ios"), feature = "metal"))]
234 {
235 if let Some(chip) = detect_apple_chip() {
236 return chip;
237 }
238 return "Metal".to_string();
239 }
240 #[allow(unreachable_code)]
241 "CPU".to_string()
242}
243
244pub fn detect_cuda_version() -> Option<String> {
246 let output = std::process::Command::new("nvcc")
247 .arg("--version")
248 .output()
249 .ok()?;
250 if !output.status.success() {
251 return None;
252 }
253 let stdout = String::from_utf8_lossy(&output.stdout);
254 for line in stdout.lines() {
256 if let Some(idx) = line.find("release ") {
257 let rest = &line[idx + 8..];
258 let ver = rest.split(',').next().unwrap_or(rest).trim();
259 if !ver.is_empty() {
260 return Some(format!("CUDA {}", ver));
261 }
262 }
263 }
264 None
265}
266
267#[cfg(any(target_os = "macos", target_os = "ios"))]
269fn detect_apple_chip() -> Option<String> {
270 for key in &["machdep.cpu.brand_string", "hw.machine"] {
272 if let Some(val) = sysctl_string(key) {
273 return Some(val);
274 }
275 }
276 None
277}
278
279#[cfg(any(target_os = "macos", target_os = "ios"))]
281fn sysctl_string(name: &str) -> Option<String> {
282 use std::ffi::CString;
283 let c_name = CString::new(name).ok()?;
284 let mut len: usize = 0;
285 let ret = unsafe {
287 libc::sysctlbyname(c_name.as_ptr(), std::ptr::null_mut(), &mut len, std::ptr::null_mut(), 0)
288 };
289 if ret != 0 || len == 0 {
290 return None;
291 }
292 let mut buf = vec![0u8; len];
293 let ret = unsafe {
294 libc::sysctlbyname(c_name.as_ptr(), buf.as_mut_ptr() as *mut _, &mut len, std::ptr::null_mut(), 0)
295 };
296 if ret != 0 {
297 return None;
298 }
299 while buf.last() == Some(&0) {
301 buf.pop();
302 }
303 String::from_utf8(buf).ok().filter(|s| !s.is_empty())
304}
305
306#[cfg(any(target_os = "macos", target_os = "ios"))]
308fn sysctl_u64(name: &str) -> Option<u64> {
309 use std::ffi::CString;
310 let c_name = CString::new(name).ok()?;
311 let mut value: u64 = 0;
312 let mut len = std::mem::size_of::<u64>();
313 let ret = unsafe {
314 libc::sysctlbyname(c_name.as_ptr(), &mut value as *mut u64 as *mut _, &mut len, std::ptr::null_mut(), 0)
315 };
316 if ret == 0 && value > 0 {
317 Some(value)
318 } else {
319 None
320 }
321}
322
323fn detect_system_memory() -> u64 {
325 #[cfg(any(target_os = "macos", target_os = "ios"))]
327 {
328 if let Some(bytes) = sysctl_u64("hw.memsize") {
329 return bytes;
330 }
331 }
332
333 #[cfg(any(target_os = "linux", target_os = "android"))]
336 {
337 if let Ok(contents) = std::fs::read_to_string("/proc/meminfo") {
338 for line in contents.lines() {
339 if let Some(rest) = line.strip_prefix("MemTotal:") {
340 let rest = rest.trim();
341 if let Some(kb_str) = rest.strip_suffix("kB") {
342 if let Ok(kb) = kb_str.trim().parse::<u64>() {
343 return kb * 1024;
344 }
345 }
346 }
347 }
348 }
349 }
350
351 memory_stats::memory_stats()
353 .map(|s| s.physical_mem as u64)
354 .unwrap_or(0)
355}
356
357#[derive(Serialize, Deserialize)]
361struct DiscoveryQuery {
362 cluster_hash: String,
363}
364
365#[derive(Serialize, Deserialize)]
367struct DiscoveryResponse {
368 cluster_hash: String,
369 worker_name: String,
370 port: u16,
371 gpus: Vec<GpuInfo>,
372 #[serde(default)]
373 backend: String,
374 #[serde(default)]
375 hostname: String,
376 #[serde(default)]
377 os: String,
378}
379
380fn encode_packet(payload: &[u8]) -> Vec<u8> {
381 let mut pkt = Vec::with_capacity(4 + payload.len());
382 pkt.extend_from_slice(MAGIC);
383 pkt.extend_from_slice(payload);
384 pkt
385}
386
387fn decode_packet(data: &[u8]) -> Option<&[u8]> {
388 if data.len() > 4 && data[..4] == *MAGIC {
389 Some(&data[4..])
390 } else {
391 None
392 }
393}
394
395pub struct DiscoveryListener {
401 stop: std::sync::Arc<std::sync::atomic::AtomicBool>,
402 _handle: std::thread::JoinHandle<()>,
403}
404
405impl Drop for DiscoveryListener {
406 fn drop(&mut self) {
407 self.stop
409 .store(true, std::sync::atomic::Ordering::SeqCst);
410 }
411}
412
413pub fn advertise_worker(
417 worker_name: &str,
418 port: u16,
419 cluster_key: &str,
420 gpus: &[GpuInfo],
421) -> Result<DiscoveryListener> {
422 let hash = cluster_hash(cluster_key);
423 let hostname = detect_hostname();
424 let backend = detect_backend();
425 let response = DiscoveryResponse {
426 cluster_hash: hash.clone(),
427 worker_name: worker_name.to_string(),
428 port,
429 gpus: gpus.to_vec(),
430 backend,
431 hostname,
432 os: std::env::consts::OS.to_string(),
433 };
434 let response_json = serde_json::to_vec(&response)?;
435 let response_pkt = encode_packet(&response_json);
436
437 let sock = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, DISCOVERY_PORT))
438 .map_err(|e| anyhow!("failed to bind discovery UDP socket on port {}: {}", DISCOVERY_PORT, e))?;
439 sock.set_broadcast(true)?;
440 sock.set_read_timeout(Some(Duration::from_secs(1)))?;
441
442 log::info!(
443 "listening for discovery queries on UDP port {}",
444 DISCOVERY_PORT
445 );
446
447 let stop = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
448 let stop_thread = stop.clone();
449
450 let handle = std::thread::spawn(move || {
451 let mut buf = [0u8; 4096];
452 while !stop_thread.load(std::sync::atomic::Ordering::Relaxed) {
453 match sock.recv_from(&mut buf) {
454 Ok((len, src)) => {
455 if let Some(payload) = decode_packet(&buf[..len]) {
456 if let Ok(query) = serde_json::from_slice::<DiscoveryQuery>(payload) {
457 if query.cluster_hash == hash {
458 if let Err(e) = sock.send_to(&response_pkt, src) {
460 log::warn!("failed to send discovery response to {}: {}", src, e);
461 }
462 }
463 }
464 }
465 }
466 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock
467 || e.kind() == std::io::ErrorKind::TimedOut =>
468 {
469 }
471 Err(e) => {
472 log::warn!("discovery listener error: {}", e);
473 break;
474 }
475 }
476 }
477 log::debug!("discovery listener thread exited");
478 });
479
480 Ok(DiscoveryListener { stop, _handle: handle })
481}
482
483fn get_broadcast_addresses() -> Vec<Ipv4Addr> {
488 let mut addrs = Vec::new();
489
490 #[cfg(target_os = "linux")]
492 {
493 if let Ok(output) = std::process::Command::new("ip")
494 .args(["-4", "addr", "show"])
495 .output()
496 {
497 let stdout = String::from_utf8_lossy(&output.stdout);
498 for line in stdout.lines() {
499 if let Some(brd_idx) = line.find("brd ") {
501 let rest = &line[brd_idx + 4..];
502 if let Some(end) = rest.find(' ') {
503 if let Ok(ip) = rest[..end].parse::<Ipv4Addr>() {
504 if !ip.is_loopback() {
505 addrs.push(ip);
506 }
507 }
508 }
509 }
510 }
511 }
512 }
513
514 #[cfg(target_os = "macos")]
515 {
516 if let Ok(output) = std::process::Command::new("ifconfig").output() {
517 let stdout = String::from_utf8_lossy(&output.stdout);
518 for line in stdout.lines() {
519 if let Some(brd_idx) = line.find("broadcast ") {
521 let rest = &line[brd_idx + 10..];
522 let addr_str = rest.split_whitespace().next().unwrap_or("");
523 if let Ok(ip) = addr_str.parse::<Ipv4Addr>() {
524 if !ip.is_loopback() {
525 addrs.push(ip);
526 }
527 }
528 }
529 }
530 }
531 }
532
533 addrs.push(Ipv4Addr::BROADCAST);
535 addrs.dedup();
536 addrs
537}
538
539pub async fn discover_workers(
545 cluster_key: &str,
546 timeout: Duration,
547) -> Result<Vec<DiscoveredWorker>> {
548 let expected_hash = cluster_hash(cluster_key);
549
550 log::info!(
551 "discovering workers (timeout: {}s)...",
552 timeout.as_secs()
553 );
554
555 let workers = tokio::task::spawn_blocking(move || {
556 let sock = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
557 .map_err(|e| anyhow!("failed to bind discovery socket: {}", e))?;
558 sock.set_broadcast(true)?;
559 sock.set_read_timeout(Some(Duration::from_millis(500)))?;
560
561 let query = DiscoveryQuery {
562 cluster_hash: expected_hash.clone(),
563 };
564 let query_json = serde_json::to_vec(&query)?;
565 let query_pkt = encode_packet(&query_json);
566
567 let broadcast_addrs = get_broadcast_addresses();
570
571 let mut workers: HashMap<String, DiscoveredWorker> = HashMap::new();
572 let deadline = std::time::Instant::now() + timeout;
573 let mut last_query = std::time::Instant::now() - Duration::from_secs(10);
574 let query_interval = Duration::from_secs(1);
575 let mut buf = [0u8; 65535];
576
577 loop {
578 let now = std::time::Instant::now();
579 if now >= deadline {
580 break;
581 }
582
583 if now.duration_since(last_query) >= query_interval {
585 for addr in &broadcast_addrs {
586 let dest = SocketAddr::V4(SocketAddrV4::new(*addr, DISCOVERY_PORT));
587 let _ = sock.send_to(&query_pkt, dest);
588 }
589 last_query = now;
590 }
591
592 match sock.recv_from(&mut buf) {
594 Ok((len, src)) => {
595 if let Some(payload) = decode_packet(&buf[..len]) {
596 if let Ok(resp) = serde_json::from_slice::<DiscoveryResponse>(payload) {
597 if resp.cluster_hash != expected_hash {
598 continue;
599 }
600
601 let src_ip = match src {
602 SocketAddr::V4(a) => a.ip().to_string(),
603 SocketAddr::V6(a) => a.ip().to_string(),
604 };
605 let host = format!("{}:{}", src_ip, resp.port);
606
607 if !workers.contains_key(&resp.worker_name) {
608 log::info!(
609 "discovered worker '{}' at {} with {} GPU(s)",
610 &resp.worker_name,
611 &host,
612 resp.gpus.len()
613 );
614
615 for gpu in &resp.gpus {
616 log::info!(
617 " {} — {} (~{:.1} TFLOPS)",
618 &gpu.name,
619 human_bytes::human_bytes(gpu.vram_bytes as f64),
620 gpu.tflops
621 );
622 }
623
624 workers.insert(resp.worker_name.clone(), DiscoveredWorker {
625 name: resp.worker_name,
626 host,
627 port: resp.port,
628 gpus: resp.gpus,
629 backend: resp.backend,
630 hostname: resp.hostname,
631 os: resp.os,
632 });
633 }
634 }
635 }
636 }
637 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock
638 || e.kind() == std::io::ErrorKind::TimedOut =>
639 {
640 }
642 Err(e) => {
643 log::warn!("discovery recv error: {}", e);
644 }
645 }
646 }
647
648 Ok::<_, anyhow::Error>(workers)
649 }).await??;
650
651 log::info!("discovery complete: {} worker(s) found", workers.len());
652 Ok(workers.into_values().collect())
653}