use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use futures::{StreamExt, pin_mut};
use camel_api::{
AggregationStrategy, CamelError, Exchange, OutcomePipeline, OutcomeSegment, PipelineOutcome,
StreamingSplitExpression,
};
use crate::split_segment::aggregate_completed;
pub struct StreamingSplitSegment {
pub expression: StreamingSplitExpression,
pub body: OutcomeSegment,
pub aggregation: AggregationStrategy,
pub stop_on_exception: bool,
}
impl Clone for StreamingSplitSegment {
fn clone(&self) -> Self {
Self {
expression: Arc::clone(&self.expression),
body: self.body.clone(),
aggregation: self.aggregation.clone(),
stop_on_exception: self.stop_on_exception,
}
}
}
impl OutcomePipeline for StreamingSplitSegment {
fn clone_box(&self) -> Box<dyn OutcomePipeline> {
Box::new(self.clone())
}
fn run<'a>(
&'a mut self,
exchange: Exchange,
) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
let expression = Arc::clone(&self.expression);
let mut body = self.body.clone();
let aggregation = self.aggregation.clone();
let stop_on_exception = self.stop_on_exception;
Box::pin(async move {
let original = exchange.clone();
let stream = expression(exchange);
pin_mut!(stream);
let mut outputs: Vec<Exchange> = Vec::new();
let mut last_error: Option<CamelError> = None;
while let Some(frag_result) = stream.next().await {
let frag = match frag_result {
Ok(f) => f,
Err(e) => return PipelineOutcome::Failed(e),
};
match body.run(frag).await {
PipelineOutcome::Completed(ex) => outputs.push(ex),
PipelineOutcome::Stopped(stopped_ex) => {
return PipelineOutcome::Stopped(stopped_ex);
}
PipelineOutcome::Failed(err) => {
if stop_on_exception {
return PipelineOutcome::Failed(err);
}
last_error = Some(err);
}
}
}
if let Some(err) = last_error {
return PipelineOutcome::Failed(err);
}
PipelineOutcome::Completed(aggregate_completed(outputs, original, aggregation))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use camel_api::{Body, CamelError, Message};
#[derive(Clone)]
struct StopOnNthBody {
counter: Arc<std::sync::atomic::AtomicUsize>,
stop_at: usize,
}
impl OutcomePipeline for StopOnNthBody {
fn clone_box(&self) -> Box<dyn OutcomePipeline> {
Box::new(self.clone())
}
fn run<'a>(
&'a mut self,
exchange: Exchange,
) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
let count = self
.counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let stop_at = self.stop_at;
Box::pin(async move {
if count >= stop_at {
PipelineOutcome::Stopped(exchange)
} else {
PipelineOutcome::Completed(exchange)
}
})
}
}
#[derive(Clone)]
struct MutateAndStopBody;
impl OutcomePipeline for MutateAndStopBody {
fn clone_box(&self) -> Box<dyn OutcomePipeline> {
Box::new(MutateAndStopBody)
}
fn run<'a>(
&'a mut self,
mut exchange: Exchange,
) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
Box::pin(async move {
exchange.input.body = Body::Text("mutated-by-body".to_string());
PipelineOutcome::Stopped(exchange)
})
}
}
#[tokio::test]
async fn streaming_split_stop_halts_stream_consumption() {
let fragments: Vec<Exchange> = (0..5)
.map(|i| Exchange::new(Message::new(format!("frag-{i}"))))
.collect();
let stored_frags = fragments.clone();
let expression: StreamingSplitExpression = Arc::new(move |_| {
let frags = stored_frags.clone();
Box::pin(futures::stream::iter(frags.into_iter().map(Ok)))
});
let invoke_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let body = StopOnNthBody {
counter: Arc::clone(&invoke_count),
stop_at: 2, };
let mut seg = StreamingSplitSegment {
expression,
body: OutcomeSegment::new(Box::new(body)),
aggregation: AggregationStrategy::LastWins,
stop_on_exception: true,
};
let ex = Exchange::new(Message::new("trigger"));
let result = OutcomePipeline::run(&mut seg, ex).await;
assert!(
matches!(result, PipelineOutcome::Stopped(_)),
"Expected Stopped, got {result:?}"
);
assert_eq!(invoke_count.load(std::sync::atomic::Ordering::SeqCst), 3);
}
#[tokio::test]
async fn streaming_split_stop_no_aggregate() {
let fragments: Vec<Exchange> = (0..3)
.map(|i| Exchange::new(Message::new(format!("frag-{i}"))))
.collect();
let stored_frags = fragments.clone();
let expression: StreamingSplitExpression = Arc::new(move |_| {
let frags = stored_frags.clone();
Box::pin(futures::stream::iter(frags.into_iter().map(Ok)))
});
let mut seg = StreamingSplitSegment {
expression,
body: OutcomeSegment::new(Box::new(MutateAndStopBody)),
aggregation: AggregationStrategy::CollectAll,
stop_on_exception: true,
};
let ex = Exchange::new(Message::new("original"));
let result = OutcomePipeline::run(&mut seg, ex).await;
match result {
PipelineOutcome::Stopped(ex) => {
assert_eq!(
ex.input.body.as_text(),
Some("mutated-by-body"),
"Stopped exchange should carry body mutation"
);
}
other => panic!(
"Expected Stopped with mutated body, got {other:?} — aggregation should NOT fire"
),
}
}
#[tokio::test]
async fn streaming_split_stop_preserves_exchange_mutations() {
let fragments: Vec<Exchange> = (0..5)
.map(|i| Exchange::new(Message::new(format!("frag-{i}"))))
.collect();
let stored_frags = fragments.clone();
let expression: StreamingSplitExpression = Arc::new(move |_| {
let frags = stored_frags.clone();
Box::pin(futures::stream::iter(frags.into_iter().map(Ok)))
});
let mut seg = StreamingSplitSegment {
expression,
body: OutcomeSegment::new(Box::new(MutateAndStopBody)),
aggregation: AggregationStrategy::CollectAll,
stop_on_exception: true,
};
let ex = Exchange::new(Message::new("original"));
let result = OutcomePipeline::run(&mut seg, ex).await;
match result {
PipelineOutcome::Stopped(ex) => {
assert_eq!(
ex.input.body.as_text(),
Some("mutated-by-body"),
"Stopped exchange should carry the body mutation from the segment"
);
}
other => panic!("Expected Stopped with mutation, got {other:?}"),
}
}
#[tokio::test]
async fn streaming_split_stop_on_exception_true() {
let fragments: Vec<Exchange> = (0..3)
.map(|i| Exchange::new(Message::new(format!("frag-{i}"))))
.collect();
let stored_frags = fragments.clone();
let expression: StreamingSplitExpression = Arc::new(move |_| {
let frags = stored_frags.clone();
Box::pin(futures::stream::iter(frags.into_iter().map(Ok)))
});
let invoke_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
#[derive(Clone)]
struct FailOnSecondBody {
counter: Arc<std::sync::atomic::AtomicUsize>,
}
impl OutcomePipeline for FailOnSecondBody {
fn clone_box(&self) -> Box<dyn OutcomePipeline> {
Box::new(self.clone())
}
fn run<'a>(
&'a mut self,
exchange: Exchange,
) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
let count = self
.counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Box::pin(async move {
if count == 1 {
PipelineOutcome::Failed(CamelError::ProcessorError("fail".into()))
} else {
PipelineOutcome::Completed(exchange)
}
})
}
}
let body = FailOnSecondBody {
counter: Arc::clone(&invoke_count),
};
let mut seg = StreamingSplitSegment {
expression,
body: OutcomeSegment::new(Box::new(body)),
aggregation: AggregationStrategy::LastWins,
stop_on_exception: true,
};
let ex = Exchange::new(Message::new("trigger"));
let result = OutcomePipeline::run(&mut seg, ex).await;
assert!(
matches!(result, PipelineOutcome::Failed(_)),
"stop_on_exception=true should propagate failure immediately"
);
assert_eq!(invoke_count.load(std::sync::atomic::Ordering::SeqCst), 2);
}
#[tokio::test]
async fn streaming_split_stop_on_exception_false() {
let fragments: Vec<Exchange> = (0..3)
.map(|i| Exchange::new(Message::new(format!("frag-{i}"))))
.collect();
let stored_frags = fragments.clone();
let expression: StreamingSplitExpression = Arc::new(move |_| {
let frags = stored_frags.clone();
Box::pin(futures::stream::iter(frags.into_iter().map(Ok)))
});
let invoke_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
#[derive(Clone)]
struct FailOnSecondBody {
counter: Arc<std::sync::atomic::AtomicUsize>,
}
impl OutcomePipeline for FailOnSecondBody {
fn clone_box(&self) -> Box<dyn OutcomePipeline> {
Box::new(self.clone())
}
fn run<'a>(
&'a mut self,
exchange: Exchange,
) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
let count = self
.counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Box::pin(async move {
if count == 1 {
PipelineOutcome::Failed(CamelError::ProcessorError("fail".into()))
} else {
PipelineOutcome::Completed(exchange)
}
})
}
}
let body = FailOnSecondBody {
counter: Arc::clone(&invoke_count),
};
let mut seg = StreamingSplitSegment {
expression,
body: OutcomeSegment::new(Box::new(body)),
aggregation: AggregationStrategy::LastWins,
stop_on_exception: false,
};
let ex = Exchange::new(Message::new("trigger"));
let result = OutcomePipeline::run(&mut seg, ex).await;
assert!(
matches!(result, PipelineOutcome::Failed(_)),
"stop_on_exception=false should still propagate error at end"
);
assert_eq!(
invoke_count.load(std::sync::atomic::Ordering::SeqCst),
3,
"all 3 fragments should be processed"
);
}
}