Skip to main content

dynamo_runtime/transports/
zmq.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! ZMQ Transport
5//!
6//! This module provides a ZMQ transport for the [crate::DistributedRuntime].
7//!
8//! Currently, the [Server] consists of a [async_zmq::Router] and the [Client] leverages
9//! a [async_zmq::Dealer].
10//!
11//! The distributed service pattern we will use is based on the Harmony pattern described in
12//! [Chapter 8: A Framework for Distributed Computing](https://zguide.zeromq.org/docs/chapter8/#True-Peer-Connectivity-Harmony-Pattern).
13//!
14//! This is similar to the TCP implementation; however, the TCP implementation used a direct
15//! connection between the client and server per stream. The ZMQ transport will enable the
16//! equivalent of a connection pool per upstream service at the cost of needing an extra internal
17//! routing step per service endpoint.
18
19use anyhow::{Result, anyhow};
20use async_zmq::{Context, Dealer, Router, Sink, SinkExt, StreamExt};
21use bytes::Bytes;
22use derive_getters::Dissolve;
23use futures::TryStreamExt;
24use serde::{Deserialize, Serialize};
25use std::{collections::HashMap, os::fd::FromRawFd, sync::Arc, time::Duration, vec::IntoIter};
26use tokio::{
27    sync::{Mutex, mpsc},
28    task::{JoinError, JoinHandle},
29};
30use tokio_util::sync::CancellationToken;
31
32// Core message types
33#[derive(Debug, Clone, Serialize, Deserialize)]
34enum ControlMessage {
35    Cancel { request_id: String },
36    CancelAck { request_id: String },
37    Error { request_id: String, error: String },
38    Complete { request_id: String },
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42enum MessageType {
43    Data(Vec<u8>),
44    Control(ControlMessage),
45}
46
47enum StreamAction {
48    SendEager(usize),
49    SendDelayed(usize),
50    Close,
51}
52
53// Router state management
54struct RouterState {
55    active_streams: HashMap<String, mpsc::Sender<Bytes>>,
56    control_channels: HashMap<String, mpsc::Sender<ControlMessage>>,
57}
58
59impl RouterState {
60    fn new() -> Self {
61        Self {
62            active_streams: HashMap::new(),
63            control_channels: HashMap::new(),
64        }
65    }
66
67    fn register_stream(
68        &mut self,
69        request_id: String,
70        data_tx: mpsc::Sender<Bytes>,
71        control_tx: mpsc::Sender<ControlMessage>,
72    ) {
73        self.active_streams.insert(request_id.clone(), data_tx);
74        self.control_channels.insert(request_id, control_tx);
75    }
76
77    fn remove_stream(&mut self, request_id: &str) {
78        self.active_streams.remove(request_id);
79        self.control_channels.remove(request_id);
80    }
81}
82
83// Server implementation
84#[derive(Clone, Dissolve)]
85pub struct Server {
86    state: Arc<Mutex<RouterState>>,
87    cancel_token: CancellationToken,
88    fd: i32,
89}
90
91impl Server {
92    /// Create a new [Server] which is a [async_zmq::Router] with the given [async_zmq::Context] and address to bind
93    /// the ZMQ [async_zmq::Router] socket.
94    ///
95    /// If the event loop processing the router fails with an error, the signal is propagated through the [CancellationToken]
96    /// by issuing a [CancellationToken::cancel].
97    ///
98    /// The [Server] is how you interact with the running instance.
99    ///
100    /// The [ServerExecutionHandle] is the handle for background task executing the [Server].
101    pub async fn new(
102        context: &Context,
103        address: &str,
104        cancel_token: CancellationToken,
105    ) -> Result<(Self, ServerExecutionHandle)> {
106        let router = async_zmq::router(address)?.with_context(context).bind()?;
107        let fd = router.as_raw_socket().get_fd()?;
108        let state = Arc::new(Mutex::new(RouterState::new()));
109
110        // can cancel the router's event loop
111        let child = cancel_token.child_token();
112        let primary_task = tokio::spawn(Self::run(router, state.clone(), child.child_token()));
113
114        // this task captures the primary cancellation token, so if an error occurs, we can cancel the router's event loop
115        // but we also propagate the error to the caller's cancellation token
116        let watch_task = tokio::spawn(async move {
117            let result = primary_task.await.inspect_err(|e| {
118                tracing::error!("zmq server/router task failed: {}", e);
119                cancel_token.cancel();
120            })?;
121            result.inspect_err(|e| {
122                tracing::error!("zmq server/router task failed: {}", e);
123                cancel_token.cancel();
124            })
125        });
126
127        let handle = ServerExecutionHandle {
128            task: watch_task,
129            cancel_token: child.clone(),
130        };
131
132        Ok((
133            Self {
134                state,
135                cancel_token: child,
136                fd,
137            },
138            handle,
139        ))
140    }
141
142    // pub async fn register_stream(&)
143
144    async fn run(
145        router: Router<IntoIter<Vec<u8>>, Vec<u8>>,
146        state: Arc<Mutex<RouterState>>,
147        token: CancellationToken,
148    ) -> Result<()> {
149        let mut router = router;
150
151        // todo - move this into the Server impl to discover the os port being used
152        // let fd = router.as_raw_socket().get_fd()?;
153        // let sock = unsafe { socket2::Socket::from_raw_fd(fd) };
154        // let addr = sock.local_addr()?;
155        // let port = addr.as_socket().map(|s| s.port());
156
157        // if let Some(port) = port {
158        //     tracing::info!("Server listening on port {}", port);
159        // }
160
161        loop {
162            let frames = tokio::select! {
163                biased;
164
165                frames = router.next() => {
166                    match frames {
167                        Some(Ok(frames)) => {
168                            frames
169                        },
170                        Some(Err(e)) => {
171                            tracing::warn!("Error receiving message: {}", e);
172                            continue;
173                        }
174                        None => break,
175                    }
176                }
177
178                _ = token.cancelled() => {
179                    tracing::info!("Server shutting down");
180                    break;
181                }
182            };
183
184            // we should have at least 3 frames
185            // 0: identity
186            // 1: request_id
187            // 2: message type
188
189            // if the contract is broken, we should exit
190            if frames.len() != 3 {
191                anyhow::bail!(
192                    "Fatal Error -- Broken contract -- Expected 3 frames, got {}",
193                    frames.len()
194                );
195            }
196
197            let request_id = String::from_utf8_lossy(&frames[1]).to_string();
198            let message = frames[2].to_vec();
199            let message_size = message.len();
200
201            if let Some(tx) = state.lock().await.active_streams.get(&request_id) {
202                // first we try to send the data eagerly without blocking
203                let action = match tx.try_send(message.into()) {
204                    Ok(_) => {
205                        tracing::trace!(
206                            request_id,
207                            "response data sent eagerly to stream: {} bytes",
208                            message_size
209                        );
210                        StreamAction::SendEager(message_size)
211                    }
212                    Err(e) => match e {
213                        mpsc::error::TrySendError::Closed(_) => {
214                            tracing::info!(request_id, "response stream was closed");
215                            StreamAction::Close
216                        }
217                        mpsc::error::TrySendError::Full(data) => {
218                            tracing::warn!(
219                                request_id,
220                                "response stream is full; backpressure alert"
221                            );
222                            // todo - add timeout - we are blocking all other streams
223                            if (tx.send(data).await).is_err() {
224                                StreamAction::Close
225                            } else {
226                                StreamAction::SendDelayed(message_size)
227                            }
228                        }
229                    },
230                };
231
232                match action {
233                    StreamAction::SendEager(_size) => {
234                        // increment bytes_received
235                        // increment messages_received
236                        // increment eager_messages_received
237                    }
238                    StreamAction::SendDelayed(_size) => {
239                        // increment bytes_received
240                        // increment messages_received
241                        // increment delayed_messages_received
242                    }
243                    StreamAction::Close => {
244                        state.lock().await.active_streams.remove(&request_id);
245                    }
246                }
247            } else {
248                // increment bytes_dropped
249                // increment messages_dropped
250                tracing::trace!(request_id, "no active stream for request_id");
251            }
252        }
253
254        Ok(())
255    }
256}
257
258/// The [ServerExecutionHandle] is the handle for background task executing the [Server].
259///
260/// You can use this to check if the server is finished or cancelled.
261///
262/// You can also join on the task to wait for it to finish.
263pub struct ServerExecutionHandle {
264    task: JoinHandle<Result<()>>,
265    cancel_token: CancellationToken,
266}
267
268impl ServerExecutionHandle {
269    /// Check if the task awaiting on the [Server]s background event loop has finished.
270    pub fn is_finished(&self) -> bool {
271        self.task.is_finished()
272    }
273
274    /// Check if the server's event loop has been cancelled.
275    pub fn is_cancelled(&self) -> bool {
276        self.cancel_token.is_cancelled()
277    }
278
279    /// Cancel the server's event loop.
280    ///
281    /// This will signal the server to stop processing requests and exit.
282    ///
283    /// This will not wait for the server to finish, it will exit immediately.
284    ///
285    /// This will not propagate to the [CancellationToken] used to start the [Server]
286    /// unless an error happens during the shutdown process.
287    pub fn cancel(&self) {
288        self.cancel_token.cancel();
289    }
290
291    /// Join on the task awaiting on the [Server]s background event loop.
292    ///
293    /// This will return the result of the [Server]s background event loop.
294    pub async fn join(self) -> Result<()> {
295        self.task.await?
296    }
297}
298
299// Client implementation
300pub struct Client {
301    dealer: Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
302}
303
304impl Client {
305    fn new(context: &Context, address: &str) -> Result<Self> {
306        let dealer = async_zmq::dealer(address)?
307            .with_context(context)
308            .connect()?;
309
310        Ok(Self { dealer })
311    }
312
313    fn dealer(&mut self) -> &mut Dealer<IntoIter<Vec<u8>>, Vec<u8>> {
314        &mut self.dealer
315    }
316
317    // async fn send_data(&self, data: Vec<u8>) -> Result<()> {
318    //     let msg_type = MessageType::Data(data);
319    //     let type_bytes = serde_json::to_vec(&msg_type)?;
320
321    //     self.dealer
322    //         .send_multipart(&[type_bytes, self.request_id.as_bytes().to_vec()])
323    //         .await
324    //         .map_err(|e| anyhow!("Failed to send data: {}", e))
325    // }
326
327    // async fn send_control(&self, msg: ControlMessage) -> Result<()> {
328    //     let msg_type = MessageType::Control(msg);
329    //     let type_bytes = serde_json::to_vec(&msg_type)?;
330
331    //     self.dealer
332    //         .send_multipart(&[type_bytes])
333    //         .await
334    //         .map_err(|e| anyhow!("Failed to send control message: {}", e))
335    // }
336
337    // async fn receive(&self) -> Result<MessageType> {
338    //     let frames = self
339    //         .dealer
340    //         .recv_multipart()
341    //         .await
342    //         .map_err(|e| anyhow!("Failed to receive message: {}", e))?;
343
344    //     if frames.is_empty() {
345    //         return Err(anyhow!("Received empty message"));
346    //     }
347
348    //     serde_json::from_slice(&frames[0])
349    //         .map_err(|e| anyhow!("Failed to deserialize message: {}", e))
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use tokio::time::timeout;
356
357    #[tokio::test]
358    async fn test_basic_communication() -> Result<()> {
359        let context = Context::new();
360        let address = "tcp://127.0.0.1:1337";
361        let token = CancellationToken::new();
362
363        // Start server
364        let (server, handle) = Server::new(&context, address, token.clone()).await?;
365        let state = server.state.clone();
366
367        let id = "test-request".to_string();
368        let (tx, mut rx) = tokio::sync::mpsc::channel(512);
369        state.lock().await.active_streams.insert(id.clone(), tx);
370
371        // Create client
372        let mut client = Client::new(&context, address)?;
373
374        client
375            .dealer()
376            .send(vec![id.as_bytes().to_vec(), id.as_bytes().to_vec()].into())
377            .await?;
378
379        let receive_result = rx.recv().await;
380
381        let received = receive_result.unwrap();
382
383        // convert to string
384        let received_str = String::from_utf8_lossy(&received).to_string();
385        assert_eq!(received_str, "test-request");
386
387        client.dealer().close().await?;
388
389        handle.cancel();
390        handle.join().await?;
391
392        println!("done");
393
394        Ok(())
395    }
396
397    // #[tokio::test]
398    // async fn test_multiple_streams() -> Result<()> {
399    //     // Similar to above but with multiple clients/streams
400    //     Ok(())
401    // }
402
403    // #[tokio::test]
404    // async fn test_error_handling() -> Result<()> {
405    //     // Test various error conditions
406    //     Ok(())
407    // }
408}