use crate::{Ndb, NoteKey, Subscription};
use std::{
pin::Pin,
task::{Context, Poll},
};
use futures::Stream;
use tracing::error;
#[derive(Debug, Clone)]
pub(crate) struct SubscriptionState {
pub done: bool,
pub waker: Option<std::task::Waker>,
}
pub struct SubscriptionStream {
ndb: Ndb,
sub_id: Subscription,
max_notes: u32,
unsubscribe_on_drop: bool,
}
impl SubscriptionStream {
pub fn new(ndb: Ndb, sub_id: Subscription) -> Self {
let max_notes = 32;
let unsubscribe_on_drop = true;
SubscriptionStream {
ndb,
sub_id,
unsubscribe_on_drop,
max_notes,
}
}
pub fn notes_per_await(mut self, max_notes: u32) -> Self {
self.max_notes = max_notes;
self
}
pub fn unsubscribe_on_drop(mut self, yes: bool) -> Self {
self.unsubscribe_on_drop = yes;
self
}
pub fn sub_id(&self) -> Subscription {
self.sub_id
}
}
impl Drop for SubscriptionStream {
fn drop(&mut self) {
{
let mut map = self.ndb.subs.lock().unwrap();
map.remove(&self.sub_id);
}
if let Err(err) = self.ndb.unsubscribe(self.sub_id) {
error!(
"Error unsubscribing from {} in SubscriptionStream Drop: {err}",
self.sub_id.id()
);
}
}
}
impl Stream for SubscriptionStream {
type Item = Vec<NoteKey>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let pinned = std::pin::pin!(self);
let me = pinned.as_ref().get_ref();
let mut map = me.ndb.subs.lock().unwrap();
let sub_state = map.entry(me.sub_id).or_insert(SubscriptionState {
done: false,
waker: None,
});
if sub_state.done {
return Poll::Ready(None);
}
let notes = me.ndb.poll_for_notes(me.sub_id, me.max_notes);
if !notes.is_empty() {
return Poll::Ready(Some(notes));
}
sub_state.waker = Some(cx.waker().clone());
std::task::Poll::Pending
}
}