dynamo_runtime/transports/
zmq.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! ZMQ Transport
5//!
6//! This module provides a ZMQ transport for the [crate::DistributedRuntime].
7//!
8//! Currently, the [Server] consists of a [async_zmq::Router] and the [Client] leverages
9//! a [async_zmq::Dealer].
10//!
11//! The distributed service pattern we will use is based on the Harmony pattern described in
12//! [Chapter 8: A Framework for Distributed Computing](https://zguide.zeromq.org/docs/chapter8/#True-Peer-Connectivity-Harmony-Pattern).
13//!
14//! This is similar to the TCP implementation; however, the TCP implementation used a direct
15//! connection between the client and server per stream. The ZMQ transport will enable the
16//! equivalent of a connection pool per upstream service at the cost of needing an extra internal
17//! routing step per service endpoint.
18
19use anyhow::{Result, anyhow};
20use async_zmq::{Context, Dealer, Router, Sink, SinkExt, StreamExt};
21use bytes::Bytes;
22use derive_getters::Dissolve;
23use futures::TryStreamExt;
24use serde::{Deserialize, Serialize};
25use std::{collections::HashMap, os::fd::FromRawFd, sync::Arc, time::Duration, vec::IntoIter};
26use tokio::{
27    sync::{Mutex, mpsc},
28    task::{JoinError, JoinHandle},
29};
30use tokio_util::sync::CancellationToken;
31use tracing as log;
32
33// Core message types
34#[derive(Debug, Clone, Serialize, Deserialize)]
35enum ControlMessage {
36    Cancel { request_id: String },
37    CancelAck { request_id: String },
38    Error { request_id: String, error: String },
39    Complete { request_id: String },
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43enum MessageType {
44    Data(Vec<u8>),
45    Control(ControlMessage),
46}
47
48enum StreamAction {
49    SendEager(usize),
50    SendDelayed(usize),
51    Close,
52}
53
54// Router state management
55struct RouterState {
56    active_streams: HashMap<String, mpsc::Sender<Bytes>>,
57    control_channels: HashMap<String, mpsc::Sender<ControlMessage>>,
58}
59
60impl RouterState {
61    fn new() -> Self {
62        Self {
63            active_streams: HashMap::new(),
64            control_channels: HashMap::new(),
65        }
66    }
67
68    fn register_stream(
69        &mut self,
70        request_id: String,
71        data_tx: mpsc::Sender<Bytes>,
72        control_tx: mpsc::Sender<ControlMessage>,
73    ) {
74        self.active_streams.insert(request_id.clone(), data_tx);
75        self.control_channels.insert(request_id, control_tx);
76    }
77
78    fn remove_stream(&mut self, request_id: &str) {
79        self.active_streams.remove(request_id);
80        self.control_channels.remove(request_id);
81    }
82}
83
84// Server implementation
85#[derive(Clone, Dissolve)]
86pub struct Server {
87    state: Arc<Mutex<RouterState>>,
88    cancel_token: CancellationToken,
89    fd: i32,
90}
91
92impl Server {
93    /// Create a new [Server] which is a [async_zmq::Router] with the given [async_zmq::Context] and address to bind
94    /// the ZMQ [async_zmq::Router] socket.
95    ///
96    /// If the event loop processing the router fails with an error, the signal is propagated through the [CancellationToken]
97    /// by issuing a [CancellationToken::cancel].
98    ///
99    /// The [Server] is how you interact with the running instance.
100    ///
101    /// The [ServerExecutionHandle] is the handle for background task executing the [Server].
102    pub async fn new(
103        context: &Context,
104        address: &str,
105        cancel_token: CancellationToken,
106    ) -> Result<(Self, ServerExecutionHandle)> {
107        let router = async_zmq::router(address)?.with_context(context).bind()?;
108        let fd = router.as_raw_socket().get_fd()?;
109        let state = Arc::new(Mutex::new(RouterState::new()));
110
111        // can cancel the router's event loop
112        let child = cancel_token.child_token();
113        let primary_task = tokio::spawn(Self::run(router, state.clone(), child.child_token()));
114
115        // this task captures the primary cancellation token, so if an error occurs, we can cancel the router's event loop
116        // but we also propagate the error to the caller's cancellation token
117        let watch_task = tokio::spawn(async move {
118            let result = primary_task.await.inspect_err(|e| {
119                log::error!("zmq server/router task failed: {}", e);
120                cancel_token.cancel();
121            })?;
122            result.inspect_err(|e| {
123                log::error!("zmq server/router task failed: {}", e);
124                cancel_token.cancel();
125            })
126        });
127
128        let handle = ServerExecutionHandle {
129            task: watch_task,
130            cancel_token: child.clone(),
131        };
132
133        Ok((
134            Self {
135                state,
136                cancel_token: child,
137                fd,
138            },
139            handle,
140        ))
141    }
142
143    // pub async fn register_stream(&)
144
145    async fn run(
146        router: Router<IntoIter<Vec<u8>>, Vec<u8>>,
147        state: Arc<Mutex<RouterState>>,
148        token: CancellationToken,
149    ) -> Result<()> {
150        let mut router = router;
151
152        // todo - move this into the Server impl to discover the os port being used
153        // let fd = router.as_raw_socket().get_fd()?;
154        // let sock = unsafe { socket2::Socket::from_raw_fd(fd) };
155        // let addr = sock.local_addr()?;
156        // let port = addr.as_socket().map(|s| s.port());
157
158        // if let Some(port) = port {
159        //     log::info!("Server listening on port {}", port);
160        // }
161
162        loop {
163            let frames = tokio::select! {
164                biased;
165
166                frames = router.next() => {
167                    match frames {
168                        Some(Ok(frames)) => {
169                            frames
170                        },
171                        Some(Err(e)) => {
172                            log::warn!("Error receiving message: {}", e);
173                            continue;
174                        }
175                        None => break,
176                    }
177                }
178
179                _ = token.cancelled() => {
180                    log::info!("Server shutting down");
181                    break;
182                }
183            };
184
185            // we should have at least 3 frames
186            // 0: identity
187            // 1: request_id
188            // 2: message type
189
190            // if the contract is broken, we should exit
191            if frames.len() != 3 {
192                anyhow::bail!(
193                    "Fatal Error -- Broken contract -- Expected 3 frames, got {}",
194                    frames.len()
195                );
196            }
197
198            let request_id = String::from_utf8_lossy(&frames[1]).to_string();
199            let message = frames[2].to_vec();
200            let message_size = message.len();
201
202            if let Some(tx) = state.lock().await.active_streams.get(&request_id) {
203                // first we try to send the data eagerly without blocking
204                let action = match tx.try_send(message.into()) {
205                    Ok(_) => {
206                        log::trace!(
207                            request_id,
208                            "response data sent eagerly to stream: {} bytes",
209                            message_size
210                        );
211                        StreamAction::SendEager(message_size)
212                    }
213                    Err(e) => match e {
214                        mpsc::error::TrySendError::Closed(_) => {
215                            log::info!(request_id, "response stream was closed");
216                            StreamAction::Close
217                        }
218                        mpsc::error::TrySendError::Full(data) => {
219                            log::warn!(request_id, "response stream is full; backpressue alert");
220                            // todo - add timeout - we are blocking all other streams
221                            if (tx.send(data).await).is_err() {
222                                StreamAction::Close
223                            } else {
224                                StreamAction::SendDelayed(message_size)
225                            }
226                        }
227                    },
228                };
229
230                match action {
231                    StreamAction::SendEager(_size) => {
232                        // increment bytes_received
233                        // increment messages_received
234                        // increment eager_messages_received
235                    }
236                    StreamAction::SendDelayed(_size) => {
237                        // increment bytes_received
238                        // increment messages_received
239                        // increment delayed_messages_received
240                    }
241                    StreamAction::Close => {
242                        state.lock().await.active_streams.remove(&request_id);
243                    }
244                }
245            } else {
246                // increment bytes_dropped
247                // increment messages_dropped
248                log::trace!(request_id, "no active stream for request_id");
249            }
250        }
251
252        Ok(())
253    }
254}
255
256/// The [ServerExecutionHandle] is the handle for background task executing the [Server].
257///
258/// You can use this to check if the server is finished or cancelled.
259///
260/// You can also join on the task to wait for it to finish.
261pub struct ServerExecutionHandle {
262    task: JoinHandle<Result<()>>,
263    cancel_token: CancellationToken,
264}
265
266impl ServerExecutionHandle {
267    /// Check if the task awaiting on the [Server]s background event loop has finished.
268    pub fn is_finished(&self) -> bool {
269        self.task.is_finished()
270    }
271
272    /// Check if the server's event loop has been cancelled.
273    pub fn is_cancelled(&self) -> bool {
274        self.cancel_token.is_cancelled()
275    }
276
277    /// Cancel the server's event loop.
278    ///
279    /// This will signal the server to stop processing requests and exit.
280    ///
281    /// This will not wait for the server to finish, it will exit immediately.
282    ///
283    /// This will not propagate to the [CancellationToken] used to start the [Server]
284    /// unless an error happens during the shutdown process.
285    pub fn cancel(&self) {
286        self.cancel_token.cancel();
287    }
288
289    /// Join on the task awaiting on the [Server]s background event loop.
290    ///
291    /// This will return the result of the [Server]s background event loop.
292    pub async fn join(self) -> Result<()> {
293        self.task.await?
294    }
295}
296
297// Client implementation
298pub struct Client {
299    dealer: Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
300}
301
302impl Client {
303    fn new(context: &Context, address: &str) -> Result<Self> {
304        let dealer = async_zmq::dealer(address)?
305            .with_context(context)
306            .connect()?;
307
308        Ok(Self { dealer })
309    }
310
311    fn dealer(&mut self) -> &mut Dealer<IntoIter<Vec<u8>>, Vec<u8>> {
312        &mut self.dealer
313    }
314
315    // async fn send_data(&self, data: Vec<u8>) -> Result<()> {
316    //     let msg_type = MessageType::Data(data);
317    //     let type_bytes = serde_json::to_vec(&msg_type)?;
318
319    //     self.dealer
320    //         .send_multipart(&[type_bytes, self.request_id.as_bytes().to_vec()])
321    //         .await
322    //         .map_err(|e| anyhow!("Failed to send data: {}", e))
323    // }
324
325    // async fn send_control(&self, msg: ControlMessage) -> Result<()> {
326    //     let msg_type = MessageType::Control(msg);
327    //     let type_bytes = serde_json::to_vec(&msg_type)?;
328
329    //     self.dealer
330    //         .send_multipart(&[type_bytes])
331    //         .await
332    //         .map_err(|e| anyhow!("Failed to send control message: {}", e))
333    // }
334
335    // async fn receive(&self) -> Result<MessageType> {
336    //     let frames = self
337    //         .dealer
338    //         .recv_multipart()
339    //         .await
340    //         .map_err(|e| anyhow!("Failed to receive message: {}", e))?;
341
342    //     if frames.is_empty() {
343    //         return Err(anyhow!("Received empty message"));
344    //     }
345
346    //     serde_json::from_slice(&frames[0])
347    //         .map_err(|e| anyhow!("Failed to deserialize message: {}", e))
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use tokio::time::timeout;
354
355    #[tokio::test]
356    async fn test_basic_communication() -> Result<()> {
357        let context = Context::new();
358        let address = "tcp://127.0.0.1:1337";
359        let token = CancellationToken::new();
360
361        // Start server
362        let (server, handle) = Server::new(&context, address, token.clone()).await?;
363        let state = server.state.clone();
364
365        let id = "test-request".to_string();
366        let (tx, mut rx) = tokio::sync::mpsc::channel(512);
367        state.lock().await.active_streams.insert(id.clone(), tx);
368
369        // Create client
370        let mut client = Client::new(&context, address)?;
371
372        client
373            .dealer()
374            .send(vec![id.as_bytes().to_vec(), id.as_bytes().to_vec()].into())
375            .await?;
376
377        let receive_result = rx.recv().await;
378
379        let received = receive_result.unwrap();
380
381        // convert to string
382        let received_str = String::from_utf8_lossy(&received).to_string();
383        assert_eq!(received_str, "test-request");
384
385        client.dealer().close().await?;
386
387        handle.cancel();
388        handle.join().await?;
389
390        println!("done");
391
392        Ok(())
393    }
394
395    // #[tokio::test]
396    // async fn test_multiple_streams() -> Result<()> {
397    //     // Similar to above but with multiple clients/streams
398    //     Ok(())
399    // }
400
401    // #[tokio::test]
402    // async fn test_error_handling() -> Result<()> {
403    //     // Test various error conditions
404    //     Ok(())
405    // }
406}