Skip to main content

dynamo_runtime/transports/
zmq.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! ZMQ Transport
5//!
6//! This module provides a ZMQ transport for the [crate::DistributedRuntime].
7//!
8//! Currently, the [Server] consists of a [tmq::router::Router] and the [Client] leverages
9//! a [tmq::dealer::Dealer].
10//!
11//! The distributed service pattern we will use is based on the Harmony pattern described in
12//! [Chapter 8: A Framework for Distributed Computing](https://zguide.zeromq.org/docs/chapter8/#True-Peer-Connectivity-Harmony-Pattern).
13//!
14//! This is similar to the TCP implementation; however, the TCP implementation used a direct
15//! connection between the client and server per stream. The ZMQ transport will enable the
16//! equivalent of a connection pool per upstream service at the cost of needing an extra internal
17//! routing step per service endpoint.
18
19use anyhow::{Context, Result, anyhow};
20use bytes::Bytes;
21use derive_getters::Dissolve;
22use futures::{SinkExt, StreamExt};
23use serde::{Deserialize, Serialize};
24use std::{collections::HashMap, sync::Arc};
25use tmq::{AsZmqSocket, Context as TmqContext, dealer, router};
26use tokio::{
27    sync::{Mutex, mpsc},
28    task::JoinHandle,
29};
30use tokio_util::sync::CancellationToken;
31
32pub type MultipartMessage = Vec<Vec<u8>>;
33
34// Core message types
35#[derive(Debug, Clone, Serialize, Deserialize)]
36enum ControlMessage {
37    Cancel { request_id: String },
38    CancelAck { request_id: String },
39    Error { request_id: String, error: String },
40    Complete { request_id: String },
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44enum MessageType {
45    Data(Vec<u8>),
46    Control(ControlMessage),
47}
48
49enum StreamAction {
50    SendEager(usize),
51    SendDelayed(usize),
52    Close,
53}
54
55// Router state management
56struct RouterState {
57    active_streams: HashMap<String, mpsc::Sender<Bytes>>,
58    control_channels: HashMap<String, mpsc::Sender<ControlMessage>>,
59}
60
61impl RouterState {
62    fn new() -> Self {
63        Self {
64            active_streams: HashMap::new(),
65            control_channels: HashMap::new(),
66        }
67    }
68
69    fn register_stream(
70        &mut self,
71        request_id: String,
72        data_tx: mpsc::Sender<Bytes>,
73        control_tx: mpsc::Sender<ControlMessage>,
74    ) {
75        self.active_streams.insert(request_id.clone(), data_tx);
76        self.control_channels.insert(request_id, control_tx);
77    }
78
79    fn remove_stream(&mut self, request_id: &str) {
80        self.active_streams.remove(request_id);
81        self.control_channels.remove(request_id);
82    }
83}
84
85// Server implementation
86#[derive(Clone, Dissolve)]
87pub struct Server {
88    state: Arc<Mutex<RouterState>>,
89    cancel_token: CancellationToken,
90    fd: i32,
91}
92
93impl Server {
94    /// Create a new [Server] which is a [tmq::router::Router] with the given [tmq::Context]
95    /// and address to bind the ZMQ router socket.
96    ///
97    /// If the event loop processing the router fails with an error, the signal is propagated through the [CancellationToken]
98    /// by issuing a [CancellationToken::cancel].
99    ///
100    /// The [Server] is how you interact with the running instance.
101    ///
102    /// The [ServerExecutionHandle] is the handle for background task executing the [Server].
103    pub async fn new(
104        context: &TmqContext,
105        address: &str,
106        cancel_token: CancellationToken,
107    ) -> Result<(Self, ServerExecutionHandle)> {
108        let router = router(context).bind(address)?;
109        let fd = router.get_socket().get_fd()?;
110        let state = Arc::new(Mutex::new(RouterState::new()));
111
112        // can cancel the router's event loop
113        let child = cancel_token.child_token();
114        let primary_task = tokio::spawn(Self::run(router, state.clone(), child.child_token()));
115
116        // this task captures the primary cancellation token, so if an error occurs, we can cancel the router's event loop
117        // but we also propagate the error to the caller's cancellation token
118        let watch_task = tokio::spawn(async move {
119            let result = primary_task.await.inspect_err(|e| {
120                tracing::error!("zmq server/router task failed: {e}");
121                cancel_token.cancel();
122            })?;
123            result.inspect_err(|e| {
124                tracing::error!("zmq server/router task failed: {e}");
125                cancel_token.cancel();
126            })
127        });
128
129        let handle = ServerExecutionHandle {
130            task: watch_task,
131            cancel_token: child.clone(),
132        };
133
134        Ok((
135            Self {
136                state,
137                cancel_token: child,
138                fd,
139            },
140            handle,
141        ))
142    }
143
144    // pub async fn register_stream(&)
145
146    async fn run(
147        router: tmq::router::Router,
148        state: Arc<Mutex<RouterState>>,
149        token: CancellationToken,
150    ) -> Result<()> {
151        let mut router = router;
152
153        // todo - move this into the Server impl to discover the os port being used
154        // let fd = router.as_raw_socket().get_fd()?;
155        // let sock = unsafe { socket2::Socket::from_raw_fd(fd) };
156        // let addr = sock.local_addr()?;
157        // let port = addr.as_socket().map(|s| s.port());
158
159        // if let Some(port) = port {
160        //     tracing::info!("Server listening on port {port}");
161        // }
162
163        loop {
164            let frames = tokio::select! {
165                biased;
166
167                frames = router.next() => {
168                    match frames {
169                        Some(Ok(frames)) => {
170                            frames
171                        },
172                        Some(Err(e)) => {
173                            tracing::warn!("Error receiving message: {e}");
174                            continue;
175                        }
176                        None => break,
177                    }
178                }
179
180                _ = token.cancelled() => {
181                    tracing::info!("Server shutting down");
182                    break;
183                }
184            };
185
186            // we should have at least 3 frames
187            // 0: identity
188            // 1: request_id
189            // 2: message type
190
191            // if the contract is broken, we should exit
192            if frames.len() != 3 {
193                anyhow::bail!(
194                    "Fatal Error -- Broken contract -- Expected 3 frames, got {}",
195                    frames.len()
196                );
197            }
198
199            let request_id = String::from_utf8_lossy(&frames[1]).to_string();
200            let message = frames[2].to_vec();
201            let message_size = message.len();
202
203            if let Some(tx) = state.lock().await.active_streams.get(&request_id) {
204                // first we try to send the data eagerly without blocking
205                let action = match tx.try_send(message.into()) {
206                    Ok(_) => {
207                        tracing::trace!(
208                            request_id,
209                            "response data sent eagerly to stream: {} bytes",
210                            message_size
211                        );
212                        StreamAction::SendEager(message_size)
213                    }
214                    Err(e) => match e {
215                        mpsc::error::TrySendError::Closed(_) => {
216                            tracing::info!(request_id, "response stream was closed");
217                            StreamAction::Close
218                        }
219                        mpsc::error::TrySendError::Full(data) => {
220                            tracing::warn!(
221                                request_id,
222                                "response stream is full; backpressure alert"
223                            );
224                            // todo - add timeout - we are blocking all other streams
225                            if (tx.send(data).await).is_err() {
226                                StreamAction::Close
227                            } else {
228                                StreamAction::SendDelayed(message_size)
229                            }
230                        }
231                    },
232                };
233
234                match action {
235                    StreamAction::SendEager(_size) => {
236                        // increment bytes_received
237                        // increment messages_received
238                        // increment eager_messages_received
239                    }
240                    StreamAction::SendDelayed(_size) => {
241                        // increment bytes_received
242                        // increment messages_received
243                        // increment delayed_messages_received
244                    }
245                    StreamAction::Close => {
246                        state.lock().await.active_streams.remove(&request_id);
247                    }
248                }
249            } else {
250                // increment bytes_dropped
251                // increment messages_dropped
252                tracing::trace!(request_id, "no active stream for request_id");
253            }
254        }
255
256        Ok(())
257    }
258}
259
260/// The [ServerExecutionHandle] is the handle for background task executing the [Server].
261///
262/// You can use this to check if the server is finished or cancelled.
263///
264/// You can also join on the task to wait for it to finish.
265pub struct ServerExecutionHandle {
266    task: JoinHandle<Result<()>>,
267    cancel_token: CancellationToken,
268}
269
270impl ServerExecutionHandle {
271    /// Check if the task awaiting on the [Server]s background event loop has finished.
272    pub fn is_finished(&self) -> bool {
273        self.task.is_finished()
274    }
275
276    /// Check if the server's event loop has been cancelled.
277    pub fn is_cancelled(&self) -> bool {
278        self.cancel_token.is_cancelled()
279    }
280
281    /// Cancel the server's event loop.
282    ///
283    /// This will signal the server to stop processing requests and exit.
284    ///
285    /// This will not wait for the server to finish, it will exit immediately.
286    ///
287    /// This will not propagate to the [CancellationToken] used to start the [Server]
288    /// unless an error happens during the shutdown process.
289    pub fn cancel(&self) {
290        self.cancel_token.cancel();
291    }
292
293    /// Join on the task awaiting on the [Server]s background event loop.
294    ///
295    /// This will return the result of the [Server]s background event loop.
296    pub async fn join(self) -> Result<()> {
297        self.task.await?
298    }
299}
300
301// Client implementation
302pub struct Client {
303    dealer: tmq::dealer::Dealer,
304}
305
306impl Client {
307    fn new(context: &TmqContext, address: &str) -> Result<Self> {
308        let dealer = dealer(context).connect(address)?;
309
310        Ok(Self { dealer })
311    }
312
313    fn dealer(&mut self) -> &mut tmq::dealer::Dealer {
314        &mut self.dealer
315    }
316
317    // async fn send_data(&self, data: Vec<u8>) -> Result<()> {
318    //     let msg_type = MessageType::Data(data);
319    //     let type_bytes = serde_json::to_vec(&msg_type)?;
320
321    //     self.dealer
322    //         .send_multipart(&[type_bytes, self.request_id.as_bytes().to_vec()])
323    //         .await
324    //         .map_err(|e| anyhow!("Failed to send data: {}", e))
325    // }
326
327    // async fn send_control(&self, msg: ControlMessage) -> Result<()> {
328    //     let msg_type = MessageType::Control(msg);
329    //     let type_bytes = serde_json::to_vec(&msg_type)?;
330
331    //     self.dealer
332    //         .send_multipart(&[type_bytes])
333    //         .await
334    //         .map_err(|e| anyhow!("Failed to send control message: {}", e))
335    // }
336
337    // async fn receive(&self) -> Result<MessageType> {
338    //     let frames = self
339    //         .dealer
340    //         .recv_multipart()
341    //         .await
342    //         .map_err(|e| anyhow!("Failed to receive message: {}", e))?;
343
344    //     if frames.is_empty() {
345    //         return Err(anyhow!("Received empty message"));
346    //     }
347
348    //     serde_json::from_slice(&frames[0])
349    //         .map_err(|e| anyhow!("Failed to deserialize message: {}", e))
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use tokio::time::timeout;
356
357    #[tokio::test]
358    async fn test_basic_communication() -> Result<()> {
359        let context = TmqContext::new();
360        let address = "tcp://127.0.0.1:1337";
361        let token = CancellationToken::new();
362
363        // Start server
364        let (server, handle) = Server::new(&context, address, token.clone()).await?;
365        let state = server.state.clone();
366
367        let id = "test-request".to_string();
368        let (tx, mut rx) = tokio::sync::mpsc::channel(512);
369        state.lock().await.active_streams.insert(id.clone(), tx);
370
371        // Create client
372        let mut client = Client::new(&context, address)?;
373
374        client
375            .dealer()
376            .send(tmq::Multipart::from(vec![
377                id.as_bytes().to_vec(),
378                id.as_bytes().to_vec(),
379            ]))
380            .await?;
381
382        let receive_result = rx.recv().await;
383
384        let received = receive_result.unwrap();
385
386        // convert to string
387        let received_str = String::from_utf8_lossy(&received).to_string();
388        assert_eq!(received_str, "test-request");
389
390        drop(client);
391
392        handle.cancel();
393        handle.join().await?;
394
395        println!("done");
396
397        Ok(())
398    }
399
400    // #[tokio::test]
401    // async fn test_multiple_streams() -> Result<()> {
402    //     // Similar to above but with multiple clients/streams
403    //     Ok(())
404    // }
405
406    // #[tokio::test]
407    // async fn test_error_handling() -> Result<()> {
408    //     // Test various error conditions
409    //     Ok(())
410    // }
411}