use std::future::Future;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use futures_util::Stream;
use futures_util::StreamExt;
use futures_util::TryStreamExt;
use crate::domain::SharedDomainData;
use crate::stream_status::StreamStatus;
type StreamStatusSetter<D, T> = Arc<dyn Fn(&mut D, StreamStatus<T>) + Send + Sync>;
type StreamItemCallback<T, I> = Arc<dyn Fn(&mut T, I) + Send + Sync>;
pub struct StreamHandle {
cancel: tokio::sync::watch::Sender<bool>,
inner: tokio::task::JoinHandle<()>,
}
impl StreamHandle {
pub fn abort(&self) {
let _ = self.cancel.send(true);
self.inner.abort();
}
}
impl Future for StreamHandle {
type Output = Result<(), tokio::task::JoinError>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let inner = unsafe { self.map_unchecked_mut(|s| &mut s.inner) };
inner.poll(cx)
}
}
impl Deref for StreamHandle {
type Target = tokio::task::JoinHandle<()>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
pub struct StreamExecutor<D> {
pub(crate) domain: SharedDomainData<D>,
pub(crate) rt: tokio::runtime::Handle,
}
impl<D> StreamExecutor<D>
where
D: Clone + Send + Sync + 'static,
{
pub fn from_stream<St, I>(self, stream: St) -> StreamSource<D, I>
where
St: Stream<Item = I> + Send + 'static,
I: Send + 'static,
{
StreamSource {
domain: self.domain,
rt: self.rt,
stream: StreamKind::Infallible(Box::pin(stream)),
}
}
pub fn try_from_stream<St, I, E>(self, stream: St) -> StreamSource<D, I>
where
St: Stream<Item = Result<I, E>> + Send + 'static,
I: Send + 'static,
E: std::fmt::Display + Send + 'static,
{
let mapped = stream.map_err(|e: E| e.to_string());
StreamSource {
domain: self.domain,
rt: self.rt,
stream: StreamKind::Fallible(Box::pin(mapped)),
}
}
}
enum StreamKind<I> {
Infallible(Pin<Box<dyn Stream<Item = I> + Send>>),
Fallible(Pin<Box<dyn Stream<Item = Result<I, String>> + Send>>),
}
impl<I> StreamKind<I> {
async fn next(&mut self) -> Option<Result<I, String>> {
match self {
StreamKind::Infallible(s) => s.next().await.map(Ok),
StreamKind::Fallible(s) => s.next().await,
}
}
}
pub struct StreamSource<D, I> {
domain: SharedDomainData<D>,
rt: tokio::runtime::Handle,
stream: StreamKind<I>,
}
impl<D, I> StreamSource<D, I>
where
D: Clone + Send + Sync + 'static,
I: Send + 'static,
{
pub fn into<T>(
self,
setter: impl Fn(&mut D, StreamStatus<T>) + Send + Sync + 'static,
initial: T,
on_item: impl Fn(&mut T, I) + Send + Sync + 'static,
) -> StreamConfig<D, I, T>
where
T: Clone + Send + 'static,
{
StreamConfig {
domain: self.domain,
rt: self.rt,
stream: self.stream,
setter: Arc::new(setter),
on_item: Arc::new(on_item),
initial,
batch_interval: None,
}
}
pub fn go(self) -> StreamHandle {
let (cancel_tx, mut cancel_rx) = tokio::sync::watch::channel(false);
let inner = self.rt.spawn(async move {
let mut stream = self.stream;
loop {
tokio::select! {
biased;
_ = cancel_rx.changed() => {
break;
}
item = stream.next() => {
match item {
Some(Ok(_)) => {}
Some(Err(_)) | None => {
break;
}
}
}
}
}
});
StreamHandle {
cancel: cancel_tx,
inner,
}
}
}
pub struct StreamConfig<D, I, T> {
domain: SharedDomainData<D>,
rt: tokio::runtime::Handle,
stream: StreamKind<I>,
setter: StreamStatusSetter<D, T>,
on_item: StreamItemCallback<T, I>,
initial: T,
batch_interval: Option<Duration>,
}
impl<D, I, T> StreamConfig<D, I, T>
where
D: Clone + Send + Sync + 'static,
I: Send + 'static,
T: Clone + Send + 'static,
{
#[must_use = "the builder must be consumed with `.go()` to spawn the stream"]
pub fn batch(mut self, interval: Duration) -> Self {
self.batch_interval = Some(interval);
self
}
#[must_use = "the config must be consumed with `.go()` to spawn the stream"]
pub fn go(self) -> StreamHandle {
let (cancel_tx, mut cancel_rx) = tokio::sync::watch::channel(false);
let StreamConfig {
domain,
rt,
stream,
setter,
on_item,
initial,
batch_interval,
} = self;
let inner = rt.spawn(async move {
let mut buffer = initial;
let mut last_flush = std::time::Instant::now();
{
let s = setter.clone();
let status = StreamStatus::streaming(buffer.clone());
domain.modify(move |d| s(d, status));
}
let mut stream = stream;
loop {
tokio::select! {
biased;
_ = cancel_rx.changed() => {
let s = setter.clone();
let status = StreamStatus::aborted(buffer.clone());
domain.modify(move |d| s(d, status));
break;
}
item = stream.next() => {
match item {
Some(Ok(item)) => {
on_item(&mut buffer, item);
let should_flush = match batch_interval {
Some(interval) => last_flush.elapsed() >= interval,
None => true,
};
if should_flush {
last_flush = std::time::Instant::now();
let s = setter.clone();
let status = StreamStatus::streaming(buffer.clone());
domain.modify(move |d| s(d, status));
}
}
Some(Err(e)) => {
let s = setter.clone();
let status = StreamStatus::error(&e, buffer.clone());
domain.modify(move |d| s(d, status));
break;
}
None => {
let s = setter.clone();
let status = StreamStatus::completed(buffer.clone());
domain.modify(move |d| s(d, status));
break;
}
}
}
}
}
});
StreamHandle {
cancel: cancel_tx,
inner,
}
}
}