1use std::collections::HashMap;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicU64, Ordering};
7
8use std::time::Duration;
9
10use async_trait::async_trait;
11use futures::future::Either;
12use futures_timer::Delay;
13use parking_lot::Mutex;
14use thiserror::Error;
15use tokio::sync::oneshot;
16
17use slim_datapath::api::{ProtoMessage as Message, ProtoName, ProtoSubscriptionAck};
18use slim_datapath::messages::utils::SlimHeaderFlags;
19
20use crate::common::SlimChannelSender;
21
22const ACK_TIMEOUT: Duration = Duration::from_secs(10);
30
31#[derive(Error, Debug)]
32pub enum SubscriptionAckError {
33 #[error("ack rejected by datapath: {message}")]
34 Rejected { message: String },
35 #[error("ack channel closed")]
36 ChannelClosed,
37 #[error("ack timed out")]
38 Timeout,
39}
40
41#[async_trait]
48pub trait SubscriptionOps: Clone + Send + Sync + 'static {
49 async fn subscribe(
52 &self,
53 source: &ProtoName,
54 name: &ProtoName,
55 forward_to: Option<u64>,
56 ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>;
57
58 async fn unsubscribe(
60 &self,
61 source: &ProtoName,
62 name: &ProtoName,
63 subscription_id: u64,
64 forward_to: Option<u64>,
65 ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError>;
66
67 async fn set_route(
69 &self,
70 source: &ProtoName,
71 name: &ProtoName,
72 conn: u64,
73 ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>;
74
75 async fn remove_route(
77 &self,
78 source: &ProtoName,
79 name: &ProtoName,
80 subscription_id: u64,
81 conn: u64,
82 ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError>;
83
84 fn from_slim_tx(_tx: &SlimChannelSender) -> Option<Self>
89 where
90 Self: Sized,
91 {
92 None
93 }
94}
95
96#[derive(Clone)]
100pub struct AutoAckManager {
101 ack_counter: Arc<AtomicU64>,
102}
103
104#[async_trait]
105impl SubscriptionOps for AutoAckManager {
106 async fn subscribe(
107 &self,
108 _source: &ProtoName,
109 _name: &ProtoName,
110 _forward_to: Option<u64>,
111 ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
112 {
113 let id = self.ack_counter.fetch_add(1, Ordering::Relaxed) + 1;
114 let (tx, rx) = oneshot::channel();
115 let _ = tx.send(Ok(()));
116 Ok((id, rx))
117 }
118
119 async fn unsubscribe(
120 &self,
121 _source: &ProtoName,
122 _name: &ProtoName,
123 _subscription_id: u64,
124 _forward_to: Option<u64>,
125 ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
126 let (tx, rx) = oneshot::channel();
127 let _ = tx.send(Ok(()));
128 Ok(rx)
129 }
130
131 async fn set_route(
132 &self,
133 _source: &ProtoName,
134 _name: &ProtoName,
135 _conn: u64,
136 ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
137 {
138 let id = self.ack_counter.fetch_add(1, Ordering::Relaxed) + 1;
139 let (tx, rx) = oneshot::channel();
140 let _ = tx.send(Ok(()));
141 Ok((id, rx))
142 }
143
144 async fn remove_route(
145 &self,
146 _source: &ProtoName,
147 _name: &ProtoName,
148 _subscription_id: u64,
149 _conn: u64,
150 ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
151 let (tx, rx) = oneshot::channel();
152 let _ = tx.send(Ok(()));
153 Ok(rx)
154 }
155
156 fn from_slim_tx(_tx: &SlimChannelSender) -> Option<Self> {
157 Some(AutoAckManager {
158 ack_counter: Arc::new(AtomicU64::new(0)),
159 })
160 }
161}
162
163#[derive(Clone)]
164pub struct SubscriptionManager {
165 pub pending_acks: Arc<Mutex<HashMap<u64, oneshot::Sender<Result<(), SubscriptionAckError>>>>>,
166 ack_counter: Arc<AtomicU64>,
167 tx: SlimChannelSender,
168}
169
170#[async_trait]
171impl SubscriptionOps for SubscriptionManager {
172 async fn subscribe(
173 &self,
174 source: &ProtoName,
175 name: &ProtoName,
176 forward_to: Option<u64>,
177 ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
178 {
179 let source = source.clone();
180 let name = name.clone();
181 self.send_with_receiver(move |ack_id| {
182 let flags = if let Some(conn) = forward_to {
183 SlimHeaderFlags::default().with_forward_to(conn)
184 } else {
185 SlimHeaderFlags::default()
186 };
187 Message::builder()
188 .source(source)
189 .destination(name)
190 .flags(flags)
191 .subscription_id(ack_id)
192 .build_subscribe()
193 .unwrap()
194 })
195 .await
196 }
197
198 async fn unsubscribe(
199 &self,
200 source: &ProtoName,
201 name: &ProtoName,
202 subscription_id: u64,
203 forward_to: Option<u64>,
204 ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
205 let source = source.clone();
206 let name = name.clone();
207 self.send_with_id(subscription_id, move |ack_id| {
208 let flags = if let Some(conn) = forward_to {
209 SlimHeaderFlags::default().with_forward_to(conn)
210 } else {
211 SlimHeaderFlags::default()
212 };
213 Message::builder()
214 .source(source)
215 .destination(name)
216 .flags(flags)
217 .subscription_id(ack_id)
218 .build_unsubscribe()
219 .unwrap()
220 })
221 .await
222 }
223
224 async fn set_route(
225 &self,
226 source: &ProtoName,
227 name: &ProtoName,
228 conn: u64,
229 ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
230 {
231 let source = source.clone();
232 let name = name.clone();
233 self.send_with_receiver(move |ack_id| {
234 Message::builder()
235 .source(source)
236 .destination(name)
237 .flags(SlimHeaderFlags::default().with_recv_from(conn))
238 .subscription_id(ack_id)
239 .build_subscribe()
240 .unwrap()
241 })
242 .await
243 }
244
245 async fn remove_route(
246 &self,
247 source: &ProtoName,
248 name: &ProtoName,
249 subscription_id: u64,
250 conn: u64,
251 ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
252 let source = source.clone();
253 let name = name.clone();
254 self.send_with_id(subscription_id, move |ack_id| {
255 Message::builder()
256 .source(source)
257 .destination(name)
258 .flags(SlimHeaderFlags::default().with_recv_from(conn))
259 .subscription_id(ack_id)
260 .build_unsubscribe()
261 .unwrap()
262 })
263 .await
264 }
265
266 fn from_slim_tx(tx: &SlimChannelSender) -> Option<Self> {
267 Some(SubscriptionManager::new(tx.clone()))
268 }
269}
270
271#[cfg(test)]
274#[derive(Clone)]
275pub struct SpySubscriptionManager {
276 tx: Arc<tokio::sync::mpsc::UnboundedSender<SubscriptionCall>>,
277}
278
279#[cfg(test)]
281#[derive(Debug, Clone, PartialEq)]
282pub enum SubscriptionCall {
283 Subscribe,
284 Unsubscribe,
285 SetRoute,
286 RemoveRoute,
287}
288
289#[cfg(test)]
290impl SpySubscriptionManager {
291 pub fn new() -> (Self, tokio::sync::mpsc::UnboundedReceiver<SubscriptionCall>) {
292 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
293 (Self { tx: Arc::new(tx) }, rx)
294 }
295}
296
297#[cfg(test)]
298#[async_trait]
299impl SubscriptionOps for SpySubscriptionManager {
300 async fn subscribe(
301 &self,
302 _source: &ProtoName,
303 _name: &ProtoName,
304 _forward_to: Option<u64>,
305 ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
306 {
307 let _ = self.tx.send(SubscriptionCall::Subscribe);
308 let (tx, rx) = oneshot::channel();
309 let _ = tx.send(Ok(()));
310 Ok((0, rx))
311 }
312
313 async fn unsubscribe(
314 &self,
315 _source: &ProtoName,
316 _name: &ProtoName,
317 _subscription_id: u64,
318 _forward_to: Option<u64>,
319 ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
320 let _ = self.tx.send(SubscriptionCall::Unsubscribe);
321 let (tx, rx) = oneshot::channel();
322 let _ = tx.send(Ok(()));
323 Ok(rx)
324 }
325
326 async fn set_route(
327 &self,
328 _source: &ProtoName,
329 _name: &ProtoName,
330 _conn: u64,
331 ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
332 {
333 let _ = self.tx.send(SubscriptionCall::SetRoute);
334 let (tx, rx) = oneshot::channel();
335 let _ = tx.send(Ok(()));
336 Ok((0, rx))
337 }
338
339 async fn remove_route(
340 &self,
341 _source: &ProtoName,
342 _name: &ProtoName,
343 _subscription_id: u64,
344 _conn: u64,
345 ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
346 let _ = self.tx.send(SubscriptionCall::RemoveRoute);
347 let (tx, rx) = oneshot::channel();
348 let _ = tx.send(Ok(()));
349 Ok(rx)
350 }
351
352 fn from_slim_tx(_tx: &SlimChannelSender) -> Option<Self> {
353 None
354 }
355}
356
357impl SubscriptionManager {
358 pub fn new(tx: SlimChannelSender) -> Self {
359 Self {
360 pending_acks: Arc::new(Mutex::new(HashMap::new())),
361 ack_counter: Arc::new(AtomicU64::new(rand::random::<u64>())),
362 tx,
363 }
364 }
365
366 fn next_ack_id(&self) -> u64 {
367 self.ack_counter.fetch_add(1, Ordering::Relaxed) + 1
368 }
369
370 async fn send_with_receiver(
371 &self,
372 build_message: impl FnOnce(u64) -> Message,
373 ) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
374 {
375 let ack_id = self.next_ack_id();
376 let (ack_tx, ack_rx) = oneshot::channel();
377 {
378 let mut pending = self.pending_acks.lock();
379 pending.insert(ack_id, ack_tx);
380 }
381
382 let msg = build_message(ack_id);
383
384 if self.tx.send(Ok(msg)).await.is_err() {
385 self.pending_acks.lock().remove(&ack_id);
386 return Err(SubscriptionAckError::ChannelClosed);
387 }
388
389 Ok((ack_id, ack_rx))
390 }
391
392 async fn send_with_id(
393 &self,
394 subscription_id: u64,
395 build_message: impl FnOnce(u64) -> Message,
396 ) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
397 let ack_rx = self.register_ack_with_id(subscription_id);
398
399 let msg = build_message(subscription_id);
400
401 if self.tx.send(Ok(msg)).await.is_err() {
402 self.pending_acks.lock().remove(&subscription_id);
403 return Err(SubscriptionAckError::ChannelClosed);
404 }
405
406 Ok(ack_rx)
407 }
408
409 pub fn register_ack(&self) -> (u64, oneshot::Receiver<Result<(), SubscriptionAckError>>) {
413 let ack_id = self.next_ack_id();
414 let (ack_tx, ack_rx) = oneshot::channel();
415 {
416 let mut pending = self.pending_acks.lock();
417 pending.insert(ack_id, ack_tx);
418 }
419 (ack_id, ack_rx)
420 }
421
422 pub fn register_ack_with_id(
424 &self,
425 id: u64,
426 ) -> oneshot::Receiver<Result<(), SubscriptionAckError>> {
427 let (ack_tx, ack_rx) = oneshot::channel();
428 self.pending_acks.lock().insert(id, ack_tx);
429 ack_rx
430 }
431
432 pub fn cancel_ack(&self, ack_id: u64) {
434 let mut pending = self.pending_acks.lock();
435 pending.remove(&ack_id);
436 }
437
438 pub async fn await_ack(
444 ack_rx: oneshot::Receiver<Result<(), SubscriptionAckError>>,
445 ) -> Result<(), SubscriptionAckError> {
446 futures::pin_mut!(ack_rx);
447 let delay = Delay::new(ACK_TIMEOUT);
448 futures::pin_mut!(delay);
449
450 match futures::future::select(ack_rx, delay).await {
451 Either::Left((Ok(result), _)) => result,
452 Either::Left((Err(_), _)) => Err(SubscriptionAckError::ChannelClosed),
453 Either::Right(_) => Err(SubscriptionAckError::Timeout),
454 }
455 }
456
457 pub fn resolve_ack(&self, ack: &ProtoSubscriptionAck) {
459 tracing::debug!(ack = %ack.subscription_id, "ack received");
460 let sender = {
461 let mut pending = self.pending_acks.lock();
462 pending.remove(&ack.subscription_id)
463 };
464
465 if let Some(sender) = sender {
466 let _ = sender.send(if ack.success {
467 Ok(())
468 } else {
469 Err(SubscriptionAckError::Rejected {
470 message: if ack.error.is_empty() {
471 "subscription ack failed".to_string()
472 } else {
473 ack.error.clone()
474 },
475 })
476 });
477 } else {
478 tracing::info!(
479 ack_id = %ack.subscription_id,
480 "received subscription ack with no pending waiter"
481 );
482 }
483 }
484}