use commonware_utils::channel::oneshot::{self, error::RecvError};
use futures::{future::Shared, FutureExt};
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
#[derive(Clone)]
pub enum Signal {
Open(Receiver),
Closed(i32),
}
impl Future for Signal {
type Output = Result<i32, RecvError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match &mut *self {
Self::Open(live) => Pin::new(&mut live.inner).poll(cx),
Self::Closed(value) => Poll::Ready(Ok(*value)),
}
}
}
#[derive(Clone)]
pub struct Receiver {
inner: Shared<oneshot::Receiver<i32>>,
_guard: Arc<Guard>,
}
struct Guard {
tx: Option<oneshot::Sender<()>>,
}
impl Guard {
pub const fn new(completion_tx: oneshot::Sender<()>) -> Self {
Self {
tx: Some(completion_tx),
}
}
}
impl Drop for Guard {
fn drop(&mut self) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(());
}
}
}
pub struct Signaler {
tx: oneshot::Sender<i32>,
completion_rx: oneshot::Receiver<()>,
}
impl Signaler {
pub fn new() -> (Self, Signal) {
let (tx, rx) = oneshot::channel();
let (completion_tx, completion_rx) = oneshot::channel();
let signaler = Self { tx, completion_rx };
let signal = Signal::Open(Receiver {
inner: rx.shared(),
_guard: Arc::new(Guard::new(completion_tx)),
});
(signaler, signal)
}
pub fn signal(self, value: i32) -> oneshot::Receiver<()> {
let _ = self.tx.send(value);
self.completion_rx
}
}
pub enum Stopper {
Running {
signaler: Option<Signaler>,
signal: Signal,
},
Stopped {
stop_value: i32,
completion: Shared<oneshot::Receiver<()>>,
},
}
impl Stopper {
pub fn new() -> Self {
let (signaler, signal) = Signaler::new();
Self::Running {
signaler: Some(signaler),
signal,
}
}
pub fn stopped(&self) -> Signal {
match self {
Self::Running { signal, .. } => signal.clone(),
Self::Stopped { stop_value, .. } => Signal::Closed(*stop_value),
}
}
pub fn stop(&mut self, value: i32) -> Shared<oneshot::Receiver<()>> {
match self {
Self::Running { signaler, .. } => {
let sig = signaler.take().unwrap();
let completion_rx = sig.signal(value);
let shared_completion = completion_rx.shared();
*self = Self::Stopped {
stop_value: value,
completion: shared_completion.clone(),
};
shared_completion
}
Self::Stopped { completion, .. } => {
completion.clone()
}
}
}
}
impl Default for Stopper {
fn default() -> Self {
Self::new()
}
}