use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use crate::event::{BarrierDecision, BarrierDecisionMessage, BarrierId};
#[derive(Debug, Clone)]
pub enum BarrierOutcome {
Decision(BarrierDecision),
TimedOut,
Cancelled,
}
pub trait BarrierSink: Send + Sync {
fn wait_decision(
&self,
barrier_id: &BarrierId,
timeout: Option<Duration>,
) -> Pin<Box<dyn std::future::Future<Output = BarrierOutcome> + Send + '_>>;
}
pub struct NoopBarrierSink;
impl BarrierSink for NoopBarrierSink {
fn wait_decision(
&self,
_barrier_id: &BarrierId,
_timeout: Option<Duration>,
) -> Pin<Box<dyn std::future::Future<Output = BarrierOutcome> + Send + '_>> {
Box::pin(async { BarrierOutcome::Decision(BarrierDecision::Approve) })
}
}
pub struct MockBarrierSink {
pub decision: BarrierDecision,
}
impl MockBarrierSink {
pub fn new(decision: BarrierDecision) -> Self {
Self { decision }
}
}
impl BarrierSink for MockBarrierSink {
fn wait_decision(
&self,
_barrier_id: &BarrierId,
_timeout: Option<Duration>,
) -> Pin<Box<dyn std::future::Future<Output = BarrierOutcome> + Send + '_>> {
let decision = self.decision.clone();
Box::pin(async { BarrierOutcome::Decision(decision) })
}
}
struct SharedReceiver<T: Send> {
inner: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<T>>>,
}
impl<T: Send> SharedReceiver<T> {
fn new(rx: tokio::sync::mpsc::Receiver<T>) -> Self {
Self {
inner: Arc::new(tokio::sync::Mutex::new(rx)),
}
}
async fn recv(&self) -> Option<T> {
let mut guard = self.inner.lock().await;
guard.recv().await
}
}
impl<T: Send> Clone for SharedReceiver<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
pub struct ChannelBarrierSink {
decision_rx: SharedReceiver<BarrierDecisionMessage>,
cancel_rx: SharedReceiver<()>,
cancel: Arc<tokio_util::sync::CancellationToken>,
wildcard_cache: Arc<tokio::sync::RwLock<std::collections::HashMap<String, BarrierDecision>>>,
}
impl ChannelBarrierSink {
pub(crate) fn new(
decision_rx: tokio::sync::mpsc::Receiver<BarrierDecisionMessage>,
cancel_rx: tokio::sync::mpsc::Receiver<()>,
cancel: tokio_util::sync::CancellationToken,
) -> Self {
Self {
decision_rx: SharedReceiver::new(decision_rx),
cancel_rx: SharedReceiver::new(cancel_rx),
cancel: Arc::new(cancel),
wildcard_cache: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
}
}
}
impl BarrierSink for ChannelBarrierSink {
fn wait_decision(
&self,
barrier_id: &BarrierId,
timeout: Option<Duration>,
) -> Pin<Box<dyn std::future::Future<Output = BarrierOutcome> + Send + '_>> {
{
let cache_guard = self.wildcard_cache.try_read();
if let Ok(cache) = cache_guard {
if let Some(decision) = cache.get(&barrier_id.node_id) {
let decision = decision.clone();
return Box::pin(async { BarrierOutcome::Decision(decision) });
}
}
}
let decision_rx = self.decision_rx.clone();
let cancel_rx = self.cancel_rx.clone();
let cancel = self.cancel.clone();
let wildcard_cache = self.wildcard_cache.clone();
let barrier_id = barrier_id.clone();
Box::pin(async move {
let outcome = if let Some(dur) = timeout {
tokio::select! {
biased;
_ = cancel_rx.recv() => {
cancel.cancel();
BarrierOutcome::Cancelled
}
_ = tokio::time::sleep(dur) => BarrierOutcome::TimedOut,
msg = decision_rx.recv() => match msg {
Some(BarrierDecisionMessage::Exact { barrier_id: bid, decision }) => {
if bid == barrier_id {
BarrierOutcome::Decision(decision)
} else {
BarrierOutcome::Cancelled
}
}
Some(BarrierDecisionMessage::Wildcard { node_id, decision }) => {
if node_id == barrier_id.node_id {
wildcard_cache.write().await.insert(node_id.clone(), decision.clone());
BarrierOutcome::Decision(decision)
} else {
BarrierOutcome::Cancelled
}
}
None => BarrierOutcome::Cancelled,
},
}
} else {
tokio::select! {
biased;
_ = cancel_rx.recv() => {
cancel.cancel();
BarrierOutcome::Cancelled
}
msg = decision_rx.recv() => match msg {
Some(BarrierDecisionMessage::Exact { barrier_id: bid, decision }) => {
if bid == barrier_id {
BarrierOutcome::Decision(decision)
} else {
BarrierOutcome::Cancelled
}
}
Some(BarrierDecisionMessage::Wildcard { node_id, decision }) => {
if node_id == barrier_id.node_id {
wildcard_cache.write().await.insert(node_id.clone(), decision.clone());
BarrierOutcome::Decision(decision)
} else {
BarrierOutcome::Cancelled
}
}
None => BarrierOutcome::Cancelled,
},
}
};
outcome
})
}
}