dynamo_runtime/transports/zmq.rs
1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! ZMQ Transport
5//!
6//! This module provides a ZMQ transport for the [crate::DistributedRuntime].
7//!
8//! Currently, the [Server] consists of a [tmq::router::Router] and the [Client] leverages
9//! a [tmq::dealer::Dealer].
10//!
11//! The distributed service pattern we will use is based on the Harmony pattern described in
12//! [Chapter 8: A Framework for Distributed Computing](https://zguide.zeromq.org/docs/chapter8/#True-Peer-Connectivity-Harmony-Pattern).
13//!
14//! This is similar to the TCP implementation; however, the TCP implementation used a direct
15//! connection between the client and server per stream. The ZMQ transport will enable the
16//! equivalent of a connection pool per upstream service at the cost of needing an extra internal
17//! routing step per service endpoint.
18
19use anyhow::{Context, Result, anyhow};
20use bytes::Bytes;
21use derive_getters::Dissolve;
22use futures::{SinkExt, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::{collections::HashMap, sync::Arc};
25use tmq::{AsZmqSocket, Context as TmqContext, dealer, router};
26use tokio::{
27 sync::{Mutex, mpsc},
28 task::JoinHandle,
29};
30use tokio_util::sync::CancellationToken;
31
32pub type MultipartMessage = Vec<Vec<u8>>;
33
34// Core message types
35#[derive(Debug, Clone, Serialize, Deserialize)]
36enum ControlMessage {
37 Cancel { request_id: String },
38 CancelAck { request_id: String },
39 Error { request_id: String, error: String },
40 Complete { request_id: String },
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44enum MessageType {
45 Data(Vec<u8>),
46 Control(ControlMessage),
47}
48
49enum StreamAction {
50 SendEager(usize),
51 SendDelayed(usize),
52 Close,
53}
54
55// Router state management
56struct RouterState {
57 active_streams: HashMap<String, mpsc::Sender<Bytes>>,
58 control_channels: HashMap<String, mpsc::Sender<ControlMessage>>,
59}
60
61impl RouterState {
62 fn new() -> Self {
63 Self {
64 active_streams: HashMap::new(),
65 control_channels: HashMap::new(),
66 }
67 }
68
69 fn register_stream(
70 &mut self,
71 request_id: String,
72 data_tx: mpsc::Sender<Bytes>,
73 control_tx: mpsc::Sender<ControlMessage>,
74 ) {
75 self.active_streams.insert(request_id.clone(), data_tx);
76 self.control_channels.insert(request_id, control_tx);
77 }
78
79 fn remove_stream(&mut self, request_id: &str) {
80 self.active_streams.remove(request_id);
81 self.control_channels.remove(request_id);
82 }
83}
84
85// Server implementation
86#[derive(Clone, Dissolve)]
87pub struct Server {
88 state: Arc<Mutex<RouterState>>,
89 cancel_token: CancellationToken,
90 fd: i32,
91}
92
93impl Server {
94 /// Create a new [Server] which is a [tmq::router::Router] with the given [tmq::Context]
95 /// and address to bind the ZMQ router socket.
96 ///
97 /// If the event loop processing the router fails with an error, the signal is propagated through the [CancellationToken]
98 /// by issuing a [CancellationToken::cancel].
99 ///
100 /// The [Server] is how you interact with the running instance.
101 ///
102 /// The [ServerExecutionHandle] is the handle for background task executing the [Server].
103 pub async fn new(
104 context: &TmqContext,
105 address: &str,
106 cancel_token: CancellationToken,
107 ) -> Result<(Self, ServerExecutionHandle)> {
108 let router = router(context).bind(address)?;
109 let fd = router.get_socket().get_fd()?;
110 let state = Arc::new(Mutex::new(RouterState::new()));
111
112 // can cancel the router's event loop
113 let child = cancel_token.child_token();
114 let primary_task = tokio::spawn(Self::run(router, state.clone(), child.child_token()));
115
116 // this task captures the primary cancellation token, so if an error occurs, we can cancel the router's event loop
117 // but we also propagate the error to the caller's cancellation token
118 let watch_task = tokio::spawn(async move {
119 let result = primary_task.await.inspect_err(|e| {
120 tracing::error!("zmq server/router task failed: {e}");
121 cancel_token.cancel();
122 })?;
123 result.inspect_err(|e| {
124 tracing::error!("zmq server/router task failed: {e}");
125 cancel_token.cancel();
126 })
127 });
128
129 let handle = ServerExecutionHandle {
130 task: watch_task,
131 cancel_token: child.clone(),
132 };
133
134 Ok((
135 Self {
136 state,
137 cancel_token: child,
138 fd,
139 },
140 handle,
141 ))
142 }
143
144 // pub async fn register_stream(&)
145
146 async fn run(
147 router: tmq::router::Router,
148 state: Arc<Mutex<RouterState>>,
149 token: CancellationToken,
150 ) -> Result<()> {
151 let mut router = router;
152
153 // todo - move this into the Server impl to discover the os port being used
154 // let fd = router.as_raw_socket().get_fd()?;
155 // let sock = unsafe { socket2::Socket::from_raw_fd(fd) };
156 // let addr = sock.local_addr()?;
157 // let port = addr.as_socket().map(|s| s.port());
158
159 // if let Some(port) = port {
160 // tracing::info!("Server listening on port {port}");
161 // }
162
163 loop {
164 let frames = tokio::select! {
165 biased;
166
167 frames = router.next() => {
168 match frames {
169 Some(Ok(frames)) => {
170 frames
171 },
172 Some(Err(e)) => {
173 tracing::warn!("Error receiving message: {e}");
174 continue;
175 }
176 None => break,
177 }
178 }
179
180 _ = token.cancelled() => {
181 tracing::info!("Server shutting down");
182 break;
183 }
184 };
185
186 // we should have at least 3 frames
187 // 0: identity
188 // 1: request_id
189 // 2: message type
190
191 // if the contract is broken, we should exit
192 if frames.len() != 3 {
193 anyhow::bail!(
194 "Fatal Error -- Broken contract -- Expected 3 frames, got {}",
195 frames.len()
196 );
197 }
198
199 let request_id = String::from_utf8_lossy(&frames[1]).to_string();
200 let message = frames[2].to_vec();
201 let message_size = message.len();
202
203 if let Some(tx) = state.lock().await.active_streams.get(&request_id) {
204 // first we try to send the data eagerly without blocking
205 let action = match tx.try_send(message.into()) {
206 Ok(_) => {
207 tracing::trace!(
208 request_id,
209 "response data sent eagerly to stream: {} bytes",
210 message_size
211 );
212 StreamAction::SendEager(message_size)
213 }
214 Err(e) => match e {
215 mpsc::error::TrySendError::Closed(_) => {
216 tracing::info!(request_id, "response stream was closed");
217 StreamAction::Close
218 }
219 mpsc::error::TrySendError::Full(data) => {
220 tracing::warn!(
221 request_id,
222 "response stream is full; backpressure alert"
223 );
224 // todo - add timeout - we are blocking all other streams
225 if (tx.send(data).await).is_err() {
226 StreamAction::Close
227 } else {
228 StreamAction::SendDelayed(message_size)
229 }
230 }
231 },
232 };
233
234 match action {
235 StreamAction::SendEager(_size) => {
236 // increment bytes_received
237 // increment messages_received
238 // increment eager_messages_received
239 }
240 StreamAction::SendDelayed(_size) => {
241 // increment bytes_received
242 // increment messages_received
243 // increment delayed_messages_received
244 }
245 StreamAction::Close => {
246 state.lock().await.active_streams.remove(&request_id);
247 }
248 }
249 } else {
250 // increment bytes_dropped
251 // increment messages_dropped
252 tracing::trace!(request_id, "no active stream for request_id");
253 }
254 }
255
256 Ok(())
257 }
258}
259
260/// The [ServerExecutionHandle] is the handle for background task executing the [Server].
261///
262/// You can use this to check if the server is finished or cancelled.
263///
264/// You can also join on the task to wait for it to finish.
265pub struct ServerExecutionHandle {
266 task: JoinHandle<Result<()>>,
267 cancel_token: CancellationToken,
268}
269
270impl ServerExecutionHandle {
271 /// Check if the task awaiting on the [Server]s background event loop has finished.
272 pub fn is_finished(&self) -> bool {
273 self.task.is_finished()
274 }
275
276 /// Check if the server's event loop has been cancelled.
277 pub fn is_cancelled(&self) -> bool {
278 self.cancel_token.is_cancelled()
279 }
280
281 /// Cancel the server's event loop.
282 ///
283 /// This will signal the server to stop processing requests and exit.
284 ///
285 /// This will not wait for the server to finish, it will exit immediately.
286 ///
287 /// This will not propagate to the [CancellationToken] used to start the [Server]
288 /// unless an error happens during the shutdown process.
289 pub fn cancel(&self) {
290 self.cancel_token.cancel();
291 }
292
293 /// Join on the task awaiting on the [Server]s background event loop.
294 ///
295 /// This will return the result of the [Server]s background event loop.
296 pub async fn join(self) -> Result<()> {
297 self.task.await?
298 }
299}
300
301// Client implementation
302pub struct Client {
303 dealer: tmq::dealer::Dealer,
304}
305
306impl Client {
307 fn new(context: &TmqContext, address: &str) -> Result<Self> {
308 let dealer = dealer(context).connect(address)?;
309
310 Ok(Self { dealer })
311 }
312
313 fn dealer(&mut self) -> &mut tmq::dealer::Dealer {
314 &mut self.dealer
315 }
316
317 // async fn send_data(&self, data: Vec<u8>) -> Result<()> {
318 // let msg_type = MessageType::Data(data);
319 // let type_bytes = serde_json::to_vec(&msg_type)?;
320
321 // self.dealer
322 // .send_multipart(&[type_bytes, self.request_id.as_bytes().to_vec()])
323 // .await
324 // .map_err(|e| anyhow!("Failed to send data: {}", e))
325 // }
326
327 // async fn send_control(&self, msg: ControlMessage) -> Result<()> {
328 // let msg_type = MessageType::Control(msg);
329 // let type_bytes = serde_json::to_vec(&msg_type)?;
330
331 // self.dealer
332 // .send_multipart(&[type_bytes])
333 // .await
334 // .map_err(|e| anyhow!("Failed to send control message: {}", e))
335 // }
336
337 // async fn receive(&self) -> Result<MessageType> {
338 // let frames = self
339 // .dealer
340 // .recv_multipart()
341 // .await
342 // .map_err(|e| anyhow!("Failed to receive message: {}", e))?;
343
344 // if frames.is_empty() {
345 // return Err(anyhow!("Received empty message"));
346 // }
347
348 // serde_json::from_slice(&frames[0])
349 // .map_err(|e| anyhow!("Failed to deserialize message: {}", e))
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use tokio::time::timeout;
356
357 #[tokio::test]
358 async fn test_basic_communication() -> Result<()> {
359 let context = TmqContext::new();
360 let address = "tcp://127.0.0.1:1337";
361 let token = CancellationToken::new();
362
363 // Start server
364 let (server, handle) = Server::new(&context, address, token.clone()).await?;
365 let state = server.state.clone();
366
367 let id = "test-request".to_string();
368 let (tx, mut rx) = tokio::sync::mpsc::channel(512);
369 state.lock().await.active_streams.insert(id.clone(), tx);
370
371 // Create client
372 let mut client = Client::new(&context, address)?;
373
374 client
375 .dealer()
376 .send(tmq::Multipart::from(vec![
377 id.as_bytes().to_vec(),
378 id.as_bytes().to_vec(),
379 ]))
380 .await?;
381
382 let receive_result = rx.recv().await;
383
384 let received = receive_result.unwrap();
385
386 // convert to string
387 let received_str = String::from_utf8_lossy(&received).to_string();
388 assert_eq!(received_str, "test-request");
389
390 drop(client);
391
392 handle.cancel();
393 handle.join().await?;
394
395 println!("done");
396
397 Ok(())
398 }
399
400 // #[tokio::test]
401 // async fn test_multiple_streams() -> Result<()> {
402 // // Similar to above but with multiple clients/streams
403 // Ok(())
404 // }
405
406 // #[tokio::test]
407 // async fn test_error_handling() -> Result<()> {
408 // // Test various error conditions
409 // Ok(())
410 // }
411}