Skip to main content

slim_session/
subscription_manager.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8use std::time::Duration;
9
10use async_trait::async_trait;
11use futures::future::Either;
12use futures_timer::Delay;
13use parking_lot::Mutex;
14use thiserror::Error;
15use tokio::sync::oneshot;
16
17use slim_datapath::api::{ProtoMessage as Message, ProtoName, ProtoSubscriptionAck};
18use slim_datapath::messages::utils::SlimHeaderFlags;
19
20use crate::common::SlimChannelSender;
21
22/// How long to wait for a subscription ACK before giving up.
23///
24/// The datapath retry loop runs `0..=MAX_RETRIES` attempts (currently 4) with a
25/// per-attempt timeout of `TIMEOUT` (currently 2 s), for a maximum of
26/// `TIMEOUT * (MAX_RETRIES + 1) = 8 s`.  This deadline must be at least that
27/// large so every retry attempt has a chance to succeed before the session
28/// considers the operation lost.
29const ACK_TIMEOUT: Duration = Duration::from_secs(10);
30
31#[derive(Error, Debug)]
32pub enum SubscriptionAckError {
33    #[error("ack rejected by datapath: {message}")]
34    Rejected { message: String },
35    #[error("ack channel closed")]
36    ChannelClosed,
37    #[error("ack timed out")]
38    Timeout,
39}
40
41/// Trait that abstracts subscription and route management operations.
42///
43/// Every method sends the request with an ack_id and returns the
44/// [`oneshot::Receiver`] for that ACK.  The caller decides whether to await
45/// the receiver immediately (blocking until confirmed) or drop it (fire and
46/// forget while the datapath still tracks the operation).
47#[async_trait]
48pub trait SubscriptionOps: Clone + Send + Sync + 'static {
49    /// Subscribe (forward_to): register interest in `name`, optionally routing
50    /// through a specific connection.
51    async fn subscribe(
52        &self,
53        source: &ProtoName,
54        name: &ProtoName,
55        forward_to: Option<u64>,
56    ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>;
57
58    /// Unsubscribe (forward_to): de-register interest in `name`.
59    async fn unsubscribe(
60        &self,
61        source: &ProtoName,
62        name: &ProtoName,
63        subscription_id: u64,
64        forward_to: Option<u64>,
65    ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError>;
66
67    /// Set a recv_from route for `name` on connection `conn`.
68    async fn set_route(
69        &self,
70        source: &ProtoName,
71        name: &ProtoName,
72        conn: u64,
73    ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>;
74
75    /// Remove a recv_from route for `name` on connection `conn`.
76    async fn remove_route(
77        &self,
78        source: &ProtoName,
79        name: &ProtoName,
80        subscription_id: u64,
81        conn: u64,
82    ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError>;
83
84    /// Called during session stack construction to create a default instance
85    /// from the SLIM channel sender.  Returns `None` if this type requires
86    /// explicit construction (caller must call `with_subscription_manager` on
87    /// the builder).
88    fn from_slim_tx(_tx: &SlimChannelSender) -> Option<Self>
89    where
90        Self: Sized,
91    {
92        None
93    }
94}
95
96/// A no-op subscription manager for tests that do not run a real SLIM
97/// datapath.  Every operation immediately succeeds without sending any
98/// messages.
99#[derive(Clone)]
100pub struct AutoAckManager {
101    ack_counter: Arc<AtomicU64>,
102}
103
104#[async_trait]
105impl SubscriptionOps for AutoAckManager {
106    async fn subscribe(
107        &self,
108        _source: &ProtoName,
109        _name: &ProtoName,
110        _forward_to: Option<u64>,
111    ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
112    {
113        let id = self.ack_counter.fetch_add(1, Ordering::Relaxed) + 1;
114        let (tx, rx) = oneshot::channel();
115        let _ = tx.send(Ok(()));
116        Ok((id, rx))
117    }
118
119    async fn unsubscribe(
120        &self,
121        _source: &ProtoName,
122        _name: &ProtoName,
123        _subscription_id: u64,
124        _forward_to: Option<u64>,
125    ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
126        let (tx, rx) = oneshot::channel();
127        let _ = tx.send(Ok(()));
128        Ok(rx)
129    }
130
131    async fn set_route(
132        &self,
133        _source: &ProtoName,
134        _name: &ProtoName,
135        _conn: u64,
136    ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
137    {
138        let id = self.ack_counter.fetch_add(1, Ordering::Relaxed) + 1;
139        let (tx, rx) = oneshot::channel();
140        let _ = tx.send(Ok(()));
141        Ok((id, rx))
142    }
143
144    async fn remove_route(
145        &self,
146        _source: &ProtoName,
147        _name: &ProtoName,
148        _subscription_id: u64,
149        _conn: u64,
150    ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
151        let (tx, rx) = oneshot::channel();
152        let _ = tx.send(Ok(()));
153        Ok(rx)
154    }
155
156    fn from_slim_tx(_tx: &SlimChannelSender) -> Option<Self> {
157        Some(AutoAckManager {
158            ack_counter: Arc::new(AtomicU64::new(0)),
159        })
160    }
161}
162
163#[derive(Clone)]
164pub struct SubscriptionManager {
165    pub pending_acks: Arc<Mutex<HashMap<u64, oneshot::Sender<Result<(), SubscriptionAckError>>>>>,
166    ack_counter: Arc<AtomicU64>,
167    tx: SlimChannelSender,
168}
169
170#[async_trait]
171impl SubscriptionOps for SubscriptionManager {
172    async fn subscribe(
173        &self,
174        source: &ProtoName,
175        name: &ProtoName,
176        forward_to: Option<u64>,
177    ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
178    {
179        let source = source.clone();
180        let name = name.clone();
181        self.send_with_receiver(move |ack_id| {
182            let flags = if let Some(conn) = forward_to {
183                SlimHeaderFlags::default().with_forward_to(conn)
184            } else {
185                SlimHeaderFlags::default()
186            };
187            Message::builder()
188                .source(source)
189                .destination(name)
190                .flags(flags)
191                .subscription_id(ack_id)
192                .build_subscribe()
193                .unwrap()
194        })
195        .await
196    }
197
198    async fn unsubscribe(
199        &self,
200        source: &ProtoName,
201        name: &ProtoName,
202        subscription_id: u64,
203        forward_to: Option<u64>,
204    ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
205        let source = source.clone();
206        let name = name.clone();
207        self.send_with_id(subscription_id, move |ack_id| {
208            let flags = if let Some(conn) = forward_to {
209                SlimHeaderFlags::default().with_forward_to(conn)
210            } else {
211                SlimHeaderFlags::default()
212            };
213            Message::builder()
214                .source(source)
215                .destination(name)
216                .flags(flags)
217                .subscription_id(ack_id)
218                .build_unsubscribe()
219                .unwrap()
220        })
221        .await
222    }
223
224    async fn set_route(
225        &self,
226        source: &ProtoName,
227        name: &ProtoName,
228        conn: u64,
229    ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
230    {
231        let source = source.clone();
232        let name = name.clone();
233        self.send_with_receiver(move |ack_id| {
234            Message::builder()
235                .source(source)
236                .destination(name)
237                .flags(SlimHeaderFlags::default().with_recv_from(conn))
238                .subscription_id(ack_id)
239                .build_subscribe()
240                .unwrap()
241        })
242        .await
243    }
244
245    async fn remove_route(
246        &self,
247        source: &ProtoName,
248        name: &ProtoName,
249        subscription_id: u64,
250        conn: u64,
251    ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
252        let source = source.clone();
253        let name = name.clone();
254        self.send_with_id(subscription_id, move |ack_id| {
255            Message::builder()
256                .source(source)
257                .destination(name)
258                .flags(SlimHeaderFlags::default().with_recv_from(conn))
259                .subscription_id(ack_id)
260                .build_unsubscribe()
261                .unwrap()
262        })
263        .await
264    }
265
266    fn from_slim_tx(tx: &SlimChannelSender) -> Option<Self> {
267        Some(SubscriptionManager::new(tx.clone()))
268    }
269}
270
271/// Spy subscription manager for tests: immediately returns `Ok(())` and
272/// records each call to a channel so tests can assert on the operations.
273#[cfg(test)]
274#[derive(Clone)]
275pub struct SpySubscriptionManager {
276    tx: Arc<tokio::sync::mpsc::UnboundedSender<SubscriptionCall>>,
277}
278
279/// Individual subscription operation recorded by [`SpySubscriptionManager`].
280#[cfg(test)]
281#[derive(Debug, Clone, PartialEq)]
282pub enum SubscriptionCall {
283    Subscribe,
284    Unsubscribe,
285    SetRoute,
286    RemoveRoute,
287}
288
289#[cfg(test)]
290impl SpySubscriptionManager {
291    pub fn new() -> (Self, tokio::sync::mpsc::UnboundedReceiver<SubscriptionCall>) {
292        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
293        (Self { tx: Arc::new(tx) }, rx)
294    }
295}
296
297#[cfg(test)]
298#[async_trait]
299impl SubscriptionOps for SpySubscriptionManager {
300    async fn subscribe(
301        &self,
302        _source: &ProtoName,
303        _name: &ProtoName,
304        _forward_to: Option<u64>,
305    ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
306    {
307        let _ = self.tx.send(SubscriptionCall::Subscribe);
308        let (tx, rx) = oneshot::channel();
309        let _ = tx.send(Ok(()));
310        Ok((0, rx))
311    }
312
313    async fn unsubscribe(
314        &self,
315        _source: &ProtoName,
316        _name: &ProtoName,
317        _subscription_id: u64,
318        _forward_to: Option<u64>,
319    ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
320        let _ = self.tx.send(SubscriptionCall::Unsubscribe);
321        let (tx, rx) = oneshot::channel();
322        let _ = tx.send(Ok(()));
323        Ok(rx)
324    }
325
326    async fn set_route(
327        &self,
328        _source: &ProtoName,
329        _name: &ProtoName,
330        _conn: u64,
331    ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
332    {
333        let _ = self.tx.send(SubscriptionCall::SetRoute);
334        let (tx, rx) = oneshot::channel();
335        let _ = tx.send(Ok(()));
336        Ok((0, rx))
337    }
338
339    async fn remove_route(
340        &self,
341        _source: &ProtoName,
342        _name: &ProtoName,
343        _subscription_id: u64,
344        _conn: u64,
345    ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
346        let _ = self.tx.send(SubscriptionCall::RemoveRoute);
347        let (tx, rx) = oneshot::channel();
348        let _ = tx.send(Ok(()));
349        Ok(rx)
350    }
351
352    fn from_slim_tx(_tx: &SlimChannelSender) -> Option<Self> {
353        None
354    }
355}
356
357impl SubscriptionManager {
358    pub fn new(tx: SlimChannelSender) -> Self {
359        Self {
360            pending_acks: Arc::new(Mutex::new(HashMap::new())),
361            ack_counter: Arc::new(AtomicU64::new(rand::random::<u64>())),
362            tx,
363        }
364    }
365
366    fn next_ack_id(&self) -> u64 {
367        self.ack_counter.fetch_add(1, Ordering::Relaxed) + 1
368    }
369
370    async fn send_with_receiver(
371        &self,
372        build_message: impl FnOnce(u64) -> Message,
373    ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
374    {
375        let ack_id = self.next_ack_id();
376        let (ack_tx, ack_rx) = oneshot::channel();
377        {
378            let mut pending = self.pending_acks.lock();
379            pending.insert(ack_id, ack_tx);
380        }
381
382        let msg = build_message(ack_id);
383
384        if self.tx.send(Ok(msg)).await.is_err() {
385            self.pending_acks.lock().remove(&ack_id);
386            return Err(SubscriptionAckError::ChannelClosed);
387        }
388
389        Ok((ack_id, ack_rx))
390    }
391
392    async fn send_with_id(
393        &self,
394        subscription_id: u64,
395        build_message: impl FnOnce(u64) -> Message,
396    ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
397        let ack_rx = self.register_ack_with_id(subscription_id);
398
399        let msg = build_message(subscription_id);
400
401        if self.tx.send(Ok(msg)).await.is_err() {
402            self.pending_acks.lock().remove(&subscription_id);
403            return Err(SubscriptionAckError::ChannelClosed);
404        }
405
406        Ok(ack_rx)
407    }
408
409    /// Register a pending ACK entry and return the ack_id and receiver.
410    /// The caller is responsible for building and sending the message with this ack_id.
411    /// If sending fails, call `cancel_ack` to clean up.
412    pub fn register_ack(&self) -> (u64, oneshot::Receiver<Result<(), SubscriptionAckError>>) {
413        let ack_id = self.next_ack_id();
414        let (ack_tx, ack_rx) = oneshot::channel();
415        {
416            let mut pending = self.pending_acks.lock();
417            pending.insert(ack_id, ack_tx);
418        }
419        (ack_id, ack_rx)
420    }
421
422    /// Register a pending ACK entry under a caller-provided ID and return the receiver.
423    pub fn register_ack_with_id(
424        &self,
425        id: u64,
426    ) -> oneshot::Receiver<Result<(), SubscriptionAckError>> {
427        let (ack_tx, ack_rx) = oneshot::channel();
428        self.pending_acks.lock().insert(id, ack_tx);
429        ack_rx
430    }
431
432    /// Remove a previously registered pending ACK (call on send failure).
433    pub fn cancel_ack(&self, ack_id: u64) {
434        let mut pending = self.pending_acks.lock();
435        pending.remove(&ack_id);
436    }
437
438    /// Await a previously registered ACK receiver, with a deadline of [`ACK_TIMEOUT`].
439    ///
440    /// Uses [`futures_timer::Delay`] rather than `tokio::time::timeout` so that
441    /// this function works correctly outside a Tokio runtime with the time driver
442    /// enabled (e.g. when called from UniFFI async bindings).
443    pub async fn await_ack(
444        ack_rx: oneshot::Receiver<Result<(), SubscriptionAckError>>,
445    ) -> Result<(), SubscriptionAckError> {
446        futures::pin_mut!(ack_rx);
447        let delay = Delay::new(ACK_TIMEOUT);
448        futures::pin_mut!(delay);
449
450        match futures::future::select(ack_rx, delay).await {
451            Either::Left((Ok(result), _)) => result,
452            Either::Left((Err(_), _)) => Err(SubscriptionAckError::ChannelClosed),
453            Either::Right(_) => Err(SubscriptionAckError::Timeout),
454        }
455    }
456
457    /// Called by the App message loop to complete a waiting future for an ACK.
458    pub fn resolve_ack(&self, ack: &ProtoSubscriptionAck) {
459        tracing::debug!(ack = %ack.subscription_id, "ack received");
460        let sender = {
461            let mut pending = self.pending_acks.lock();
462            pending.remove(&ack.subscription_id)
463        };
464
465        if let Some(sender) = sender {
466            let _ = sender.send(if ack.success {
467                Ok(())
468            } else {
469                Err(SubscriptionAckError::Rejected {
470                    message: if ack.error.is_empty() {
471                        "subscription ack failed".to_string()
472                    } else {
473                        ack.error.clone()
474                    },
475                })
476            });
477        } else {
478            tracing::info!(
479                ack_id = %ack.subscription_id,
480                "received subscription ack with no pending waiter"
481            );
482        }
483    }
484}