use std::{
fmt::Debug,
pin::Pin,
task::{Context, Poll},
};
use crate::Reply;
use async_broadcast::{
InactiveReceiver, Receiver as BroadcastReceiver, Sender as BroadcastSender, broadcast,
};
use pin_project_lite::pin_project;
#[derive(Debug, Clone)]
pub struct State<T, ReplyParams> {
value: T,
tx: BroadcastSender<ReplyParams>,
inactive_rx: InactiveReceiver<ReplyParams>,
}
impl<T, ReplyParams> zlink_core::notified::State<T, ReplyParams> for State<T, ReplyParams>
where
T: Into<ReplyParams> + Clone + Debug + Send,
ReplyParams: Clone + Send + 'static + Debug,
{
type Stream = Stream<ReplyParams>;
fn new(value: T) -> Self {
let (mut tx, rx) = broadcast(1);
tx.set_await_active(false);
tx.set_overflow(true);
let inactive_rx = rx.deactivate();
Self {
value,
tx,
inactive_rx,
}
}
async fn set(&mut self, value: T) {
self.value = value.clone();
self.tx
.broadcast_direct(value.into())
.await
.expect("Failed to broadcast value");
}
fn get(&self) -> T {
self.value.clone()
}
fn stream(&self) -> Stream<ReplyParams> {
Stream {
inner: self.inactive_rx.activate_cloned(),
cached: None,
once: false,
}
}
fn stream_once(&self) -> Stream<ReplyParams> {
Stream {
inner: self.inactive_rx.activate_cloned(),
cached: Some(self.get().into()),
once: true,
}
}
}
pin_project! {
#[derive(Debug)]
pub struct Stream<ReplyParams> {
#[pin]
inner: BroadcastReceiver<ReplyParams>,
cached: Option<ReplyParams>,
once: bool,
}
}
impl<ReplyParams> futures_util::Stream for Stream<ReplyParams>
where
ReplyParams: Clone + Send + 'static,
{
type Item = Reply<ReplyParams>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
if *this.once {
return Poll::Ready(
this.cached
.take()
.map(|reply| Reply::new(Some(reply)).set_continues(Some(false))),
);
}
match futures_util::ready!(this.inner.poll_next(cx)) {
Some(reply) => {
*this.cached = Some(reply.clone());
Poll::Ready(Some(Reply::new(Some(reply)).set_continues(Some(true))))
}
None => Poll::Ready(
this.cached
.take()
.map(|reply| Reply::new(Some(reply)).set_continues(Some(false))),
),
}
}
}