use std::{
any::Any,
future::Future,
pin::Pin,
task::{self, Poll},
};
use futures::{self, channel::mpsc, sink::SinkExt as _, stream, stream::StreamExt as _};
use pin_project::pin_project;
use sealed::sealed;
use crate::{
envelope::{Envelope, MessageKind},
message::{AnyMessage, Message},
scope::{self, Scope},
source::{SourceArc, SourceStream, UnattachedSource, UntypedSourceArc},
tracing::TraceId,
Addr,
};
pub struct Stream<M = AnyMessage> {
source: SourceArc<StreamSource<dyn futures::Stream<Item = M> + Send + 'static>>,
}
#[sealed]
impl<M: StreamItem> crate::source::SourceHandle for Stream<M> {
fn is_terminated(&self) -> bool {
self.source.is_terminated()
}
fn terminate_by_ref(&self) -> bool {
self.source.terminate_by_ref()
}
}
impl<M: StreamItem> Stream<M> {
pub fn from_futures03<S>(stream: S) -> UnattachedSource<Self>
where
S: futures::Stream<Item = M> + Send + 'static,
{
Self::from_futures03_inner(stream, true, false)
}
pub fn once<F>(future: F) -> UnattachedSource<Self>
where
F: Future<Output = M> + Send + 'static,
{
Self::from_futures03_inner(stream::once(future), false, true)
}
fn from_futures03_inner(
stream: impl futures::Stream<Item = M> + Send + 'static,
rewrite_trace_id: bool,
oneshot: bool,
) -> UnattachedSource<Self> {
#[cfg(not(feature = "test-util"))]
let scope = scope::expose();
#[cfg(feature = "test-util")]
let scope = scope::try_expose().unwrap_or_else(|| {
Scope::test(
Addr::NULL,
std::sync::Arc::new(crate::actor::ActorMeta {
group: "test".into(),
key: "test".into(),
}),
)
});
let source = StreamSource {
scope,
rewrite_trace_id,
inner: stream,
};
if rewrite_trace_id {
source.scope.set_trace_id(TraceId::generate());
}
let source = SourceArc::from_untyped(UntypedSourceArc::new(source, oneshot));
UnattachedSource::new(source, |source| Self { source })
}
}
impl Stream<AnyMessage> {
pub fn generate<G, F>(generator: G) -> UnattachedSource<Self>
where
G: FnOnce(Emitter) -> F,
F: Future<Output = ()> + Send + 'static,
{
let (tx, rx) = mpsc::channel(0);
let gen = generator(Emitter(tx));
let gen = stream::once(gen).filter_map(|_| async { None });
let stream = stream::select(gen, rx);
Self::from_futures03_inner(stream, false, false)
}
}
#[pin_project]
struct StreamSource<S: ?Sized> {
scope: Scope,
rewrite_trace_id: bool,
#[pin]
inner: S,
}
impl<S, M> SourceStream for StreamSource<S>
where
S: futures::Stream<Item = M> + ?Sized + Send + 'static,
M: StreamItem,
{
fn as_any_mut(self: Pin<&mut Self>) -> Pin<&mut dyn Any> {
unreachable!()
}
fn poll_recv(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Envelope>> {
let this = self.project();
let scope = this.scope.clone();
scope.sync_within(|| match this.inner.poll_next(cx) {
Poll::Ready(Some(message)) => {
let trace_id = scope::trace_id();
this.scope.set_trace_id(if *this.rewrite_trace_id {
TraceId::generate()
} else {
trace_id
});
Poll::Ready(Some(message.pack(trace_id)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => {
this.scope.set_trace_id(scope::trace_id());
Poll::Pending
}
})
}
}
pub struct Emitter(mpsc::Sender<AnyMessage>);
impl Emitter {
pub async fn emit<M: Message>(&mut self, message: M) {
let _ = self.0.send(AnyMessage::new(message)).await;
}
}
#[sealed]
pub trait StreamItem: 'static {
#[doc(hidden)]
fn pack(self, trace_id: TraceId) -> Envelope;
}
#[sealed]
impl<M: Message> StreamItem for M {
#[doc(hidden)]
fn pack(self, trace_id: TraceId) -> Envelope {
let kind = MessageKind::regular(Addr::NULL);
Envelope::with_trace_id(self, kind, trace_id)
}
}
#[sealed]
impl<M1: Message, M2: Message> StreamItem for Result<M1, M2> {
#[doc(hidden)]
fn pack(self, trace_id: TraceId) -> Envelope {
match self {
Ok(msg) => msg.pack(trace_id),
Err(msg) => msg.pack(trace_id),
}
}
}