1use anyhow::Context;
2use core::ffi::c_char;
3use hanzo_ml::{DType, Device};
4pub use hanzo_quant::distributed::{use_nccl, use_ring};
5use hanzo_quant::{RingConfig, ShardedVarBuilder};
6use interprocess::local_socket::traits::{Listener, Stream};
7use interprocess::local_socket::{GenericNamespaced, Name, ToNsName};
8use interprocess::local_socket::{ListenerOptions, Stream as LocalStream};
9use serde::{Deserialize, Serialize};
10use serde_big_array::BigArray;
11use std::env;
12use std::io::{BufRead, BufReader, Write};
13use std::net::TcpStream;
14use std::process::Command;
15use std::str::FromStr;
16use std::sync::Arc;
17use tokio::runtime::Runtime;
18use tokio::sync::mpsc::Sender;
19use tracing::info;
20
21use crate::device_map::DeviceMapper;
22use crate::pipeline::{DeviceMappedModelLoader, IsqModelLoader};
23use crate::utils::varbuilder_utils::{self, DeviceForLoadTensor};
24use crate::{DeviceMapSetting, IsqOrganization, ModelPaths, Request};
25
26pub(crate) const IS_DAEMON_FLAG: &str = "__HANZO_DAEMON_INTERNAL";
27
28pub fn is_daemon() -> bool {
29 if cfg!(feature = "cuda") && !cfg!(feature = "ring") {
30 std::env::var(IS_DAEMON_FLAG).is_ok()
31 } else if use_ring() {
32 !RingConfig::load().is_master_rank()
33 } else {
34 false
35 }
36}
37
38pub fn nccl_daemon_replicator(request_sender: Sender<Request>) {
39 use std::io::BufRead;
40 use std::io::BufReader;
41
42 std::thread::spawn(move || {
43 let rt = Runtime::new().unwrap();
44 rt.block_on(async move {
45 use interprocess::local_socket::traits::Stream;
46 use interprocess::local_socket::Stream as LocalStream;
47
48 loop {
49 let name = match ipc_name() {
50 Ok(name) => name,
51 Err(e) => {
52 tracing::error!("Failed to get IPC name in daemon: {e}");
53 continue;
54 }
55 };
56 if let Ok(stream) = LocalStream::connect(name) {
57 let mut reader = BufReader::new(stream);
58 let mut buf = String::new();
59 if let Err(e) = reader.read_line(&mut buf) {
60 tracing::error!("Failed to read line from IPC stream: {e}");
61 continue;
62 }
63 let mut req: Request = match serde_json::from_str(&buf) {
64 Ok(req) => req,
65 Err(e) => {
66 tracing::error!("Failed to parse request JSON: {e}");
67 continue;
68 }
69 };
70
71 req = match req {
72 Request::ReIsq(x) => Request::ReIsq(x),
73 Request::Terminate => Request::Terminate,
74 Request::Detokenize(mut x) => {
75 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
76 x.response = sender;
77 let req = Request::Detokenize(x);
78
79 if request_sender.send(req).await.is_err() {
80 tracing::error!("Daemon channel closed for Detokenize request");
81 continue;
82 }
83 match receiver.recv().await {
84 Some(resp) => {
85 if let Err(e) = resp {
86 tracing::error!("Detokenize response error: {e}");
87 }
88 }
89 None => tracing::error!("Detokenize response channel closed"),
90 }
91 continue;
92 }
93 Request::Tokenize(mut x) => {
94 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
95 x.response = sender;
96 let req = Request::Tokenize(x);
97
98 if request_sender.send(req).await.is_err() {
99 tracing::error!("Daemon channel closed for Tokenize request");
100 continue;
101 }
102 match receiver.recv().await {
103 Some(resp) => {
104 if let Err(e) = resp {
105 tracing::error!("Tokenize response error: {e}");
106 }
107 }
108 None => tracing::error!("Tokenize response channel closed"),
109 }
110 continue;
111 }
112 Request::Normal(mut x) => {
113 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
114 x.is_streaming = false;
115 x.response = sender;
116 let req = Request::Normal(x);
117
118 if request_sender.send(req).await.is_err() {
119 tracing::error!("Daemon channel closed for Normal request");
120 continue;
121 }
122 match receiver.recv().await {
123 Some(resp) => {
124 if let Err(e) = resp.as_result() {
125 tracing::error!("Normal response error: {e}");
126 }
127 }
128 None => tracing::error!("Normal response channel closed"),
129 }
130 continue;
131 }
132 Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
133 };
134
135 if request_sender.send(req).await.is_err() {
136 tracing::error!("Daemon channel closed for request");
137 }
138 }
139 }
140 });
141 });
142}
143
144pub fn ring_daemon_replicator(request_sender: Sender<Request>) {
145 use std::io::BufRead;
146 use std::io::BufReader;
147
148 let ring_config = RingConfig::load();
149
150 let master_ip = ring_config.master_ip();
151 let master_port = ring_config.master_port;
152 std::thread::spawn(move || {
153 let rt = Runtime::new().unwrap();
154 rt.block_on(async move {
155 loop {
156 if let Ok(stream) = TcpStream::connect(format!("{master_ip}:{master_port}")) {
157 let mut reader = BufReader::new(stream);
158 let mut buf = String::new();
159 reader.read_line(&mut buf).unwrap();
160 let mut req: Request = serde_json::from_str(&buf).unwrap();
161
162 req = match req {
163 Request::ReIsq(x) => Request::ReIsq(x),
164 Request::Terminate => Request::Terminate,
165 Request::Detokenize(mut x) => {
166 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
167 x.response = sender;
168 let req = Request::Detokenize(x);
169
170 request_sender.send(req).await.unwrap();
171 let resp = receiver.recv().await.unwrap();
172 resp.unwrap();
173 continue;
174 }
175 Request::Tokenize(mut x) => {
176 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
177 x.response = sender;
178 let req = Request::Tokenize(x);
179
180 request_sender.send(req).await.unwrap();
181 let resp = receiver.recv().await.unwrap();
182 resp.unwrap();
183 continue;
184 }
185 Request::Normal(mut x) => {
186 let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
187 x.is_streaming = false;
188 x.response = sender;
189 let req = Request::Normal(x);
190
191 request_sender.send(req).await.unwrap();
192 loop {
193 let resp = receiver.recv().await.unwrap();
194 match resp {
195 crate::Response::AgenticToolCallProgress { .. } => continue,
196 crate::Response::File(_) => continue,
197 other => {
198 other.as_result().unwrap();
199 break;
200 }
201 }
202 }
203 continue;
204 }
205 Request::TerminateAllSeqsNextStep => Request::TerminateAllSeqsNextStep,
206 };
207
208 request_sender.send(req).await.unwrap();
209 }
210 }
211 });
212 });
213}
214
215#[derive(Serialize, Deserialize, Debug)]
216#[serde(transparent)]
217pub(crate) struct BigCCharArray(#[serde(with = "BigArray")] pub(crate) [c_char; 128]);
218
219#[derive(Serialize, Deserialize, Debug)]
220pub(crate) enum WorkerTransferData {
221 Init {
222 id: BigCCharArray,
223 worker_rank: usize,
224 },
225}
226
227pub(crate) fn ipc_name() -> anyhow::Result<Name<'static>> {
228 let printname = "hanzo_daemon.sock";
229 Ok(printname.to_ns_name::<GenericNamespaced>()?)
230}
231
232#[allow(clippy::too_many_arguments)]
233pub(crate) fn prepare_distributed_mapper<T: DeviceMappedModelLoader + IsqModelLoader + ?Sized>(
234 dtype: DType,
235 device: &Device,
236 available_devices: &[Device],
237 silent: bool,
238 config: &str,
239 loading_isq: bool,
240 from_uqff: bool,
241 organization: IsqOrganization,
242 model: &T,
243 paths: &dyn ModelPaths,
244) -> anyhow::Result<(Box<dyn DeviceMapper + Send + Sync>, ShardedVarBuilder)> {
245 if !(cfg!(feature = "cuda") || cfg!(feature = "ring")) {
246 tracing::warn!(
247 "Distributed support was not included in the build, be sure to build with `--features nccl`."
248 );
249 }
250
251 let local_world_size = available_devices.len();
254 let global_world_size = if let Ok(x) = std::env::var("HANZO_MN_GLOBAL_WORLD_SIZE") {
255 usize::from_str(&x).context("HANZO_MN_GLOBAL_WORLD_SIZE")?
256 } else {
257 std::cmp::max(
259 hanzo_quant::distributed::get_global_tp_size_from_devices()?,
260 local_world_size,
261 )
262 };
263
264 let use_multi_node = std::env::var("HANZO_MN_GLOBAL_WORLD_SIZE").is_ok();
265 if use_multi_node {
266 info!("HANZO_MN_GLOBAL_WORLD_SIZE is set, entering multi-node.");
267 }
268
269 if global_world_size < local_world_size || global_world_size % local_world_size != 0 {
270 anyhow::bail!("Global world size {global_world_size} must both be at least and divide the local world size {local_world_size}");
271 }
272
273 info!("Local tensor parallel world size is {local_world_size}");
274 info!("Global tensor parallel world size is {global_world_size}");
275
276 let name = ipc_name()?;
278 let mut id;
279 let local_rank = if let Ok(payload) = env::var(IS_DAEMON_FLAG) {
280 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
281 let WorkerTransferData::Init {
282 id: new_id,
283 worker_rank,
284 } = payload;
285 id = hanzo_quant::Id::uninit(new_id.0);
286
287 let mut stream = LocalStream::connect(name)?;
288 stream.write_all(b"ready\n")?;
289 worker_rank + 1
290 } else if cfg!(feature = "ring") {
291 id = hanzo_quant::Id::new();
292
293 let config = RingConfig::load();
294
295 config.rank
296 } else {
297 id = hanzo_quant::Id::new();
298 let num_ranks = hanzo_quant::distributed::get_global_tp_size_from_devices()?;
299 let num_workers = num_ranks - 1;
300 let mut children = Vec::new();
301 for worker_rank in 0..num_workers {
302 let exe_path = env::current_exe().expect("Failed to get current exe");
303
304 let args: Vec<String> = env::args().collect();
305
306 let mut cmd = Command::new(exe_path);
307 cmd.args(&args[1..]);
308
309 let data = WorkerTransferData::Init {
310 id: BigCCharArray(*id.internal()),
311 worker_rank,
312 };
313
314 cmd.env(IS_DAEMON_FLAG, serde_json::to_string(&data)?);
315
316 cmd.stdout(std::process::Stdio::null());
317 cmd.stderr(std::process::Stdio::null());
318 cmd.stdin(std::process::Stdio::null());
319
320 children.push(cmd.spawn().expect("Failed to spawn process"));
321 }
322
323 let listener = ListenerOptions::new().name(name).create_sync()?;
324 let mut ready_count = 0;
325
326 while ready_count < num_workers {
327 let stream = listener.accept()?;
328 let mut reader = BufReader::new(stream);
329 let mut message = String::new();
330 reader.read_line(&mut message)?;
331 if message.trim() == "ready" {
332 ready_count += 1;
333 }
334 }
335 info!("All workers have received the ids!");
336
337 0
338 };
339
340 if use_multi_node {
341 if let Ok(n_nodes) = env::var("HANZO_MN_HEAD_NUM_WORKERS") {
342 let n_nodes = usize::from_str(&n_nodes).context("HANZO_MN_HEAD_NUM_WORKERS")?;
343 info!("Head node managing {n_nodes} workers.");
344 let Ok(port) = env::var("HANZO_MN_HEAD_PORT") else {
345 anyhow::bail!("Got HANZO_MN_HEAD_NUM_WORKERS, expected HANZO_MN_HEAD_PORT");
346 };
347 info!("Head node initializing connection on {port}.");
348 let server =
349 hanzo_quant::Server::new(&format!("0.0.0.0:{port}"), n_nodes, local_world_size)?;
350
351 server.broadcast_id(&id)?;
352 } else if let Ok(addr) = env::var("HANZO_MN_WORKER_SERVER_ADDR") {
353 info!("Worker node connecting to {addr}.");
354 let client = hanzo_quant::Client::new(addr.parse()?, local_world_size)?;
355
356 id = client.receive_id()?;
357 }
358 }
359
360 let rank_offset = if env::var("HANZO_MN_WORKER_SERVER_ADDR").is_ok() {
361 let Ok(node_id) = env::var("HANZO_MN_WORKER_ID") else {
362 anyhow::bail!("Got HANZO_MN_WORKER_SERVER_ADDR, expected HANZO_MN_WORKER_ID");
363 };
364 let node_id = usize::from_str(&node_id).context("HANZO_MN_WORKER_ID")?;
365 info!("Worker ID is {node_id}.");
366 (node_id + 1) * local_world_size
367 } else {
368 0
369 };
370
371 let comm =
374 hanzo_quant::Comm::from_device(id, device, local_rank + rank_offset, global_world_size)?;
375
376 let make_dummy_regexes = if loading_isq && from_uqff {
377 Some(std::sync::Arc::new(
379 if matches!(organization, IsqOrganization::MoeExpertsOnly) {
380 model.isq_layer_regexes_moqe(config)?
381 } else {
382 model.isq_layer_regexes(config)?
383 },
384 ))
385 } else {
386 None
387 };
388
389 let sharded_vb = varbuilder_utils::from_mmaped_safetensors(
390 paths.get_weight_filenames().to_vec(),
391 vec![],
392 Some(dtype),
393 &Device::Cpu,
394 vec![],
395 silent,
396 make_dummy_regexes,
397 |_| true,
398 Arc::new(|_| DeviceForLoadTensor::Base),
399 )?;
400
401 info!("Loading all ranks.");
402 let mapper = DeviceMapSetting::Nccl {
404 nm_device: available_devices[0].clone(),
405 comm: Arc::new(comm),
406 }
407 .into_mapper(model.num_layers(config)?, device, None, available_devices)?;
408
409 let sharded_vb = if !loading_isq {
410 sharded_vb.clone().set_device(device.clone())
411 } else {
412 sharded_vb.clone()
413 };
414
415 Ok((mapper, sharded_vb))
416}