use std::sync::Arc;
use futures::stream::BoxStream;
use futures::{Stream, StreamExt};
use parking_lot::RwLock;
use tokio::sync::watch;
use tokio_stream::wrappers::WatchStream;
use super::query::{ContentNeedle, MemoriesFilterSpec, OrderBy};
use super::state::MemoriesState;
use super::types::{Memory, MemoryId};
pub struct MemoriesWatcher {
state: Arc<RwLock<MemoriesState>>,
changes: BoxStream<'static, u64>,
spec: MemoriesFilterSpec,
}
impl MemoriesWatcher {
pub(super) fn new(state: Arc<RwLock<MemoriesState>>, changes: BoxStream<'static, u64>) -> Self {
Self {
state,
changes,
spec: MemoriesFilterSpec::default(),
}
}
pub fn where_id_in(mut self, ids: impl IntoIterator<Item = MemoryId>) -> Self {
self.spec.id_in = Some(ids.into_iter().collect());
self
}
pub fn where_source(mut self, source: impl Into<String>) -> Self {
self.spec.source = Some(source.into());
self
}
pub fn content_contains(mut self, needle: impl Into<String>) -> Self {
self.spec.content_contains = Some(ContentNeedle::new(needle));
self
}
pub fn where_tag(mut self, tag: impl Into<String>) -> Self {
self.spec.require_tag = Some(tag.into());
self
}
pub fn where_any_tag(mut self, tags: impl IntoIterator<Item = String>) -> Self {
self.spec.require_any_tag = Some(tags.into_iter().collect());
self
}
pub fn where_all_tags(mut self, tags: impl IntoIterator<Item = String>) -> Self {
self.spec.require_all_tags = Some(tags.into_iter().collect());
self
}
pub fn where_pinned(mut self, pinned: bool) -> Self {
self.spec.only_pinned = Some(pinned);
self
}
pub fn created_after(mut self, ns: u64) -> Self {
self.spec.created_after_ns = Some(ns);
self
}
pub fn created_before(mut self, ns: u64) -> Self {
self.spec.created_before_ns = Some(ns);
self
}
pub fn updated_after(mut self, ns: u64) -> Self {
self.spec.updated_after_ns = Some(ns);
self
}
pub fn updated_before(mut self, ns: u64) -> Self {
self.spec.updated_before_ns = Some(ns);
self
}
pub fn order_by(mut self, order: OrderBy) -> Self {
self.spec.order_by = Some(order);
self
}
pub fn limit(mut self, n: usize) -> Self {
self.spec.limit = Some(n);
self
}
pub(super) fn spec_for_snapshot(&self) -> MemoriesFilterSpec {
let mut spec = self.spec.clone();
if spec.order_by.is_none() {
spec.order_by = Some(OrderBy::IdAsc);
}
spec
}
pub fn stream(self) -> impl Stream<Item = Vec<std::sync::Arc<Memory>>> + Send + 'static {
let MemoriesWatcher {
state,
mut changes,
mut spec,
} = self;
if spec.order_by.is_none() {
spec.order_by = Some(OrderBy::IdAsc);
}
let initial = {
let guard = state.read();
spec.execute(&guard)
};
let (tx, rx) = watch::channel(initial.clone());
tokio::spawn(async move {
let mut last = initial;
loop {
tokio::select! {
_ = tx.closed() => return,
maybe_seq = changes.next() => {
let Some(_seq) = maybe_seq else { return };
use futures::future::FutureExt;
let mut stream_ended = false;
while let Some(maybe_more) = changes.next().now_or_never() {
if maybe_more.is_none() {
stream_ended = true;
break;
}
}
let current = {
let guard = state.read();
spec.execute(&guard)
};
if current != last {
if tx.send(current.clone()).is_err() {
return;
}
last = current;
}
if stream_ended {
return;
}
}
}
}
});
WatchStream::new(rx)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn burst_drain_hitting_end_of_stream_still_delivers_final_state() {
let state = Arc::new(RwLock::new(MemoriesState::new()));
let changes = StreamExt::boxed(futures::stream::iter(vec![0u64, 1, 2]));
let watcher = MemoriesWatcher::new(state.clone(), changes);
let mut out = Box::pin(watcher.stream());
state.write().memories.insert(
7,
Arc::new(Memory {
id: 7,
content: "final value".into(),
tags: vec!["t".into()],
source: "test".into(),
created_ns: 1,
updated_ns: 1,
pinned: false,
}),
);
let initial = out.next().await.unwrap();
assert!(initial.is_empty(), "initial snapshot precedes the insert");
let last = out
.next()
.await
.expect("final state must be delivered before the stream ends");
assert_eq!(last.len(), 1);
assert_eq!(last[0].id, 7);
assert!(
out.next().await.is_none(),
"stream ends after the final delivery"
);
}
}