1use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use futures::{StreamExt, pin_mut};
10
11use camel_api::{
12 AggregationStrategy, CamelError, Exchange, OutcomePipeline, OutcomeSegment, PipelineOutcome,
13 StreamingSplitExpression,
14};
15
16use crate::split_segment::aggregate_completed;
17
18pub struct StreamingSplitSegment {
32 pub expression: StreamingSplitExpression,
34 pub body: OutcomeSegment,
36 pub aggregation: AggregationStrategy,
38 pub stop_on_exception: bool,
48}
49
50impl Clone for StreamingSplitSegment {
51 fn clone(&self) -> Self {
52 Self {
53 expression: Arc::clone(&self.expression),
54 body: self.body.clone(),
55 aggregation: self.aggregation.clone(),
56 stop_on_exception: self.stop_on_exception,
57 }
58 }
59}
60
61impl OutcomePipeline for StreamingSplitSegment {
62 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
63 Box::new(self.clone())
64 }
65
66 fn run<'a>(
67 &'a mut self,
68 exchange: Exchange,
69 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
70 let expression = Arc::clone(&self.expression);
71 let mut body = self.body.clone();
72 let aggregation = self.aggregation.clone();
73 let stop_on_exception = self.stop_on_exception;
74
75 Box::pin(async move {
76 let original = exchange.clone();
77 let stream = expression(exchange);
78 pin_mut!(stream);
79
80 let mut outputs: Vec<Exchange> = Vec::new();
81 let mut last_error: Option<CamelError> = None;
82
83 while let Some(frag_result) = stream.next().await {
84 let frag = match frag_result {
85 Ok(f) => f,
86 Err(e) => return PipelineOutcome::Failed(e),
87 };
88
89 match body.run(frag).await {
90 PipelineOutcome::Completed(ex) => outputs.push(ex),
91 PipelineOutcome::Stopped(stopped_ex) => {
92 return PipelineOutcome::Stopped(stopped_ex);
96 }
97 PipelineOutcome::Failed(err) => {
98 if stop_on_exception {
99 return PipelineOutcome::Failed(err);
100 }
101 last_error = Some(err);
104 }
105 }
106 }
107
108 if let Some(err) = last_error {
110 return PipelineOutcome::Failed(err);
111 }
112 PipelineOutcome::Completed(aggregate_completed(outputs, original, aggregation))
113 })
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use camel_api::{Body, CamelError, Message};
121
122 #[derive(Clone)]
125 struct StopOnNthBody {
126 counter: Arc<std::sync::atomic::AtomicUsize>,
127 stop_at: usize,
128 }
129 impl OutcomePipeline for StopOnNthBody {
130 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
131 Box::new(self.clone())
132 }
133 fn run<'a>(
134 &'a mut self,
135 exchange: Exchange,
136 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
137 let count = self
138 .counter
139 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
140 let stop_at = self.stop_at;
141 Box::pin(async move {
142 if count >= stop_at {
143 PipelineOutcome::Stopped(exchange)
144 } else {
145 PipelineOutcome::Completed(exchange)
146 }
147 })
148 }
149 }
150
151 #[derive(Clone)]
153 struct MutateAndStopBody;
154 impl OutcomePipeline for MutateAndStopBody {
155 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
156 Box::new(MutateAndStopBody)
157 }
158 fn run<'a>(
159 &'a mut self,
160 mut exchange: Exchange,
161 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
162 Box::pin(async move {
163 exchange.input.body = Body::Text("mutated-by-body".to_string());
164 PipelineOutcome::Stopped(exchange)
165 })
166 }
167 }
168
169 #[tokio::test]
172 async fn streaming_split_stop_halts_stream_consumption() {
173 let fragments: Vec<Exchange> = (0..5)
176 .map(|i| Exchange::new(Message::new(format!("frag-{i}"))))
177 .collect();
178 let stored_frags = fragments.clone();
179
180 let expression: StreamingSplitExpression = Arc::new(move |_| {
181 let frags = stored_frags.clone();
182 Box::pin(futures::stream::iter(frags.into_iter().map(Ok)))
183 });
184
185 let invoke_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
186 let body = StopOnNthBody {
187 counter: Arc::clone(&invoke_count),
188 stop_at: 2, };
190
191 let mut seg = StreamingSplitSegment {
192 expression,
193 body: OutcomeSegment::new(Box::new(body)),
194 aggregation: AggregationStrategy::LastWins,
195 stop_on_exception: true,
196 };
197
198 let ex = Exchange::new(Message::new("trigger"));
199 let result = OutcomePipeline::run(&mut seg, ex).await;
200
201 assert!(
202 matches!(result, PipelineOutcome::Stopped(_)),
203 "Expected Stopped, got {result:?}"
204 );
205 assert_eq!(invoke_count.load(std::sync::atomic::Ordering::SeqCst), 3);
207 }
208
209 #[tokio::test]
212 async fn streaming_split_stop_no_aggregate() {
213 let fragments: Vec<Exchange> = (0..3)
216 .map(|i| Exchange::new(Message::new(format!("frag-{i}"))))
217 .collect();
218 let stored_frags = fragments.clone();
219
220 let expression: StreamingSplitExpression = Arc::new(move |_| {
221 let frags = stored_frags.clone();
222 Box::pin(futures::stream::iter(frags.into_iter().map(Ok)))
223 });
224
225 let mut seg = StreamingSplitSegment {
226 expression,
227 body: OutcomeSegment::new(Box::new(MutateAndStopBody)),
228 aggregation: AggregationStrategy::CollectAll,
229 stop_on_exception: true,
230 };
231
232 let ex = Exchange::new(Message::new("original"));
233 let result = OutcomePipeline::run(&mut seg, ex).await;
234
235 match result {
236 PipelineOutcome::Stopped(ex) => {
237 assert_eq!(
239 ex.input.body.as_text(),
240 Some("mutated-by-body"),
241 "Stopped exchange should carry body mutation"
242 );
243 }
244 other => panic!(
245 "Expected Stopped with mutated body, got {other:?} — aggregation should NOT fire"
246 ),
247 }
248 }
249
250 #[tokio::test]
253 async fn streaming_split_stop_preserves_exchange_mutations() {
254 let fragments: Vec<Exchange> = (0..5)
257 .map(|i| Exchange::new(Message::new(format!("frag-{i}"))))
258 .collect();
259 let stored_frags = fragments.clone();
260
261 let expression: StreamingSplitExpression = Arc::new(move |_| {
262 let frags = stored_frags.clone();
263 Box::pin(futures::stream::iter(frags.into_iter().map(Ok)))
264 });
265
266 let mut seg = StreamingSplitSegment {
267 expression,
268 body: OutcomeSegment::new(Box::new(MutateAndStopBody)),
269 aggregation: AggregationStrategy::CollectAll,
270 stop_on_exception: true,
271 };
272
273 let ex = Exchange::new(Message::new("original"));
274 let result = OutcomePipeline::run(&mut seg, ex).await;
275
276 match result {
277 PipelineOutcome::Stopped(ex) => {
278 assert_eq!(
279 ex.input.body.as_text(),
280 Some("mutated-by-body"),
281 "Stopped exchange should carry the body mutation from the segment"
282 );
283 }
284 other => panic!("Expected Stopped with mutation, got {other:?}"),
285 }
286 }
287
288 #[tokio::test]
291 async fn streaming_split_stop_on_exception_true() {
292 let fragments: Vec<Exchange> = (0..3)
293 .map(|i| Exchange::new(Message::new(format!("frag-{i}"))))
294 .collect();
295 let stored_frags = fragments.clone();
296
297 let expression: StreamingSplitExpression = Arc::new(move |_| {
298 let frags = stored_frags.clone();
299 Box::pin(futures::stream::iter(frags.into_iter().map(Ok)))
300 });
301
302 let invoke_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
303 #[derive(Clone)]
305 struct FailOnSecondBody {
306 counter: Arc<std::sync::atomic::AtomicUsize>,
307 }
308 impl OutcomePipeline for FailOnSecondBody {
309 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
310 Box::new(self.clone())
311 }
312 fn run<'a>(
313 &'a mut self,
314 exchange: Exchange,
315 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
316 let count = self
317 .counter
318 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
319 Box::pin(async move {
320 if count == 1 {
321 PipelineOutcome::Failed(CamelError::ProcessorError("fail".into()))
322 } else {
323 PipelineOutcome::Completed(exchange)
324 }
325 })
326 }
327 }
328
329 let body = FailOnSecondBody {
330 counter: Arc::clone(&invoke_count),
331 };
332
333 let mut seg = StreamingSplitSegment {
334 expression,
335 body: OutcomeSegment::new(Box::new(body)),
336 aggregation: AggregationStrategy::LastWins,
337 stop_on_exception: true,
338 };
339
340 let ex = Exchange::new(Message::new("trigger"));
341 let result = OutcomePipeline::run(&mut seg, ex).await;
342
343 assert!(
344 matches!(result, PipelineOutcome::Failed(_)),
345 "stop_on_exception=true should propagate failure immediately"
346 );
347 assert_eq!(invoke_count.load(std::sync::atomic::Ordering::SeqCst), 2);
349 }
350
351 #[tokio::test]
354 async fn streaming_split_stop_on_exception_false() {
355 let fragments: Vec<Exchange> = (0..3)
356 .map(|i| Exchange::new(Message::new(format!("frag-{i}"))))
357 .collect();
358 let stored_frags = fragments.clone();
359
360 let expression: StreamingSplitExpression = Arc::new(move |_| {
361 let frags = stored_frags.clone();
362 Box::pin(futures::stream::iter(frags.into_iter().map(Ok)))
363 });
364
365 let invoke_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
366 #[derive(Clone)]
368 struct FailOnSecondBody {
369 counter: Arc<std::sync::atomic::AtomicUsize>,
370 }
371 impl OutcomePipeline for FailOnSecondBody {
372 fn clone_box(&self) -> Box<dyn OutcomePipeline> {
373 Box::new(self.clone())
374 }
375 fn run<'a>(
376 &'a mut self,
377 exchange: Exchange,
378 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
379 let count = self
380 .counter
381 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
382 Box::pin(async move {
383 if count == 1 {
384 PipelineOutcome::Failed(CamelError::ProcessorError("fail".into()))
385 } else {
386 PipelineOutcome::Completed(exchange)
387 }
388 })
389 }
390 }
391
392 let body = FailOnSecondBody {
393 counter: Arc::clone(&invoke_count),
394 };
395
396 let mut seg = StreamingSplitSegment {
397 expression,
398 body: OutcomeSegment::new(Box::new(body)),
399 aggregation: AggregationStrategy::LastWins,
400 stop_on_exception: false,
401 };
402
403 let ex = Exchange::new(Message::new("trigger"));
404 let result = OutcomePipeline::run(&mut seg, ex).await;
405
406 assert!(
408 matches!(result, PipelineOutcome::Failed(_)),
409 "stop_on_exception=false should still propagate error at end"
410 );
411 assert_eq!(
412 invoke_count.load(std::sync::atomic::Ordering::SeqCst),
413 3,
414 "all 3 fragments should be processed"
415 );
416 }
417}