use std::{error::Error, fmt::Display};
use crate::{atomic_try_update, bits::FlagU64, Atom};
pub struct ShutdownBarrierWaitResult {
cancelled: bool,
}
pub struct ShutdownBarrierDoneResult {
cancelled: bool,
shutdown_leader: bool,
}
impl ShutdownBarrierWaitResult {
pub fn is_cancelled(&self) -> bool {
self.cancelled
}
}
impl ShutdownBarrierDoneResult {
pub fn is_cancelled(&self) -> bool {
self.cancelled
}
pub fn is_leader(&self) -> bool {
self.shutdown_leader
}
}
pub struct ShutdownBarrier {
state: Atom<FlagU64, u64>,
broadcast: tokio::sync::broadcast::Sender<bool>,
}
enum WaitResult {
StillRunning,
Shutdown,
Cancelled,
}
#[derive(Debug)]
enum DoneResult {
Cancelled,
AlreadyDone,
ShutdownLeader,
Running,
}
impl Default for ShutdownBarrier {
fn default() -> Self {
let this = Self {
state: Default::default(),
broadcast: tokio::sync::broadcast::channel(1).0,
};
unsafe {
atomic_try_update(&this.state, |s| {
s.set_val(1);
(true, ())
});
}
this
}
}
#[derive(Debug)]
pub enum ShutdownBarrierError {
AlreadyShutdown,
}
impl Error for ShutdownBarrierError {}
impl Display for ShutdownBarrierError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl ShutdownBarrier {
pub fn spawn(&self) -> Result<(), ShutdownBarrierError> {
let already_shutdown = unsafe {
atomic_try_update(&self.state, |s| {
let count = s.get_val();
if s.get_flag() || count == 0 {
(false, true) } else {
s.set_val(count + 1);
(true, false)
}
})
};
if already_shutdown {
Err(ShutdownBarrierError::AlreadyShutdown)
} else {
Ok(())
}
}
pub fn cancel(&self) -> Result<(), ShutdownBarrierError> {
let already_shutdown = unsafe {
atomic_try_update(&self.state, |s| {
let count = s.get_val();
if s.get_flag() || count == 0 {
(false, true)
} else {
s.set_flag(true);
(true, false)
}
})
};
if already_shutdown {
Err(ShutdownBarrierError::AlreadyShutdown)
} else {
_ = self.broadcast.send(true);
Ok(())
}
}
pub fn done(&self) -> Result<ShutdownBarrierDoneResult, ShutdownBarrierError> {
let done_result = unsafe {
atomic_try_update(&self.state, |s| {
let count = s.get_val();
s.set_val(count - 1);
if s.get_flag() {
(true, DoneResult::Cancelled)
} else if count == 0 {
(false, DoneResult::AlreadyDone)
} else if count == 1 {
(true, DoneResult::ShutdownLeader)
} else {
(true, DoneResult::Running)
}
})
};
match done_result {
DoneResult::Cancelled => Ok(ShutdownBarrierDoneResult {
cancelled: true,
shutdown_leader: false,
}),
DoneResult::ShutdownLeader => {
_ = self.broadcast.send(false);
Ok(ShutdownBarrierDoneResult {
cancelled: false,
shutdown_leader: true,
})
}
DoneResult::Running => Ok(ShutdownBarrierDoneResult {
cancelled: false,
shutdown_leader: false,
}),
DoneResult::AlreadyDone => Err(ShutdownBarrierError::AlreadyShutdown),
}
}
pub async fn wait(&self) -> Result<ShutdownBarrierWaitResult, ShutdownBarrierError> {
let mut rx = self.broadcast.subscribe();
let wait_result = unsafe {
atomic_try_update(&self.state, |s| {
let count = s.get_val();
if s.get_flag() {
(false, WaitResult::Cancelled)
} else if count == 0 {
(true, WaitResult::Shutdown)
} else {
(false, WaitResult::StillRunning)
}
})
};
match wait_result {
WaitResult::StillRunning => {
let cancelled = rx
.recv()
.await
.map_err(|_| ShutdownBarrierError::AlreadyShutdown)?;
Ok(ShutdownBarrierWaitResult { cancelled })
}
WaitResult::Shutdown => Ok(ShutdownBarrierWaitResult { cancelled: false }),
WaitResult::Cancelled => Ok(ShutdownBarrierWaitResult { cancelled: true }),
}
}
pub fn new() -> Self {
Default::default()
}
}