dynamo_runtime/transports/zmq.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! ZMQ Transport
5//!
6//! This module provides a ZMQ transport for the [crate::DistributedRuntime].
7//!
8//! Currently, the [Server] consists of a [async_zmq::Router] and the [Client] leverages
9//! a [async_zmq::Dealer].
10//!
11//! The distributed service pattern we will use is based on the Harmony pattern described in
12//! [Chapter 8: A Framework for Distributed Computing](https://zguide.zeromq.org/docs/chapter8/#True-Peer-Connectivity-Harmony-Pattern).
13//!
14//! This is similar to the TCP implementation; however, the TCP implementation used a direct
15//! connection between the client and server per stream. The ZMQ transport will enable the
16//! equivalent of a connection pool per upstream service at the cost of needing an extra internal
17//! routing step per service endpoint.
18
19use anyhow::{Result, anyhow};
20use async_zmq::{Context, Dealer, Router, Sink, SinkExt, StreamExt};
21use bytes::Bytes;
22use derive_getters::Dissolve;
23use futures::TryStreamExt;
24use serde::{Deserialize, Serialize};
25use std::{collections::HashMap, os::fd::FromRawFd, sync::Arc, time::Duration, vec::IntoIter};
26use tokio::{
27 sync::{Mutex, mpsc},
28 task::{JoinError, JoinHandle},
29};
30use tokio_util::sync::CancellationToken;
31
32// Core message types
33#[derive(Debug, Clone, Serialize, Deserialize)]
34enum ControlMessage {
35 Cancel { request_id: String },
36 CancelAck { request_id: String },
37 Error { request_id: String, error: String },
38 Complete { request_id: String },
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42enum MessageType {
43 Data(Vec<u8>),
44 Control(ControlMessage),
45}
46
47enum StreamAction {
48 SendEager(usize),
49 SendDelayed(usize),
50 Close,
51}
52
53// Router state management
54struct RouterState {
55 active_streams: HashMap<String, mpsc::Sender<Bytes>>,
56 control_channels: HashMap<String, mpsc::Sender<ControlMessage>>,
57}
58
59impl RouterState {
60 fn new() -> Self {
61 Self {
62 active_streams: HashMap::new(),
63 control_channels: HashMap::new(),
64 }
65 }
66
67 fn register_stream(
68 &mut self,
69 request_id: String,
70 data_tx: mpsc::Sender<Bytes>,
71 control_tx: mpsc::Sender<ControlMessage>,
72 ) {
73 self.active_streams.insert(request_id.clone(), data_tx);
74 self.control_channels.insert(request_id, control_tx);
75 }
76
77 fn remove_stream(&mut self, request_id: &str) {
78 self.active_streams.remove(request_id);
79 self.control_channels.remove(request_id);
80 }
81}
82
83// Server implementation
84#[derive(Clone, Dissolve)]
85pub struct Server {
86 state: Arc<Mutex<RouterState>>,
87 cancel_token: CancellationToken,
88 fd: i32,
89}
90
91impl Server {
92 /// Create a new [Server] which is a [async_zmq::Router] with the given [async_zmq::Context] and address to bind
93 /// the ZMQ [async_zmq::Router] socket.
94 ///
95 /// If the event loop processing the router fails with an error, the signal is propagated through the [CancellationToken]
96 /// by issuing a [CancellationToken::cancel].
97 ///
98 /// The [Server] is how you interact with the running instance.
99 ///
100 /// The [ServerExecutionHandle] is the handle for background task executing the [Server].
101 pub async fn new(
102 context: &Context,
103 address: &str,
104 cancel_token: CancellationToken,
105 ) -> Result<(Self, ServerExecutionHandle)> {
106 let router = async_zmq::router(address)?.with_context(context).bind()?;
107 let fd = router.as_raw_socket().get_fd()?;
108 let state = Arc::new(Mutex::new(RouterState::new()));
109
110 // can cancel the router's event loop
111 let child = cancel_token.child_token();
112 let primary_task = tokio::spawn(Self::run(router, state.clone(), child.child_token()));
113
114 // this task captures the primary cancellation token, so if an error occurs, we can cancel the router's event loop
115 // but we also propagate the error to the caller's cancellation token
116 let watch_task = tokio::spawn(async move {
117 let result = primary_task.await.inspect_err(|e| {
118 tracing::error!("zmq server/router task failed: {}", e);
119 cancel_token.cancel();
120 })?;
121 result.inspect_err(|e| {
122 tracing::error!("zmq server/router task failed: {}", e);
123 cancel_token.cancel();
124 })
125 });
126
127 let handle = ServerExecutionHandle {
128 task: watch_task,
129 cancel_token: child.clone(),
130 };
131
132 Ok((
133 Self {
134 state,
135 cancel_token: child,
136 fd,
137 },
138 handle,
139 ))
140 }
141
142 // pub async fn register_stream(&)
143
144 async fn run(
145 router: Router<IntoIter<Vec<u8>>, Vec<u8>>,
146 state: Arc<Mutex<RouterState>>,
147 token: CancellationToken,
148 ) -> Result<()> {
149 let mut router = router;
150
151 // todo - move this into the Server impl to discover the os port being used
152 // let fd = router.as_raw_socket().get_fd()?;
153 // let sock = unsafe { socket2::Socket::from_raw_fd(fd) };
154 // let addr = sock.local_addr()?;
155 // let port = addr.as_socket().map(|s| s.port());
156
157 // if let Some(port) = port {
158 // tracing::info!("Server listening on port {}", port);
159 // }
160
161 loop {
162 let frames = tokio::select! {
163 biased;
164
165 frames = router.next() => {
166 match frames {
167 Some(Ok(frames)) => {
168 frames
169 },
170 Some(Err(e)) => {
171 tracing::warn!("Error receiving message: {}", e);
172 continue;
173 }
174 None => break,
175 }
176 }
177
178 _ = token.cancelled() => {
179 tracing::info!("Server shutting down");
180 break;
181 }
182 };
183
184 // we should have at least 3 frames
185 // 0: identity
186 // 1: request_id
187 // 2: message type
188
189 // if the contract is broken, we should exit
190 if frames.len() != 3 {
191 anyhow::bail!(
192 "Fatal Error -- Broken contract -- Expected 3 frames, got {}",
193 frames.len()
194 );
195 }
196
197 let request_id = String::from_utf8_lossy(&frames[1]).to_string();
198 let message = frames[2].to_vec();
199 let message_size = message.len();
200
201 if let Some(tx) = state.lock().await.active_streams.get(&request_id) {
202 // first we try to send the data eagerly without blocking
203 let action = match tx.try_send(message.into()) {
204 Ok(_) => {
205 tracing::trace!(
206 request_id,
207 "response data sent eagerly to stream: {} bytes",
208 message_size
209 );
210 StreamAction::SendEager(message_size)
211 }
212 Err(e) => match e {
213 mpsc::error::TrySendError::Closed(_) => {
214 tracing::info!(request_id, "response stream was closed");
215 StreamAction::Close
216 }
217 mpsc::error::TrySendError::Full(data) => {
218 tracing::warn!(
219 request_id,
220 "response stream is full; backpressure alert"
221 );
222 // todo - add timeout - we are blocking all other streams
223 if (tx.send(data).await).is_err() {
224 StreamAction::Close
225 } else {
226 StreamAction::SendDelayed(message_size)
227 }
228 }
229 },
230 };
231
232 match action {
233 StreamAction::SendEager(_size) => {
234 // increment bytes_received
235 // increment messages_received
236 // increment eager_messages_received
237 }
238 StreamAction::SendDelayed(_size) => {
239 // increment bytes_received
240 // increment messages_received
241 // increment delayed_messages_received
242 }
243 StreamAction::Close => {
244 state.lock().await.active_streams.remove(&request_id);
245 }
246 }
247 } else {
248 // increment bytes_dropped
249 // increment messages_dropped
250 tracing::trace!(request_id, "no active stream for request_id");
251 }
252 }
253
254 Ok(())
255 }
256}
257
258/// The [ServerExecutionHandle] is the handle for background task executing the [Server].
259///
260/// You can use this to check if the server is finished or cancelled.
261///
262/// You can also join on the task to wait for it to finish.
263pub struct ServerExecutionHandle {
264 task: JoinHandle<Result<()>>,
265 cancel_token: CancellationToken,
266}
267
268impl ServerExecutionHandle {
269 /// Check if the task awaiting on the [Server]s background event loop has finished.
270 pub fn is_finished(&self) -> bool {
271 self.task.is_finished()
272 }
273
274 /// Check if the server's event loop has been cancelled.
275 pub fn is_cancelled(&self) -> bool {
276 self.cancel_token.is_cancelled()
277 }
278
279 /// Cancel the server's event loop.
280 ///
281 /// This will signal the server to stop processing requests and exit.
282 ///
283 /// This will not wait for the server to finish, it will exit immediately.
284 ///
285 /// This will not propagate to the [CancellationToken] used to start the [Server]
286 /// unless an error happens during the shutdown process.
287 pub fn cancel(&self) {
288 self.cancel_token.cancel();
289 }
290
291 /// Join on the task awaiting on the [Server]s background event loop.
292 ///
293 /// This will return the result of the [Server]s background event loop.
294 pub async fn join(self) -> Result<()> {
295 self.task.await?
296 }
297}
298
299// Client implementation
300pub struct Client {
301 dealer: Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
302}
303
304impl Client {
305 fn new(context: &Context, address: &str) -> Result<Self> {
306 let dealer = async_zmq::dealer(address)?
307 .with_context(context)
308 .connect()?;
309
310 Ok(Self { dealer })
311 }
312
313 fn dealer(&mut self) -> &mut Dealer<IntoIter<Vec<u8>>, Vec<u8>> {
314 &mut self.dealer
315 }
316
317 // async fn send_data(&self, data: Vec<u8>) -> Result<()> {
318 // let msg_type = MessageType::Data(data);
319 // let type_bytes = serde_json::to_vec(&msg_type)?;
320
321 // self.dealer
322 // .send_multipart(&[type_bytes, self.request_id.as_bytes().to_vec()])
323 // .await
324 // .map_err(|e| anyhow!("Failed to send data: {}", e))
325 // }
326
327 // async fn send_control(&self, msg: ControlMessage) -> Result<()> {
328 // let msg_type = MessageType::Control(msg);
329 // let type_bytes = serde_json::to_vec(&msg_type)?;
330
331 // self.dealer
332 // .send_multipart(&[type_bytes])
333 // .await
334 // .map_err(|e| anyhow!("Failed to send control message: {}", e))
335 // }
336
337 // async fn receive(&self) -> Result<MessageType> {
338 // let frames = self
339 // .dealer
340 // .recv_multipart()
341 // .await
342 // .map_err(|e| anyhow!("Failed to receive message: {}", e))?;
343
344 // if frames.is_empty() {
345 // return Err(anyhow!("Received empty message"));
346 // }
347
348 // serde_json::from_slice(&frames[0])
349 // .map_err(|e| anyhow!("Failed to deserialize message: {}", e))
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use tokio::time::timeout;
356
357 #[tokio::test]
358 async fn test_basic_communication() -> Result<()> {
359 let context = Context::new();
360 let address = "tcp://127.0.0.1:1337";
361 let token = CancellationToken::new();
362
363 // Start server
364 let (server, handle) = Server::new(&context, address, token.clone()).await?;
365 let state = server.state.clone();
366
367 let id = "test-request".to_string();
368 let (tx, mut rx) = tokio::sync::mpsc::channel(512);
369 state.lock().await.active_streams.insert(id.clone(), tx);
370
371 // Create client
372 let mut client = Client::new(&context, address)?;
373
374 client
375 .dealer()
376 .send(vec![id.as_bytes().to_vec(), id.as_bytes().to_vec()].into())
377 .await?;
378
379 let receive_result = rx.recv().await;
380
381 let received = receive_result.unwrap();
382
383 // convert to string
384 let received_str = String::from_utf8_lossy(&received).to_string();
385 assert_eq!(received_str, "test-request");
386
387 client.dealer().close().await?;
388
389 handle.cancel();
390 handle.join().await?;
391
392 println!("done");
393
394 Ok(())
395 }
396
397 // #[tokio::test]
398 // async fn test_multiple_streams() -> Result<()> {
399 // // Similar to above but with multiple clients/streams
400 // Ok(())
401 // }
402
403 // #[tokio::test]
404 // async fn test_error_handling() -> Result<()> {
405 // // Test various error conditions
406 // Ok(())
407 // }
408}