use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use async_trait::async_trait;
use futures::future::Either;
use futures_timer::Delay;
use parking_lot::Mutex;
use thiserror::Error;
use tokio::sync::oneshot;
use slim_datapath::api::{ProtoMessage as Message, ProtoName, ProtoSubscriptionAck};
use slim_datapath::messages::utils::SlimHeaderFlags;
use crate::common::SlimChannelSender;
const ACK_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Error, Debug)]
pub enum SubscriptionAckError {
#[error("ack rejected by datapath: {message}")]
Rejected { message: String },
#[error("ack channel closed")]
ChannelClosed,
#[error("ack timed out")]
Timeout,
}
#[async_trait]
pub trait SubscriptionOps: Clone + Send + Sync + 'static {
async fn subscribe(
&self,
source: &ProtoName,
name: &ProtoName,
forward_to: Option<u64>,
) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>;
async fn unsubscribe(
&self,
source: &ProtoName,
name: &ProtoName,
subscription_id: u64,
forward_to: Option<u64>,
) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError>;
async fn set_route(
&self,
source: &ProtoName,
name: &ProtoName,
conn: u64,
) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>;
async fn remove_route(
&self,
source: &ProtoName,
name: &ProtoName,
subscription_id: u64,
conn: u64,
) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError>;
fn from_slim_tx(_tx: &SlimChannelSender) -> Option<Self>
where
Self: Sized,
{
None
}
}
#[derive(Clone)]
pub struct AutoAckManager {
ack_counter: Arc<AtomicU64>,
}
#[async_trait]
impl SubscriptionOps for AutoAckManager {
async fn subscribe(
&self,
_source: &ProtoName,
_name: &ProtoName,
_forward_to: Option<u64>,
) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
{
let id = self.ack_counter.fetch_add(1, Ordering::Relaxed) + 1;
let (tx, rx) = oneshot::channel();
let _ = tx.send(Ok(()));
Ok((id, rx))
}
async fn unsubscribe(
&self,
_source: &ProtoName,
_name: &ProtoName,
_subscription_id: u64,
_forward_to: Option<u64>,
) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
let (tx, rx) = oneshot::channel();
let _ = tx.send(Ok(()));
Ok(rx)
}
async fn set_route(
&self,
_source: &ProtoName,
_name: &ProtoName,
_conn: u64,
) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
{
let id = self.ack_counter.fetch_add(1, Ordering::Relaxed) + 1;
let (tx, rx) = oneshot::channel();
let _ = tx.send(Ok(()));
Ok((id, rx))
}
async fn remove_route(
&self,
_source: &ProtoName,
_name: &ProtoName,
_subscription_id: u64,
_conn: u64,
) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
let (tx, rx) = oneshot::channel();
let _ = tx.send(Ok(()));
Ok(rx)
}
fn from_slim_tx(_tx: &SlimChannelSender) -> Option<Self> {
Some(AutoAckManager {
ack_counter: Arc::new(AtomicU64::new(0)),
})
}
}
#[derive(Clone)]
pub struct SubscriptionManager {
pub pending_acks: Arc<Mutex<HashMap<u64, oneshot::Sender<Result<(), SubscriptionAckError>>>>>,
ack_counter: Arc<AtomicU64>,
tx: SlimChannelSender,
}
#[async_trait]
impl SubscriptionOps for SubscriptionManager {
async fn subscribe(
&self,
source: &ProtoName,
name: &ProtoName,
forward_to: Option<u64>,
) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
{
let source = source.clone();
let name = name.clone();
self.send_with_receiver(move |ack_id| {
let flags = if let Some(conn) = forward_to {
SlimHeaderFlags::default().with_forward_to(conn)
} else {
SlimHeaderFlags::default()
};
Message::builder()
.source(source)
.destination(name)
.flags(flags)
.subscription_id(ack_id)
.build_subscribe()
.unwrap()
})
.await
}
async fn unsubscribe(
&self,
source: &ProtoName,
name: &ProtoName,
subscription_id: u64,
forward_to: Option<u64>,
) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
let source = source.clone();
let name = name.clone();
self.send_with_id(subscription_id, move |ack_id| {
let flags = if let Some(conn) = forward_to {
SlimHeaderFlags::default().with_forward_to(conn)
} else {
SlimHeaderFlags::default()
};
Message::builder()
.source(source)
.destination(name)
.flags(flags)
.subscription_id(ack_id)
.build_unsubscribe()
.unwrap()
})
.await
}
async fn set_route(
&self,
source: &ProtoName,
name: &ProtoName,
conn: u64,
) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
{
let source = source.clone();
let name = name.clone();
self.send_with_receiver(move |ack_id| {
Message::builder()
.source(source)
.destination(name)
.flags(SlimHeaderFlags::default().with_recv_from(conn))
.subscription_id(ack_id)
.build_subscribe()
.unwrap()
})
.await
}
async fn remove_route(
&self,
source: &ProtoName,
name: &ProtoName,
subscription_id: u64,
conn: u64,
) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
let source = source.clone();
let name = name.clone();
self.send_with_id(subscription_id, move |ack_id| {
Message::builder()
.source(source)
.destination(name)
.flags(SlimHeaderFlags::default().with_recv_from(conn))
.subscription_id(ack_id)
.build_unsubscribe()
.unwrap()
})
.await
}
fn from_slim_tx(tx: &SlimChannelSender) -> Option<Self> {
Some(SubscriptionManager::new(tx.clone()))
}
}
#[cfg(test)]
#[derive(Clone)]
pub struct SpySubscriptionManager {
tx: Arc<tokio::sync::mpsc::UnboundedSender<SubscriptionCall>>,
}
#[cfg(test)]
#[derive(Debug, Clone, PartialEq)]
pub enum SubscriptionCall {
Subscribe,
Unsubscribe,
SetRoute,
RemoveRoute,
}
#[cfg(test)]
impl SpySubscriptionManager {
pub fn new() -> (Self, tokio::sync::mpsc::UnboundedReceiver<SubscriptionCall>) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
(Self { tx: Arc::new(tx) }, rx)
}
}
#[cfg(test)]
#[async_trait]
impl SubscriptionOps for SpySubscriptionManager {
async fn subscribe(
&self,
_source: &ProtoName,
_name: &ProtoName,
_forward_to: Option<u64>,
) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
{
let _ = self.tx.send(SubscriptionCall::Subscribe);
let (tx, rx) = oneshot::channel();
let _ = tx.send(Ok(()));
Ok((0, rx))
}
async fn unsubscribe(
&self,
_source: &ProtoName,
_name: &ProtoName,
_subscription_id: u64,
_forward_to: Option<u64>,
) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
let _ = self.tx.send(SubscriptionCall::Unsubscribe);
let (tx, rx) = oneshot::channel();
let _ = tx.send(Ok(()));
Ok(rx)
}
async fn set_route(
&self,
_source: &ProtoName,
_name: &ProtoName,
_conn: u64,
) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
{
let _ = self.tx.send(SubscriptionCall::SetRoute);
let (tx, rx) = oneshot::channel();
let _ = tx.send(Ok(()));
Ok((0, rx))
}
async fn remove_route(
&self,
_source: &ProtoName,
_name: &ProtoName,
_subscription_id: u64,
_conn: u64,
) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
let _ = self.tx.send(SubscriptionCall::RemoveRoute);
let (tx, rx) = oneshot::channel();
let _ = tx.send(Ok(()));
Ok(rx)
}
fn from_slim_tx(_tx: &SlimChannelSender) -> Option<Self> {
None
}
}
impl SubscriptionManager {
pub fn new(tx: SlimChannelSender) -> Self {
Self {
pending_acks: Arc::new(Mutex::new(HashMap::new())),
ack_counter: Arc::new(AtomicU64::new(rand::random::<u64>())),
tx,
}
}
fn next_ack_id(&self) -> u64 {
self.ack_counter.fetch_add(1, Ordering::Relaxed) + 1
}
async fn send_with_receiver(
&self,
build_message: impl FnOnce(u64) -> Message,
) -> Result<(u64, oneshot::Receiver<Result<(), SubscriptionAckError>>), SubscriptionAckError>
{
let ack_id = self.next_ack_id();
let (ack_tx, ack_rx) = oneshot::channel();
{
let mut pending = self.pending_acks.lock();
pending.insert(ack_id, ack_tx);
}
let msg = build_message(ack_id);
if self.tx.send(Ok(msg)).await.is_err() {
self.pending_acks.lock().remove(&ack_id);
return Err(SubscriptionAckError::ChannelClosed);
}
Ok((ack_id, ack_rx))
}
async fn send_with_id(
&self,
subscription_id: u64,
build_message: impl FnOnce(u64) -> Message,
) -> Result<oneshot::Receiver<Result<(), SubscriptionAckError>>, SubscriptionAckError> {
let ack_rx = self.register_ack_with_id(subscription_id);
let msg = build_message(subscription_id);
if self.tx.send(Ok(msg)).await.is_err() {
self.pending_acks.lock().remove(&subscription_id);
return Err(SubscriptionAckError::ChannelClosed);
}
Ok(ack_rx)
}
pub fn register_ack(&self) -> (u64, oneshot::Receiver<Result<(), SubscriptionAckError>>) {
let ack_id = self.next_ack_id();
let (ack_tx, ack_rx) = oneshot::channel();
{
let mut pending = self.pending_acks.lock();
pending.insert(ack_id, ack_tx);
}
(ack_id, ack_rx)
}
pub fn register_ack_with_id(
&self,
id: u64,
) -> oneshot::Receiver<Result<(), SubscriptionAckError>> {
let (ack_tx, ack_rx) = oneshot::channel();
self.pending_acks.lock().insert(id, ack_tx);
ack_rx
}
pub fn cancel_ack(&self, ack_id: u64) {
let mut pending = self.pending_acks.lock();
pending.remove(&ack_id);
}
pub async fn await_ack(
ack_rx: oneshot::Receiver<Result<(), SubscriptionAckError>>,
) -> Result<(), SubscriptionAckError> {
futures::pin_mut!(ack_rx);
let delay = Delay::new(ACK_TIMEOUT);
futures::pin_mut!(delay);
match futures::future::select(ack_rx, delay).await {
Either::Left((Ok(result), _)) => result,
Either::Left((Err(_), _)) => Err(SubscriptionAckError::ChannelClosed),
Either::Right(_) => Err(SubscriptionAckError::Timeout),
}
}
pub fn resolve_ack(&self, ack: &ProtoSubscriptionAck) {
tracing::debug!(ack = %ack.subscription_id, "ack received");
let sender = {
let mut pending = self.pending_acks.lock();
pending.remove(&ack.subscription_id)
};
if let Some(sender) = sender {
let _ = sender.send(if ack.success {
Ok(())
} else {
Err(SubscriptionAckError::Rejected {
message: if ack.error.is_empty() {
"subscription ack failed".to_string()
} else {
ack.error.clone()
},
})
});
} else {
tracing::info!(
ack_id = %ack.subscription_id,
"received subscription ack with no pending waiter"
);
}
}
}