dynamo_runtime/transports/zmq.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! ZMQ Transport
5//!
6//! This module provides a ZMQ transport for the [crate::DistributedRuntime].
7//!
8//! Currently, the [Server] consists of a [async_zmq::Router] and the [Client] leverages
9//! a [async_zmq::Dealer].
10//!
11//! The distributed service pattern we will use is based on the Harmony pattern described in
12//! [Chapter 8: A Framework for Distributed Computing](https://zguide.zeromq.org/docs/chapter8/#True-Peer-Connectivity-Harmony-Pattern).
13//!
14//! This is similar to the TCP implementation; however, the TCP implementation used a direct
15//! connection between the client and server per stream. The ZMQ transport will enable the
16//! equivalent of a connection pool per upstream service at the cost of needing an extra internal
17//! routing step per service endpoint.
18
19use anyhow::{Result, anyhow};
20use async_zmq::{Context, Dealer, Router, Sink, SinkExt, StreamExt};
21use bytes::Bytes;
22use derive_getters::Dissolve;
23use futures::TryStreamExt;
24use serde::{Deserialize, Serialize};
25use std::{collections::HashMap, os::fd::FromRawFd, sync::Arc, time::Duration, vec::IntoIter};
26use tokio::{
27 sync::{Mutex, mpsc},
28 task::{JoinError, JoinHandle},
29};
30use tokio_util::sync::CancellationToken;
31use tracing as log;
32
33// Core message types
34#[derive(Debug, Clone, Serialize, Deserialize)]
35enum ControlMessage {
36 Cancel { request_id: String },
37 CancelAck { request_id: String },
38 Error { request_id: String, error: String },
39 Complete { request_id: String },
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43enum MessageType {
44 Data(Vec<u8>),
45 Control(ControlMessage),
46}
47
48enum StreamAction {
49 SendEager(usize),
50 SendDelayed(usize),
51 Close,
52}
53
54// Router state management
55struct RouterState {
56 active_streams: HashMap<String, mpsc::Sender<Bytes>>,
57 control_channels: HashMap<String, mpsc::Sender<ControlMessage>>,
58}
59
60impl RouterState {
61 fn new() -> Self {
62 Self {
63 active_streams: HashMap::new(),
64 control_channels: HashMap::new(),
65 }
66 }
67
68 fn register_stream(
69 &mut self,
70 request_id: String,
71 data_tx: mpsc::Sender<Bytes>,
72 control_tx: mpsc::Sender<ControlMessage>,
73 ) {
74 self.active_streams.insert(request_id.clone(), data_tx);
75 self.control_channels.insert(request_id, control_tx);
76 }
77
78 fn remove_stream(&mut self, request_id: &str) {
79 self.active_streams.remove(request_id);
80 self.control_channels.remove(request_id);
81 }
82}
83
84// Server implementation
85#[derive(Clone, Dissolve)]
86pub struct Server {
87 state: Arc<Mutex<RouterState>>,
88 cancel_token: CancellationToken,
89 fd: i32,
90}
91
92impl Server {
93 /// Create a new [Server] which is a [async_zmq::Router] with the given [async_zmq::Context] and address to bind
94 /// the ZMQ [async_zmq::Router] socket.
95 ///
96 /// If the event loop processing the router fails with an error, the signal is propagated through the [CancellationToken]
97 /// by issuing a [CancellationToken::cancel].
98 ///
99 /// The [Server] is how you interact with the running instance.
100 ///
101 /// The [ServerExecutionHandle] is the handle for background task executing the [Server].
102 pub async fn new(
103 context: &Context,
104 address: &str,
105 cancel_token: CancellationToken,
106 ) -> Result<(Self, ServerExecutionHandle)> {
107 let router = async_zmq::router(address)?.with_context(context).bind()?;
108 let fd = router.as_raw_socket().get_fd()?;
109 let state = Arc::new(Mutex::new(RouterState::new()));
110
111 // can cancel the router's event loop
112 let child = cancel_token.child_token();
113 let primary_task = tokio::spawn(Self::run(router, state.clone(), child.child_token()));
114
115 // this task captures the primary cancellation token, so if an error occurs, we can cancel the router's event loop
116 // but we also propagate the error to the caller's cancellation token
117 let watch_task = tokio::spawn(async move {
118 let result = primary_task.await.inspect_err(|e| {
119 log::error!("zmq server/router task failed: {}", e);
120 cancel_token.cancel();
121 })?;
122 result.inspect_err(|e| {
123 log::error!("zmq server/router task failed: {}", e);
124 cancel_token.cancel();
125 })
126 });
127
128 let handle = ServerExecutionHandle {
129 task: watch_task,
130 cancel_token: child.clone(),
131 };
132
133 Ok((
134 Self {
135 state,
136 cancel_token: child,
137 fd,
138 },
139 handle,
140 ))
141 }
142
143 // pub async fn register_stream(&)
144
145 async fn run(
146 router: Router<IntoIter<Vec<u8>>, Vec<u8>>,
147 state: Arc<Mutex<RouterState>>,
148 token: CancellationToken,
149 ) -> Result<()> {
150 let mut router = router;
151
152 // todo - move this into the Server impl to discover the os port being used
153 // let fd = router.as_raw_socket().get_fd()?;
154 // let sock = unsafe { socket2::Socket::from_raw_fd(fd) };
155 // let addr = sock.local_addr()?;
156 // let port = addr.as_socket().map(|s| s.port());
157
158 // if let Some(port) = port {
159 // log::info!("Server listening on port {}", port);
160 // }
161
162 loop {
163 let frames = tokio::select! {
164 biased;
165
166 frames = router.next() => {
167 match frames {
168 Some(Ok(frames)) => {
169 frames
170 },
171 Some(Err(e)) => {
172 log::warn!("Error receiving message: {}", e);
173 continue;
174 }
175 None => break,
176 }
177 }
178
179 _ = token.cancelled() => {
180 log::info!("Server shutting down");
181 break;
182 }
183 };
184
185 // we should have at least 3 frames
186 // 0: identity
187 // 1: request_id
188 // 2: message type
189
190 // if the contract is broken, we should exit
191 if frames.len() != 3 {
192 anyhow::bail!(
193 "Fatal Error -- Broken contract -- Expected 3 frames, got {}",
194 frames.len()
195 );
196 }
197
198 let request_id = String::from_utf8_lossy(&frames[1]).to_string();
199 let message = frames[2].to_vec();
200 let message_size = message.len();
201
202 if let Some(tx) = state.lock().await.active_streams.get(&request_id) {
203 // first we try to send the data eagerly without blocking
204 let action = match tx.try_send(message.into()) {
205 Ok(_) => {
206 log::trace!(
207 request_id,
208 "response data sent eagerly to stream: {} bytes",
209 message_size
210 );
211 StreamAction::SendEager(message_size)
212 }
213 Err(e) => match e {
214 mpsc::error::TrySendError::Closed(_) => {
215 log::info!(request_id, "response stream was closed");
216 StreamAction::Close
217 }
218 mpsc::error::TrySendError::Full(data) => {
219 log::warn!(request_id, "response stream is full; backpressue alert");
220 // todo - add timeout - we are blocking all other streams
221 if (tx.send(data).await).is_err() {
222 StreamAction::Close
223 } else {
224 StreamAction::SendDelayed(message_size)
225 }
226 }
227 },
228 };
229
230 match action {
231 StreamAction::SendEager(_size) => {
232 // increment bytes_received
233 // increment messages_received
234 // increment eager_messages_received
235 }
236 StreamAction::SendDelayed(_size) => {
237 // increment bytes_received
238 // increment messages_received
239 // increment delayed_messages_received
240 }
241 StreamAction::Close => {
242 state.lock().await.active_streams.remove(&request_id);
243 }
244 }
245 } else {
246 // increment bytes_dropped
247 // increment messages_dropped
248 log::trace!(request_id, "no active stream for request_id");
249 }
250 }
251
252 Ok(())
253 }
254}
255
256/// The [ServerExecutionHandle] is the handle for background task executing the [Server].
257///
258/// You can use this to check if the server is finished or cancelled.
259///
260/// You can also join on the task to wait for it to finish.
261pub struct ServerExecutionHandle {
262 task: JoinHandle<Result<()>>,
263 cancel_token: CancellationToken,
264}
265
266impl ServerExecutionHandle {
267 /// Check if the task awaiting on the [Server]s background event loop has finished.
268 pub fn is_finished(&self) -> bool {
269 self.task.is_finished()
270 }
271
272 /// Check if the server's event loop has been cancelled.
273 pub fn is_cancelled(&self) -> bool {
274 self.cancel_token.is_cancelled()
275 }
276
277 /// Cancel the server's event loop.
278 ///
279 /// This will signal the server to stop processing requests and exit.
280 ///
281 /// This will not wait for the server to finish, it will exit immediately.
282 ///
283 /// This will not propagate to the [CancellationToken] used to start the [Server]
284 /// unless an error happens during the shutdown process.
285 pub fn cancel(&self) {
286 self.cancel_token.cancel();
287 }
288
289 /// Join on the task awaiting on the [Server]s background event loop.
290 ///
291 /// This will return the result of the [Server]s background event loop.
292 pub async fn join(self) -> Result<()> {
293 self.task.await?
294 }
295}
296
297// Client implementation
298pub struct Client {
299 dealer: Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
300}
301
302impl Client {
303 fn new(context: &Context, address: &str) -> Result<Self> {
304 let dealer = async_zmq::dealer(address)?
305 .with_context(context)
306 .connect()?;
307
308 Ok(Self { dealer })
309 }
310
311 fn dealer(&mut self) -> &mut Dealer<IntoIter<Vec<u8>>, Vec<u8>> {
312 &mut self.dealer
313 }
314
315 // async fn send_data(&self, data: Vec<u8>) -> Result<()> {
316 // let msg_type = MessageType::Data(data);
317 // let type_bytes = serde_json::to_vec(&msg_type)?;
318
319 // self.dealer
320 // .send_multipart(&[type_bytes, self.request_id.as_bytes().to_vec()])
321 // .await
322 // .map_err(|e| anyhow!("Failed to send data: {}", e))
323 // }
324
325 // async fn send_control(&self, msg: ControlMessage) -> Result<()> {
326 // let msg_type = MessageType::Control(msg);
327 // let type_bytes = serde_json::to_vec(&msg_type)?;
328
329 // self.dealer
330 // .send_multipart(&[type_bytes])
331 // .await
332 // .map_err(|e| anyhow!("Failed to send control message: {}", e))
333 // }
334
335 // async fn receive(&self) -> Result<MessageType> {
336 // let frames = self
337 // .dealer
338 // .recv_multipart()
339 // .await
340 // .map_err(|e| anyhow!("Failed to receive message: {}", e))?;
341
342 // if frames.is_empty() {
343 // return Err(anyhow!("Received empty message"));
344 // }
345
346 // serde_json::from_slice(&frames[0])
347 // .map_err(|e| anyhow!("Failed to deserialize message: {}", e))
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use tokio::time::timeout;
354
355 #[tokio::test]
356 async fn test_basic_communication() -> Result<()> {
357 let context = Context::new();
358 let address = "tcp://127.0.0.1:1337";
359 let token = CancellationToken::new();
360
361 // Start server
362 let (server, handle) = Server::new(&context, address, token.clone()).await?;
363 let state = server.state.clone();
364
365 let id = "test-request".to_string();
366 let (tx, mut rx) = tokio::sync::mpsc::channel(512);
367 state.lock().await.active_streams.insert(id.clone(), tx);
368
369 // Create client
370 let mut client = Client::new(&context, address)?;
371
372 client
373 .dealer()
374 .send(vec![id.as_bytes().to_vec(), id.as_bytes().to_vec()].into())
375 .await?;
376
377 let receive_result = rx.recv().await;
378
379 let received = receive_result.unwrap();
380
381 // convert to string
382 let received_str = String::from_utf8_lossy(&received).to_string();
383 assert_eq!(received_str, "test-request");
384
385 client.dealer().close().await?;
386
387 handle.cancel();
388 handle.join().await?;
389
390 println!("done");
391
392 Ok(())
393 }
394
395 // #[tokio::test]
396 // async fn test_multiple_streams() -> Result<()> {
397 // // Similar to above but with multiple clients/streams
398 // Ok(())
399 // }
400
401 // #[tokio::test]
402 // async fn test_error_handling() -> Result<()> {
403 // // Test various error conditions
404 // Ok(())
405 // }
406}