Skip to main content

dynamo_runtime/transports/event_plane/
dynamic_subscriber.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Dynamic subscriber that watches discovery and manages connections to multiple publishers.
5//!
6//! This module enables automatic discovery and connection to new publishers as they come online,
7//! and cleanup of disconnected publishers.
8
9use anyhow::Result;
10use bytes::Bytes;
11use futures::stream::StreamExt;
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::{RwLock, mpsc};
15use tokio_util::sync::CancellationToken;
16
17use super::transport::{EventTransportRx, WireStream};
18use super::zmq_transport::ZmqSubTransport;
19use crate::discovery::{
20    Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
21    EventTransport,
22};
23
24/// Manages dynamic subscriptions to multiple publishers.
25pub struct DynamicSubscriber {
26    discovery: Arc<dyn Discovery>,
27    query: DiscoveryQuery,
28    topic: String,
29    cancel_token: CancellationToken,
30}
31
32impl DynamicSubscriber {
33    pub fn new(discovery: Arc<dyn Discovery>, query: DiscoveryQuery, topic: String) -> Self {
34        Self {
35            discovery,
36            query,
37            topic,
38            cancel_token: CancellationToken::new(),
39        }
40    }
41
42    /// Start watching discovery and create a merged stream of events.
43    pub async fn start_zmq(self: Arc<Self>) -> Result<WireStream> {
44        let (event_tx, event_rx) = mpsc::unbounded_channel::<Bytes>();
45
46        // Track active endpoint connections with instance ID to endpoint mapping
47        let active_endpoints: Arc<RwLock<HashMap<String, (String, CancellationToken)>>> =
48            Arc::new(RwLock::new(HashMap::new()));
49
50        // Clone self for the spawned task
51        let subscriber_clone = Arc::clone(&self);
52
53        // Spawn background task to watch discovery
54        let discovery = Arc::clone(&self.discovery);
55        let query = self.query.clone();
56        // Use the actual topic for ZMQ native filtering (avoids decoding irrelevant messages)
57        let zmq_topic = self.topic.clone();
58        let cancel_token = self.cancel_token.clone();
59        let endpoints = Arc::clone(&active_endpoints);
60
61        tokio::spawn(async move {
62            tracing::debug!(
63                ?query,
64                cancel_token_cancelled = cancel_token.is_cancelled(),
65                "Attempting to start discovery watch"
66            );
67
68            // Don't pass the cancel token to list_and_watch - we'll handle cancellation ourselves
69            let mut watch_stream = match discovery.list_and_watch(query.clone(), None).await {
70                Ok(stream) => {
71                    tracing::debug!("Successfully obtained discovery watch stream");
72                    stream
73                }
74                Err(e) => {
75                    tracing::error!(error = %e, "Failed to start discovery watch");
76                    return;
77                }
78            };
79
80            tracing::info!(?query, "Started dynamic discovery watch for ZMQ publishers");
81
82            while let Some(event_result) = watch_stream.next().await {
83                tracing::debug!("Received discovery event: {:?}", event_result);
84                if cancel_token.is_cancelled() {
85                    tracing::info!("Dynamic subscriber cancelled, stopping watch");
86                    break;
87                }
88
89                match event_result {
90                    Ok(DiscoveryEvent::Added(instance)) => {
91                        tracing::info!(instance = ?instance, "Discovery Added event received");
92                        let instance_id = instance.instance_id().to_string();
93
94                        // Extract ZMQ endpoint from the instance
95                        if let Some(endpoint) = Self::extract_zmq_endpoint(&instance) {
96                            let mut endpoints_guard = endpoints.write().await;
97
98                            // Skip if instance already tracked
99                            if endpoints_guard.contains_key(&instance_id) {
100                                tracing::debug!(endpoint = %endpoint, instance_id = %instance_id, "Already connected to ZMQ publisher");
101                                continue;
102                            }
103
104                            tracing::info!(endpoint = %endpoint, instance_id = %instance_id, "Connecting to new ZMQ publisher");
105
106                            // Create cancellation token for this endpoint's stream
107                            let endpoint_cancel = CancellationToken::new();
108                            endpoints_guard.insert(
109                                instance_id.clone(),
110                                (endpoint.clone(), endpoint_cancel.clone()),
111                            );
112                            drop(endpoints_guard);
113
114                            // Spawn task to handle this endpoint's stream
115                            let event_tx_clone = event_tx.clone();
116                            let zmq_topic_clone = zmq_topic.clone();
117                            let endpoint_clone = endpoint.clone();
118                            let endpoints_clone = Arc::clone(&endpoints);
119                            let instance_id_clone = instance_id.clone();
120
121                            tokio::spawn(async move {
122                                if let Err(e) = Self::consume_endpoint_stream(
123                                    &endpoint_clone,
124                                    &zmq_topic_clone,
125                                    event_tx_clone,
126                                    endpoint_cancel,
127                                )
128                                .await
129                                {
130                                    tracing::warn!(
131                                        endpoint = %endpoint_clone,
132                                        error = %e,
133                                        "Error consuming ZMQ endpoint stream"
134                                    );
135                                }
136                                // Clean up on stream termination
137                                endpoints_clone.write().await.remove(&instance_id_clone);
138                            });
139                        } else {
140                            tracing::warn!(
141                                instance = ?instance,
142                                "Discovery Added event did not contain a ZMQ endpoint"
143                            );
144                        }
145                    }
146                    Ok(DiscoveryEvent::Removed(instance_id)) => {
147                        let id_str = instance_id.instance_id().to_string();
148                        tracing::info!(
149                            instance_id = %id_str,
150                            "ZMQ publisher removed from discovery, cancelling endpoint stream"
151                        );
152
153                        // Cancel the endpoint's stream via its CancellationToken
154                        if let Some((_endpoint, cancel)) = endpoints.write().await.remove(&id_str) {
155                            cancel.cancel();
156                            tracing::info!(instance_id = %id_str, "Cancelled endpoint stream");
157                        } else {
158                            tracing::warn!(instance_id = %id_str, "No active endpoint found for removed stream instance");
159                        }
160                    }
161                    Err(e) => {
162                        tracing::error!(error = %e, "Discovery watch error");
163                        break;
164                    }
165                }
166            }
167
168            // Cancel all active endpoints on shutdown
169            let endpoints_guard = endpoints.write().await;
170            for (_id, (_endpoint, cancel)) in endpoints_guard.iter() {
171                cancel.cancel();
172            }
173            tracing::info!("Discovery watch stream ended");
174        });
175
176        // Return a stream that reads from the merged channel
177        let stream = async_stream::stream! {
178            // Keep subscriber_clone alive by capturing it in the stream
179            let _subscriber = subscriber_clone;
180            let mut rx = event_rx;
181            while let Some(bytes) = rx.recv().await {
182                yield Ok(bytes);
183            }
184        };
185
186        Ok(Box::pin(stream))
187    }
188
189    /// Extract ZMQ endpoint from a discovery instance.
190    fn extract_zmq_endpoint(instance: &DiscoveryInstance) -> Option<String> {
191        if let DiscoveryInstance::EventChannel { transport, .. } = instance
192            && let EventTransport::Zmq { endpoint } = transport
193        {
194            return Some(endpoint.clone());
195        }
196        None
197    }
198
199    /// Consume events from a single endpoint and forward to the merged channel.
200    async fn consume_endpoint_stream(
201        endpoint: &str,
202        zmq_topic: &str,
203        event_tx: mpsc::UnboundedSender<Bytes>,
204        cancel_token: CancellationToken,
205    ) -> Result<()> {
206        // Connect to the endpoint
207        let sub_transport = ZmqSubTransport::connect(endpoint, zmq_topic).await?;
208        let mut stream = sub_transport.subscribe(zmq_topic).await?;
209
210        tracing::info!(endpoint = %endpoint, topic = %zmq_topic, "Started consuming ZMQ endpoint stream");
211
212        loop {
213            tokio::select! {
214                _ = cancel_token.cancelled() => {
215                    tracing::info!(endpoint = %endpoint, "Endpoint stream cancelled");
216                    break;
217                }
218
219                event = stream.next() => {
220                    match event {
221                        Some(Ok(bytes)) => {
222                            if event_tx.send(bytes).is_err() {
223                                tracing::warn!(endpoint = %endpoint, "Event channel closed, stopping endpoint stream");
224                                break;
225                            }
226                        }
227                        Some(Err(e)) => {
228                            tracing::error!(
229                                endpoint = %endpoint,
230                                error = %e,
231                                "Error receiving from ZMQ endpoint"
232                            );
233                            break;
234                        }
235                        None => {
236                            tracing::info!(endpoint = %endpoint, "ZMQ endpoint stream ended");
237                            break;
238                        }
239                    }
240                }
241            }
242        }
243
244        Ok(())
245    }
246
247    /// Stop watching and disconnect from all endpoints.
248    pub fn cancel(&self) {
249        self.cancel_token.cancel();
250    }
251}
252
253impl Drop for DynamicSubscriber {
254    fn drop(&mut self) {
255        self.cancel_token.cancel();
256    }
257}