use crate::axon::{Axon, BoxFuture};
use futures_core::Stream;
use futures_util::StreamExt;
use ranvier_core::bus::Bus;
use ranvier_core::outcome::Outcome;
use ranvier_core::schematic::Schematic;
use ranvier_core::streaming::StreamTimeoutConfig;
use ranvier_core::transition::ResourceRequirement;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;
pub struct StreamingAxon<In, Item, E, Res = ()> {
pub schematic: Schematic,
pub(crate) stream_executor: StreamExecutor<In, Item, E, Res>,
pub timeout_config: Option<StreamTimeoutConfig>,
pub buffer_size: usize,
}
pub type StreamExecutorType<In, Item, E, Res> = StreamExecutor<In, Item, E, Res>;
type StreamExecutor<In, Item, E, Res> = Arc<
dyn for<'a> Fn(
In,
&'a Res,
&'a mut Bus,
) -> BoxFuture<'a, Result<Pin<Box<dyn Stream<Item = Item> + Send>>, StreamingAxonError<E>>>
+ Send
+ Sync,
>;
#[derive(Debug)]
pub enum StreamingAxonError<E> {
PipelineFault(E),
UnexpectedOutcome(String),
StreamInitError(String),
Timeout(StreamTimeoutKind),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StreamTimeoutKind {
Init,
Idle,
Total,
}
impl<E: Debug> std::fmt::Display for StreamingAxonError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::PipelineFault(e) => write!(f, "Pipeline fault: {:?}", e),
Self::UnexpectedOutcome(msg) => write!(f, "Unexpected outcome: {}", msg),
Self::StreamInitError(msg) => write!(f, "Stream init error: {}", msg),
Self::Timeout(kind) => write!(f, "Stream timeout: {:?}", kind),
}
}
}
impl<E: Debug> std::error::Error for StreamingAxonError<E> {}
impl<In, Item, E, Res> Clone for StreamingAxon<In, Item, E, Res> {
fn clone(&self) -> Self {
Self {
schematic: self.schematic.clone(),
stream_executor: self.stream_executor.clone(),
timeout_config: self.timeout_config.clone(),
buffer_size: self.buffer_size,
}
}
}
impl<In, Item, E, Res> StreamingAxon<In, Item, E, Res>
where
In: Send + Sync + 'static,
Item: Send + 'static,
E: Send + Sync + Debug + 'static,
Res: ResourceRequirement,
{
pub async fn execute(
&self,
input: In,
resources: &Res,
bus: &mut Bus,
) -> Result<Pin<Box<dyn Stream<Item = Item> + Send>>, StreamingAxonError<E>> {
let stream = (self.stream_executor)(input, resources, bus).await?;
match &self.timeout_config {
Some(config) if config.init.is_some() || config.idle.is_some() || config.total.is_some() => {
Ok(Box::pin(TimeoutStream::new(stream, config.clone())))
}
_ => Ok(stream),
}
}
pub fn with_buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self
}
pub fn with_timeout(mut self, config: StreamTimeoutConfig) -> Self {
self.timeout_config = Some(config);
self
}
pub fn export_schematic(&self) -> &Schematic {
&self.schematic
}
}
impl<In, Item, E, Res> StreamingAxon<In, Item, E, Res>
where
In: Send + Sync + Serialize + DeserializeOwned + 'static,
Item: Send + Sync + Serialize + DeserializeOwned + 'static,
E: Send + Sync + Serialize + DeserializeOwned + Debug + 'static,
Res: ResourceRequirement,
{
pub fn collect_into_vec(self) -> Axon<In, Vec<Item>, String, Res> {
let stream_executor = self.stream_executor.clone();
let timeout_config = self.timeout_config.clone();
let executor: crate::axon::Executor<In, Vec<Item>, String, Res> = Arc::new(
move |input: In, res: &Res, bus: &mut Bus| -> BoxFuture<'_, Outcome<Vec<Item>, String>> {
let stream_executor = stream_executor.clone();
let timeout_config = timeout_config.clone();
Box::pin(async move {
let stream = match stream_executor(input, res, bus).await {
Ok(s) => s,
Err(e) => return Outcome::Fault(format!("{}", e)),
};
let stream = match &timeout_config {
Some(config)
if config.init.is_some()
|| config.idle.is_some()
|| config.total.is_some() =>
{
Box::pin(TimeoutStream::new(stream, config.clone()))
as Pin<Box<dyn Stream<Item = Item> + Send>>
}
_ => stream,
};
let items: Vec<Item> = stream.collect().await;
Outcome::Next(items)
})
},
);
Axon {
schematic: self.schematic,
executor,
execution_mode: crate::axon::ExecutionMode::Local,
persistence_store: None,
audit_sink: None,
dlq_sink: None,
dlq_policy: Default::default(),
dynamic_dlq_policy: None,
saga_policy: Default::default(),
dynamic_saga_policy: None,
saga_compensation_registry: Arc::new(std::sync::RwLock::new(
ranvier_core::saga::SagaCompensationRegistry::new(),
)),
iam_handle: None,
}
}
}
struct TimeoutStream<S> {
inner: Pin<Box<S>>,
config: StreamTimeoutConfig,
started_at: tokio::time::Instant,
first_item_received: bool,
last_item_at: tokio::time::Instant,
finished: bool,
}
impl<S, Item> TimeoutStream<S>
where
S: Stream<Item = Item> + Send,
{
fn new(inner: S, config: StreamTimeoutConfig) -> Self {
let now = tokio::time::Instant::now();
Self {
inner: Box::pin(inner),
config,
started_at: now,
first_item_received: false,
last_item_at: now,
finished: false,
}
}
}
impl<S, Item> Stream for TimeoutStream<S>
where
S: Stream<Item = Item> + Send,
{
type Item = Item;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
if this.finished {
return std::task::Poll::Ready(None);
}
if let Some(total) = this.config.total {
if this.started_at.elapsed() >= total {
tracing::warn!("Stream total timeout exceeded ({:?})", total);
this.finished = true;
return std::task::Poll::Ready(None);
}
}
match this.inner.as_mut().poll_next(cx) {
std::task::Poll::Ready(Some(item)) => {
this.first_item_received = true;
this.last_item_at = tokio::time::Instant::now();
std::task::Poll::Ready(Some(item))
}
std::task::Poll::Ready(None) => {
this.finished = true;
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => {
let now = tokio::time::Instant::now();
if !this.first_item_received {
if let Some(init) = this.config.init {
if now.duration_since(this.started_at) >= init {
tracing::warn!("Stream init timeout exceeded ({:?})", init);
this.finished = true;
return std::task::Poll::Ready(None);
}
}
}
if this.first_item_received {
if let Some(idle) = this.config.idle {
if now.duration_since(this.last_item_at) >= idle {
tracing::warn!("Stream idle timeout exceeded ({:?})", idle);
this.finished = true;
return std::task::Poll::Ready(None);
}
}
}
std::task::Poll::Pending
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream;
use ranvier_core::bus::Bus;
#[tokio::test]
async fn test_streaming_axon_basic_execute() {
let stream_executor: StreamExecutor<String, String, String, ()> =
Arc::new(|input: String, _res: &(), _bus: &mut Bus| {
Box::pin(async move {
let items = vec![
format!("chunk1: {}", input),
format!("chunk2: {}", input),
format!("chunk3: {}", input),
];
Ok(Box::pin(stream::iter(items)) as Pin<Box<dyn Stream<Item = String> + Send>>)
})
});
let sa = StreamingAxon {
schematic: Schematic::new("test"),
stream_executor,
timeout_config: None,
buffer_size: 64,
};
let mut bus = Bus::new();
let stream = sa.execute("hello".to_string(), &(), &mut bus).await.unwrap();
let items: Vec<String> = stream.collect().await;
assert_eq!(items.len(), 3);
assert_eq!(items[0], "chunk1: hello");
}
#[tokio::test]
async fn test_streaming_axon_collect_into_vec() {
let stream_executor: StreamExecutor<(), String, String, ()> =
Arc::new(|_input: (), _res: &(), _bus: &mut Bus| {
Box::pin(async move {
let items = vec!["a".to_string(), "b".to_string(), "c".to_string()];
Ok(Box::pin(stream::iter(items)) as Pin<Box<dyn Stream<Item = String> + Send>>)
})
});
let sa: StreamingAxon<(), String, String, ()> = StreamingAxon {
schematic: Schematic::new("test-collect"),
stream_executor,
timeout_config: None,
buffer_size: 64,
};
let axon = sa.collect_into_vec();
let mut bus = Bus::new();
let result = axon.execute((), &(), &mut bus).await;
match result {
Outcome::Next(items) => {
assert_eq!(items, vec!["a", "b", "c"]);
}
other => panic!("Expected Next, got {:?}", other),
}
}
#[tokio::test]
async fn test_streaming_axon_pipeline_fault() {
let stream_executor: StreamExecutor<(), String, String, ()> =
Arc::new(|_input: (), _res: &(), _bus: &mut Bus| {
Box::pin(async move {
Err(StreamingAxonError::PipelineFault("step failed".to_string()))
})
});
let sa = StreamingAxon {
schematic: Schematic::new("test-fault"),
stream_executor,
timeout_config: None,
buffer_size: 64,
};
let mut bus = Bus::new();
let result = sa.execute((), &(), &mut bus).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_streaming_axon_clone() {
let stream_executor: StreamExecutor<(), String, String, ()> =
Arc::new(|_input: (), _res: &(), _bus: &mut Bus| {
Box::pin(async move {
Ok(Box::pin(stream::iter(vec!["x".to_string()]))
as Pin<Box<dyn Stream<Item = String> + Send>>)
})
});
let sa = StreamingAxon {
schematic: Schematic::new("test-clone"),
stream_executor,
timeout_config: None,
buffer_size: 64,
};
let sa2 = sa.clone();
let mut bus = Bus::new();
let items: Vec<String> = sa2.execute((), &(), &mut bus).await.unwrap().collect().await;
assert_eq!(items, vec!["x"]);
}
}