Skip to main content

hanzo_engine/
distributed.rs

1use anyhow::Context;
2use core::ffi::c_char;
3use hanzo_ml::{DType, Device};
4pub use hanzo_quant::distributed::{use_nccl, use_ring};
5use hanzo_quant::{RingConfig, ShardedVarBuilder};
6use interprocess::local_socket::traits::{Listener, Stream};
7use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
8use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::net::TcpStream;
14use std::process::Command;
15use std::str::FromStr;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tokio::sync::mpsc::Sender;
19use tracing::info;
20
21use crate::device_map::DeviceMapper;
22use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
23use crate::utils::varbuilder_utils::{self, DeviceForLoadTensor};
24use crate::{DeviceMapSetting, IsqOrganization, ModelPaths, Request};
25
26pub(crate) const IS_DAEMON_FLAG: &str = "__HANZO_DAEMON_INTERNAL";
27
28pub fn is_daemon() -> bool {
29    if cfg!(feature = "cuda") && !cfg!(feature = "ring") {
30        std::env::var(IS_DAEMON_FLAG).is_ok()
31    } else if use_ring() {
32        !RingConfig::load().is_master_rank()
33    } else {
34        false
35    }
36}
37
38pub fn nccl_daemon_replicator(request_sender: Sender<Request>) {
39    use std::io::BufRead;
40    use std::io::BufReader;
41
42    std::thread::spawn(move || {
43        let rt = Runtime::new().unwrap();
44        rt.block_on(async move {
45            use interprocess::local_socket::traits::Stream;
46            use interprocess::local_socket::Stream as LocalStream;
47
48            loop {
49                let name = match ipc_name() {
50                    Ok(name) => name,
51                    Err(e) => {
52                        tracing::error!("Failed to get IPC name in daemon: {e}");
53                        continue;
54                    }
55                };
56                if let Ok(stream) = LocalStream::connect(name) {
57                    let mut reader = BufReader::new(stream);
58                    let mut buf = String::new();
59                    if let Err(e) = reader.read_line(&mut buf) {
60                        tracing::error!("Failed to read line from IPC stream: {e}");
61                        continue;
62                    }
63                    let mut req: Request = match serde_json::from_str(&buf) {
64                        Ok(req) => req,
65                        Err(e) => {
66                            tracing::error!("Failed to parse request JSON: {e}");
67                            continue;
68                        }
69                    };
70
71                    req = match req {
72                        Request::ReIsq(x) => Request::ReIsq(x),
73                        Request::Terminate => Request::Terminate,
74                        Request::Detokenize(mut x) => {
75                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
76                            x.response = sender;
77                            let req = Request::Detokenize(x);
78
79                            if request_sender.send(req).await.is_err() {
80                                tracing::error!("Daemon channel closed for Detokenize request");
81                                continue;
82                            }
83                            match receiver.recv().await {
84                                Some(resp) => {
85                                    if let Err(e) = resp {
86                                        tracing::error!("Detokenize response error: {e}");
87                                    }
88                                }
89                                None => tracing::error!("Detokenize response channel closed"),
90                            }
91                            continue;
92                        }
93                        Request::Tokenize(mut x) => {
94                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
95                            x.response = sender;
96                            let req = Request::Tokenize(x);
97
98                            if request_sender.send(req).await.is_err() {
99                                tracing::error!("Daemon channel closed for Tokenize request");
100                                continue;
101                            }
102                            match receiver.recv().await {
103                                Some(resp) => {
104                                    if let Err(e) = resp {
105                                        tracing::error!("Tokenize response error: {e}");
106                                    }
107                                }
108                                None => tracing::error!("Tokenize response channel closed"),
109                            }
110                            continue;
111                        }
112                        Request::Normal(mut x) => {
113                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
114                            x.is_streaming = false;
115                            x.response = sender;
116                            let req = Request::Normal(x);
117
118                            if request_sender.send(req).await.is_err() {
119                                tracing::error!("Daemon channel closed for Normal request");
120                                continue;
121                            }
122                            match receiver.recv().await {
123                                Some(resp) => {
124                                    if let Err(e) = resp.as_result() {
125                                        tracing::error!("Normal response error: {e}");
126                                    }
127                                }
128                                None => tracing::error!("Normal response channel closed"),
129                            }
130                            continue;
131                        }
132                        Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
133                    };
134
135                    if request_sender.send(req).await.is_err() {
136                        tracing::error!("Daemon channel closed for request");
137                    }
138                }
139            }
140        });
141    });
142}
143
144pub fn ring_daemon_replicator(request_sender: Sender<Request>) {
145    use std::io::BufRead;
146    use std::io::BufReader;
147
148    let ring_config = RingConfig::load();
149
150    let master_ip = ring_config.master_ip();
151    let master_port = ring_config.master_port;
152    std::thread::spawn(move || {
153        let rt = Runtime::new().unwrap();
154        rt.block_on(async move {
155            loop {
156                if let Ok(stream) = TcpStream::connect(format!("{master_ip}:{master_port}")) {
157                    let mut reader = BufReader::new(stream);
158                    let mut buf = String::new();
159                    reader.read_line(&mut buf).unwrap();
160                    let mut req: Request = serde_json::from_str(&buf).unwrap();
161
162                    req = match req {
163                        Request::ReIsq(x) => Request::ReIsq(x),
164                        Request::Terminate => Request::Terminate,
165                        Request::Detokenize(mut x) => {
166                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
167                            x.response = sender;
168                            let req = Request::Detokenize(x);
169
170                            request_sender.send(req).await.unwrap();
171                            let resp = receiver.recv().await.unwrap();
172                            resp.unwrap();
173                            continue;
174                        }
175                        Request::Tokenize(mut x) => {
176                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
177                            x.response = sender;
178                            let req = Request::Tokenize(x);
179
180                            request_sender.send(req).await.unwrap();
181                            let resp = receiver.recv().await.unwrap();
182                            resp.unwrap();
183                            continue;
184                        }
185                        Request::Normal(mut x) => {
186                            let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
187                            x.is_streaming = false;
188                            x.response = sender;
189                            let req = Request::Normal(x);
190
191                            request_sender.send(req).await.unwrap();
192                            loop {
193                                let resp = receiver.recv().await.unwrap();
194                                match resp {
195                                    crate::Response::AgenticToolCallProgress { .. } => continue,
196                                    crate::Response::File(_) => continue,
197                                    other => {
198                                        other.as_result().unwrap();
199                                        break;
200                                    }
201                                }
202                            }
203                            continue;
204                        }
205                        Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
206                    };
207
208                    request_sender.send(req).await.unwrap();
209                }
210            }
211        });
212    });
213}
214
215#[derive(Serialize, Deserialize, Debug)]
216#[serde(transparent)]
217pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
218
219#[derive(Serialize, Deserialize, Debug)]
220pub(crate) enum WorkerTransferData {
221    Init {
222        id: BigCCharArray,
223        worker_rank: usize,
224    },
225}
226
227pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
228    let printname = "hanzo_daemon.sock";
229    Ok(printname.to_ns_name::<GenericNamespaced>()?)
230}
231
232#[allow(clippy::too_many_arguments)]
233pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
234    dtype: DType,
235    device: &Device,
236    available_devices: &[Device],
237    silent: bool,
238    config: &str,
239    loading_isq: bool,
240    from_uqff: bool,
241    organization: IsqOrganization,
242    model: &T,
243    paths: &dyn ModelPaths,
244) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
245    if !(cfg!(feature = "cuda") || cfg!(feature = "ring")) {
246        tracing::warn!(
247            "Distributed support was not included in the build, be sure to build with `--features nccl`."
248        );
249    }
250
251    // NCCL case!
252
253    let local_world_size = available_devices.len();
254    let global_world_size = if let Ok(x) = std::env::var("HANZO_MN_GLOBAL_WORLD_SIZE") {
255        usize::from_str(&x).context("HANZO_MN_GLOBAL_WORLD_SIZE")?
256    } else {
257        // global world size is always >= local world size
258        std::cmp::max(
259            hanzo_quant::distributed::get_global_tp_size_from_devices()?,
260            local_world_size,
261        )
262    };
263
264    let use_multi_node = std::env::var("HANZO_MN_GLOBAL_WORLD_SIZE").is_ok();
265    if use_multi_node {
266        info!("HANZO_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
267    }
268
269    if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
270        anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
271    }
272
273    info!("Local tensor parallel world size is {local_world_size}");
274    info!("Global tensor parallel world size is {global_world_size}");
275
276    // TP uses parallel pipelines.
277    let name = ipc_name()?;
278    let mut id;
279    let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
280        let payload: WorkerTransferData = serde_json::from_str(&payload)?;
281        let WorkerTransferData::Init {
282            id: new_id,
283            worker_rank,
284        } = payload;
285        id = hanzo_quant::Id::uninit(new_id.0);
286
287        let mut stream = LocalStream::connect(name)?;
288        stream.write_all(b"ready\n")?;
289        worker_rank + 1
290    } else if cfg!(feature = "ring") {
291        id = hanzo_quant::Id::new();
292
293        let config = RingConfig::load();
294
295        config.rank
296    } else {
297        id = hanzo_quant::Id::new();
298        let num_ranks = hanzo_quant::distributed::get_global_tp_size_from_devices()?;
299        let num_workers = num_ranks - 1;
300        let mut children = Vec::new();
301        for worker_rank in 0..num_workers {
302            let exe_path = env::current_exe().expect("Failed to get current exe");
303
304            let args: Vec<String> = env::args().collect();
305
306            let mut cmd = Command::new(exe_path);
307            cmd.args(&args[1..]);
308
309            let data = WorkerTransferData::Init {
310                id: BigCCharArray(*id.internal()),
311                worker_rank,
312            };
313
314            cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
315
316            cmd.stdout(std::process::Stdio::null());
317            cmd.stderr(std::process::Stdio::null());
318            cmd.stdin(std::process::Stdio::null());
319
320            children.push(cmd.spawn().expect("Failed to spawn process"));
321        }
322
323        let listener = ListenerOptions::new().name(name).create_sync()?;
324        let mut ready_count = 0;
325
326        while ready_count < num_workers {
327            let stream = listener.accept()?;
328            let mut reader = BufReader::new(stream);
329            let mut message = String::new();
330            reader.read_line(&mut message)?;
331            if message.trim() == "ready" {
332                ready_count += 1;
333            }
334        }
335        info!("All workers have received the ids!");
336
337        0
338    };
339
340    if use_multi_node {
341        if let Ok(n_nodes) = env::var("HANZO_MN_HEAD_NUM_WORKERS") {
342            let n_nodes = usize::from_str(&n_nodes).context("HANZO_MN_HEAD_NUM_WORKERS")?;
343            info!("Head node managing {n_nodes} workers.");
344            let Ok(port) = env::var("HANZO_MN_HEAD_PORT") else {
345                anyhow::bail!("Got HANZO_MN_HEAD_NUM_WORKERS, expected HANZO_MN_HEAD_PORT");
346            };
347            info!("Head node initializing connection on {port}.");
348            let server =
349                hanzo_quant::Server::new(&format!("0.0.0.0:{port}"), n_nodes, local_world_size)?;
350
351            server.broadcast_id(&id)?;
352        } else if let Ok(addr) = env::var("HANZO_MN_WORKER_SERVER_ADDR") {
353            info!("Worker node connecting to {addr}.");
354            let client = hanzo_quant::Client::new(addr.parse()?, local_world_size)?;
355
356            id = client.receive_id()?;
357        }
358    }
359
360    let rank_offset = if env::var("HANZO_MN_WORKER_SERVER_ADDR").is_ok() {
361        let Ok(node_id) = env::var("HANZO_MN_WORKER_ID") else {
362            anyhow::bail!("Got HANZO_MN_WORKER_SERVER_ADDR, expected HANZO_MN_WORKER_ID");
363        };
364        let node_id = usize::from_str(&node_id).context("HANZO_MN_WORKER_ID")?;
365        info!("Worker ID is {node_id}.");
366        (node_id + 1) * local_world_size
367    } else {
368        0
369    };
370
371    // They each block on each other
372    // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html?ncclcomminitrank#ncclcomminitrank
373    let comm =
374        hanzo_quant::Comm::from_device(id, device, local_rank + rank_offset, global_world_size)?;
375
376    let make_dummy_regexes = if loading_isq && from_uqff {
377        // Dummy weights for the layers which will be overwritten...
378        Some(std::sync::Arc::new(
379            if matches!(organization, IsqOrganization::MoeExpertsOnly) {
380                model.isq_layer_regexes_moqe(config)?
381            } else {
382                model.isq_layer_regexes(config)?
383            },
384        ))
385    } else {
386        None
387    };
388
389    let sharded_vb = varbuilder_utils::from_mmaped_safetensors(
390        paths.get_weight_filenames().to_vec(),
391        vec![],
392        Some(dtype),
393        &Device::Cpu,
394        vec![],
395        silent,
396        make_dummy_regexes,
397        |_| true,
398        Arc::new(|_| DeviceForLoadTensor::Base),
399    )?;
400
401    info!("Loading all ranks.");
402    // The mapper is specific to this pipeline
403    let mapper = DeviceMapSetting::Nccl {
404        nm_device: available_devices[0].clone(),
405        comm: Arc::new(comm),
406    }
407    .into_mapper(model.num_layers(config)?, device, None, available_devices)?;
408
409    let sharded_vb = if !loading_isq {
410        sharded_vb.clone().set_device(device.clone())
411    } else {
412        sharded_vb.clone()
413    };
414
415    Ok((mapper, sharded_vb))
416}