Skip to main content

camel_processor/
streaming_splitter.rs

1use futures::{StreamExt, pin_mut};
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio_util::sync::CancellationToken;
6use tower::Service;
7
8use camel_api::{
9    AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, StreamingSplitExpression, Value,
10};
11
12pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
13pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
14
15#[derive(Clone)]
16pub struct StreamingSplitterService {
17    expression: StreamingSplitExpression,
18    sub_pipeline: BoxProcessor,
19    aggregation: AggregationStrategy,
20    stop_on_exception: bool,
21    cancel_token: CancellationToken,
22}
23
24impl StreamingSplitterService {
25    pub fn new(
26        expression: StreamingSplitExpression,
27        sub_pipeline: BoxProcessor,
28        aggregation: AggregationStrategy,
29        stop_on_exception: bool,
30    ) -> Self {
31        Self {
32            expression,
33            sub_pipeline,
34            aggregation,
35            stop_on_exception,
36            cancel_token: CancellationToken::new(),
37        }
38    }
39
40    pub fn cancel(&self) {
41        self.cancel_token.cancel();
42    }
43
44    pub fn is_cancelled(&self) -> bool {
45        self.cancel_token.is_cancelled()
46    }
47}
48
49impl Service<Exchange> for StreamingSplitterService {
50    type Response = Exchange;
51    type Error = CamelError;
52    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
53
54    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
55        self.sub_pipeline.poll_ready(cx)
56    }
57
58    fn call(&mut self, exchange: Exchange) -> Self::Future {
59        let mut original = exchange.clone();
60        if matches!(original.input.body, Body::Stream(_)) {
61            original.input.body = Body::Empty;
62        }
63        let expression = self.expression.clone();
64        let sub_pipeline = self.sub_pipeline.clone();
65        let aggregation = self.aggregation.clone();
66        let stop_on_exception = self.stop_on_exception;
67        let cancel_token = self.cancel_token.clone();
68
69        Box::pin(async move {
70            let stream = expression(exchange);
71            pin_mut!(stream);
72
73            let mut acc: Option<Exchange> = None;
74            let mut acc_bodies: Vec<Value> = Vec::new();
75            let mut index: u64 = 0;
76
77            // One-entry lookahead for CamelSplitComplete
78            let mut current = stream.next().await;
79
80            while let Some(fragment_result) = current.take() {
81                if cancel_token.is_cancelled() {
82                    return Err(CamelError::ProcessorError(
83                        "StreamingSplitter cancelled".to_string(),
84                    ));
85                }
86
87                let fragment = fragment_result?;
88
89                // Peek next to know if this is the last entry
90                let next = stream.next().await;
91                let is_last = next.is_none();
92
93                let mut fragment = fragment;
94                fragment.set_property(CAMEL_SPLIT_INDEX, Value::from(index));
95                fragment.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(is_last));
96
97                let mut pipeline = sub_pipeline.clone();
98                let ready = tower::ServiceExt::ready(&mut pipeline).await;
99                let result = match ready {
100                    Ok(svc) => svc.call(fragment).await,
101                    Err(e) => Err(e),
102                };
103
104                match result {
105                    Ok(processed) => {
106                        match &aggregation {
107                            AggregationStrategy::CollectAll => {
108                                let v = match &processed.input.body {
109                                    Body::Text(s) => Value::String(s.clone()),
110                                    Body::Json(v) => v.clone(),
111                                    Body::Xml(s) => Value::String(s.clone()),
112                                    Body::Bytes(b) => {
113                                        Value::String(String::from_utf8_lossy(b).into_owned())
114                                    }
115                                    Body::Empty => Value::Null,
116                                    Body::Stream(_) => {
117                                        return Err(CamelError::TypeConversionFailed(
118                                            "StreamingSplitter CollectAll cannot aggregate Body::Stream — use 'stream_cache' or 'convert_body_to' before this step".to_string(),
119                                        ));
120                                    }
121                                };
122                                acc_bodies.push(v);
123                            }
124                            AggregationStrategy::Custom(fold_fn) => {
125                                acc = Some(match acc {
126                                    Some(prev) => fold_fn(prev, processed),
127                                    None => processed,
128                                });
129                            }
130                            _ => {
131                                acc = Some(processed);
132                            }
133                        }
134                        index += 1;
135                    }
136                    Err(e) => {
137                        if stop_on_exception {
138                            return Err(e);
139                        }
140                        index += 1;
141                    }
142                }
143
144                current = next;
145            }
146
147            match &aggregation {
148                AggregationStrategy::LastWins => Ok(acc.unwrap_or(original)),
149                AggregationStrategy::Original => Ok(original),
150                AggregationStrategy::CollectAll => {
151                    let mut out = original;
152                    out.input.body = Body::Json(Value::Array(acc_bodies));
153                    Ok(out)
154                }
155                AggregationStrategy::Custom(_) => Ok(acc.unwrap_or(original)),
156            }
157        })
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use bytes::Bytes;
165    use camel_api::{BoxProcessorExt, Message, StreamBody, StreamMetadata};
166    use futures::stream;
167    use std::sync::Arc;
168    use tokio::sync::Mutex;
169    use tower::ServiceExt;
170
171    use crate::stream_codec::{StreamSplitInput, resolve_format, resolve_incremental_codec};
172
173    fn passthrough_pipeline() -> BoxProcessor {
174        BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
175    }
176
177    fn uppercase_pipeline() -> BoxProcessor {
178        BoxProcessor::from_fn(|mut ex: Exchange| {
179            Box::pin(async move {
180                if let Body::Text(s) = &ex.input.body {
181                    ex.input.body = Body::Text(s.to_uppercase());
182                }
183                Ok(ex)
184            })
185        })
186    }
187
188    fn make_exchange(text: &str) -> Exchange {
189        Exchange::new(Message::new(text))
190    }
191
192    fn test_expression(fragments: Vec<Exchange>) -> StreamingSplitExpression {
193        Arc::new(move |_| {
194            let frags = fragments.clone();
195            Box::pin(stream::iter(frags.into_iter().map(Ok)))
196        })
197    }
198
199    fn error_expression() -> StreamingSplitExpression {
200        Arc::new(|_| {
201            Box::pin(stream::iter(vec![Err(CamelError::ProcessorError(
202                "stream error".to_string(),
203            ))]))
204        })
205    }
206
207    /// Build a `StreamingSplitExpression` that reads from `Body::Stream` and
208    /// splits using the NdjsonCodec. Mirrors the resolution logic in
209    /// `step_resolution.rs`.
210    fn ndjson_stream_expression(config: camel_api::StreamSplitConfig) -> StreamingSplitExpression {
211        Arc::new(move |exchange: Exchange| {
212            let config = config.clone();
213            let (stream_body, parent) = match &exchange.input.body {
214                Body::Stream(sb) => (sb.clone(), {
215                    let mut p = exchange.clone();
216                    p.input.body = Body::Empty;
217                    p
218                }),
219                _ => {
220                    return Box::pin(futures::stream::once(async {
221                        Err(CamelError::ProcessorError(
222                            "streaming split requires Body::Stream".into(),
223                        ))
224                    }));
225                }
226            };
227
228            let stream = match stream_body.stream.try_lock() {
229                Ok(mut guard) => match guard.take() {
230                    Some(s) => s,
231                    None => {
232                        return Box::pin(futures::stream::once(async {
233                            Err(CamelError::ProcessorError(
234                                "stream body already consumed".into(),
235                            ))
236                        }));
237                    }
238                },
239                Err(_) => {
240                    return Box::pin(futures::stream::once(async {
241                        Err(CamelError::ProcessorError("stream body locked".into()))
242                    }));
243                }
244            };
245
246            let input = StreamSplitInput {
247                parent,
248                stream,
249                metadata: stream_body.metadata,
250            };
251
252            match resolve_format(&config.format, &input.metadata) {
253                Ok(f) => {
254                    let codec = resolve_incremental_codec(&f);
255                    let codec = match codec {
256                        Ok(c) => c,
257                        Err(e) => return Box::pin(futures::stream::once(async { Err(e) })),
258                    };
259                    codec.split(input, config)
260                }
261                Err(e) => Box::pin(futures::stream::once(async { Err(e) })),
262            }
263        })
264    }
265
266    // ---------------------------------------------------------------------------
267    // Integration test: Body::Stream NDJSON → streaming split → Body::Json fragments
268    // ---------------------------------------------------------------------------
269
270    #[tokio::test]
271    async fn test_ndjson_body_stream_streaming_split() {
272        // ── Arrange ────────────────────────────────────────────────────────────
273        // 3 lines of NDJSON as a byte stream
274        let ndjson_lines: Vec<Result<Bytes, CamelError>> = vec![
275            Ok(Bytes::from("{\"id\":1,\"name\":\"a\"}\n")),
276            Ok(Bytes::from("{\"id\":2,\"name\":\"b\"}\n")),
277            Ok(Bytes::from("{\"id\":3,\"name\":\"c\"}\n")),
278        ];
279        let byte_stream = futures::stream::iter(ndjson_lines);
280
281        let stream_body = StreamBody {
282            stream: Arc::new(Mutex::new(Some(Box::pin(byte_stream)))),
283            metadata: StreamMetadata {
284                content_type: Some("application/x-ndjson".into()),
285                size_hint: None,
286                origin: Some("test://ndjson".into()),
287            },
288        };
289
290        let ex = Exchange::new(Message::new(Body::Stream(stream_body)));
291
292        // Streaming split config — Ndjson format
293        let split_config = camel_api::StreamSplitConfig {
294            format: camel_api::StreamSplitFormat::Ndjson,
295            ..Default::default()
296        };
297
298        // Recorder sub-pipeline: captures per-fragment body + properties
299        let fragments: Arc<Mutex<Vec<(Option<serde_json::Value>, Option<Value>, Option<Value>)>>> =
300            Arc::new(Mutex::new(Vec::new()));
301        let fragments_clone = Arc::clone(&fragments);
302        let recorder = BoxProcessor::from_fn(move |ex: Exchange| {
303            let frags = Arc::clone(&fragments_clone);
304            Box::pin(async move {
305                let body_json = match &ex.input.body {
306                    Body::Json(v) => Some(v.clone()),
307                    _ => None,
308                };
309                let split_index = ex.property(CAMEL_SPLIT_INDEX).cloned();
310                let split_complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
311                let mut guard = frags.lock().await;
312                guard.push((body_json, split_index, split_complete));
313                Ok(ex)
314            })
315        });
316
317        let expression = ndjson_stream_expression(split_config);
318
319        // ── Act ────────────────────────────────────────────────────────────────
320        let mut splitter = StreamingSplitterService::new(
321            expression,
322            recorder,
323            AggregationStrategy::CollectAll,
324            true, // stop_on_exception
325        );
326
327        let result = splitter
328            .ready()
329            .await
330            .expect("splitter ready")
331            .call(ex)
332            .await
333            .expect("splitter call");
334
335        // ── Assert ─────────────────────────────────────────────────────────────
336        let guard = fragments.lock().await;
337
338        // 1. Three fragments were produced
339        assert_eq!(guard.len(), 3, "expected 3 NDJSON fragments");
340
341        // 2. Each fragment has Body::Json
342        for (i, (body_json, _idx, _complete)) in guard.iter().enumerate() {
343            assert!(
344                body_json.is_some(),
345                "fragment {i}: expected Body::Json body, got non-Json"
346            );
347        }
348
349        // 3. Each fragment has CamelSplitIndex property (0, 1, 2)
350        for (i, (_body, idx, _complete)) in guard.iter().enumerate() {
351            assert_eq!(
352                *idx,
353                Some(Value::Number(serde_json::Number::from(i as u64))),
354                "fragment {i}: CamelSplitIndex mismatch"
355            );
356        }
357
358        // 4. CamelSplitComplete: first two false, last one true
359        assert_eq!(
360            guard[0].2,
361            Some(Value::Bool(false)),
362            "first fragment: CamelSplitComplete should be false"
363        );
364        assert_eq!(
365            guard[1].2,
366            Some(Value::Bool(false)),
367            "second fragment: CamelSplitComplete should be false"
368        );
369        assert_eq!(
370            guard[2].2,
371            Some(Value::Bool(true)),
372            "last fragment: CamelSplitComplete should be true"
373        );
374
375        // 5. CollectAll aggregated into JSON array with correct values
376        match &result.input.body {
377            Body::Json(v) => {
378                let arr = v.as_array().expect("CollectAll result should be array");
379                assert_eq!(arr.len(), 3);
380                assert_eq!(arr[0], serde_json::json!({"id":1,"name":"a"}));
381                assert_eq!(arr[1], serde_json::json!({"id":2,"name":"b"}));
382                assert_eq!(arr[2], serde_json::json!({"id":3,"name":"c"}));
383            }
384            other => panic!("expected Body::Json from CollectAll, got {other:?}"),
385        }
386
387        // 6. Original stream body sanitized (already Empty, becomes part of aggregate)
388        //    The aggregate exchange's body is Json, not Stream
389        assert!(
390            matches!(result.input.body, Body::Json(_)),
391            "aggregate body should be Json, not Stream"
392        );
393    }
394
395    // ---------------------------------------------------------------------------
396    // Integration test: Empty Body::Stream → aggregate result is empty
397    // ---------------------------------------------------------------------------
398
399    #[tokio::test]
400    async fn test_ndjson_body_stream_empty_stream() {
401        // ── Arrange ────────────────────────────────────────────────────────────
402        // Empty byte stream
403        let byte_stream = futures::stream::iter(Vec::<Result<Bytes, CamelError>>::new());
404
405        let stream_body = StreamBody {
406            stream: Arc::new(Mutex::new(Some(Box::pin(byte_stream)))),
407            metadata: StreamMetadata {
408                content_type: Some("application/x-ndjson".into()),
409                size_hint: None,
410                origin: None,
411            },
412        };
413
414        let mut ex = Exchange::new(Message::new(Body::Stream(stream_body)));
415        ex.set_property("trace_id", Value::String("empty-test".into()));
416
417        let split_config = camel_api::StreamSplitConfig {
418            format: camel_api::StreamSplitFormat::Ndjson,
419            ..Default::default()
420        };
421
422        let expression = ndjson_stream_expression(split_config);
423
424        // ── Act ────────────────────────────────────────────────────────────────
425        let mut splitter = StreamingSplitterService::new(
426            expression,
427            passthrough_pipeline(),
428            AggregationStrategy::CollectAll,
429            true,
430        );
431
432        let result = splitter
433            .ready()
434            .await
435            .expect("splitter ready")
436            .call(ex)
437            .await
438            .expect("splitter call");
439
440        // ── Assert ─────────────────────────────────────────────────────────────
441        // Empty stream → CollectAll produces Body::Json([])
442        match &result.input.body {
443            Body::Json(v) => {
444                let arr = v.as_array().expect("CollectAll result should be array");
445                assert!(
446                    arr.is_empty(),
447                    "empty stream should produce empty array, got {arr:?}"
448                );
449            }
450            other => {
451                panic!("expected Body::Json([]) from CollectAll on empty stream, got {other:?}")
452            }
453        }
454
455        // Properties preserved
456        assert_eq!(
457            result.property("trace_id"),
458            Some(&Value::String("empty-test".into()))
459        );
460    }
461
462    #[tokio::test]
463    async fn test_streaming_sequential_last_wins() {
464        let expr = test_expression(vec![
465            make_exchange("a"),
466            make_exchange("b"),
467            make_exchange("c"),
468        ]);
469        let mut svc = StreamingSplitterService::new(
470            expr,
471            uppercase_pipeline(),
472            AggregationStrategy::LastWins,
473            true,
474        );
475
476        let result = svc
477            .ready()
478            .await
479            .unwrap()
480            .call(make_exchange("original"))
481            .await
482            .unwrap();
483        assert_eq!(result.input.body.as_text(), Some("C"));
484    }
485
486    #[tokio::test]
487    async fn test_streaming_sequential_original() {
488        let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
489        let mut svc = StreamingSplitterService::new(
490            expr,
491            uppercase_pipeline(),
492            AggregationStrategy::Original,
493            true,
494        );
495
496        let result = svc
497            .ready()
498            .await
499            .unwrap()
500            .call(make_exchange("original"))
501            .await
502            .unwrap();
503        assert_eq!(result.input.body.as_text(), Some("original"));
504    }
505
506    #[tokio::test]
507    async fn test_streaming_stop_on_exception() {
508        let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
509        let fail_pipeline = BoxProcessor::from_fn(|_| {
510            Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
511        });
512        let mut svc =
513            StreamingSplitterService::new(expr, fail_pipeline, AggregationStrategy::LastWins, true);
514
515        let result = svc
516            .ready()
517            .await
518            .unwrap()
519            .call(make_exchange("original"))
520            .await;
521        assert!(result.is_err());
522    }
523
524    #[tokio::test]
525    async fn test_streaming_empty_stream() {
526        let expr: StreamingSplitExpression = Arc::new(|_| Box::pin(futures::stream::empty()));
527        let mut svc = StreamingSplitterService::new(
528            expr,
529            passthrough_pipeline(),
530            AggregationStrategy::LastWins,
531            true,
532        );
533
534        let mut ex = make_exchange("original");
535        ex.set_property("marker", Value::Bool(true));
536        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
537        assert_eq!(result.input.body.as_text(), Some("original"));
538        assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
539    }
540
541    #[tokio::test]
542    async fn test_streaming_error_in_expression() {
543        let mut svc = StreamingSplitterService::new(
544            error_expression(),
545            passthrough_pipeline(),
546            AggregationStrategy::LastWins,
547            true,
548        );
549
550        let result = svc
551            .ready()
552            .await
553            .unwrap()
554            .call(make_exchange("original"))
555            .await;
556        assert!(result.is_err());
557    }
558
559    #[tokio::test]
560    async fn test_streaming_cancellation() {
561        let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
562        let slow_pipeline = BoxProcessor::from_fn(|ex| {
563            Box::pin(async move {
564                tokio::time::sleep(std::time::Duration::from_secs(60)).await;
565                Ok(ex)
566            })
567        });
568        let svc =
569            StreamingSplitterService::new(expr, slow_pipeline, AggregationStrategy::LastWins, true);
570        svc.cancel();
571
572        let mut svc_clone = svc.clone();
573        let result = svc_clone
574            .ready()
575            .await
576            .unwrap()
577            .call(make_exchange("original"))
578            .await;
579        assert!(result.is_err());
580    }
581
582    #[tokio::test]
583    async fn test_streaming_sequential_collect_all() {
584        let expr = test_expression(vec![
585            make_exchange("a"),
586            make_exchange("b"),
587            make_exchange("c"),
588        ]);
589        let mut svc = StreamingSplitterService::new(
590            expr,
591            uppercase_pipeline(),
592            AggregationStrategy::CollectAll,
593            true,
594        );
595
596        let result = svc
597            .ready()
598            .await
599            .unwrap()
600            .call(make_exchange("original"))
601            .await
602            .unwrap();
603        let expected = serde_json::json!(["A", "B", "C"]);
604        match &result.input.body {
605            Body::Json(v) => assert_eq!(*v, expected),
606            other => panic!("expected JSON body, got {other:?}"),
607        }
608    }
609
610    #[tokio::test]
611    async fn test_streaming_sequential_custom_aggregation() {
612        let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
613            Arc::new(|mut acc: Exchange, next: Exchange| {
614                let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
615                let next_text = next.input.body.as_text().unwrap_or("").to_string();
616                acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
617                acc
618            });
619
620        let expr = test_expression(vec![
621            make_exchange("a"),
622            make_exchange("b"),
623            make_exchange("c"),
624        ]);
625        let mut svc = StreamingSplitterService::new(
626            expr,
627            uppercase_pipeline(),
628            AggregationStrategy::Custom(joiner),
629            true,
630        );
631
632        let result = svc
633            .ready()
634            .await
635            .unwrap()
636            .call(make_exchange("original"))
637            .await
638            .unwrap();
639        assert_eq!(result.input.body.as_text(), Some("A+B+C"));
640    }
641
642    #[tokio::test]
643    async fn test_streaming_error_continue_on_exception() {
644        let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
645        let count_clone = call_count.clone();
646        let fail_on_first = BoxProcessor::from_fn(move |ex: Exchange| {
647            let count = count_clone.clone();
648            Box::pin(async move {
649                let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
650                if n == 0 {
651                    Err(CamelError::ProcessorError("first fails".into()))
652                } else {
653                    Ok(ex)
654                }
655            })
656        });
657
658        let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
659        let mut svc = StreamingSplitterService::new(
660            expr,
661            fail_on_first,
662            AggregationStrategy::LastWins,
663            false,
664        );
665
666        let result = svc
667            .ready()
668            .await
669            .unwrap()
670            .call(make_exchange("original"))
671            .await
672            .unwrap();
673        assert_eq!(result.input.body.as_text(), Some("b"));
674        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
675    }
676
677    #[tokio::test]
678    async fn test_streaming_metadata_lookahead() {
679        let recorder = BoxProcessor::from_fn(|ex: Exchange| {
680            Box::pin(async move {
681                let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
682                let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
683                let body = serde_json::json!({
684                    "index": idx,
685                    "complete": complete,
686                });
687                let mut out = ex;
688                out.input.body = Body::Json(body);
689                Ok(out)
690            })
691        });
692
693        let expr = test_expression(vec![
694            make_exchange("x"),
695            make_exchange("y"),
696            make_exchange("z"),
697        ]);
698        let mut svc =
699            StreamingSplitterService::new(expr, recorder, AggregationStrategy::CollectAll, true);
700
701        let result = svc
702            .ready()
703            .await
704            .unwrap()
705            .call(make_exchange("original"))
706            .await
707            .unwrap();
708        let expected = serde_json::json!([
709            {"index": 0, "complete": false},
710            {"index": 1, "complete": false},
711            {"index": 2, "complete": true},
712        ]);
713        match &result.input.body {
714            Body::Json(v) => assert_eq!(*v, expected),
715            other => panic!("expected JSON body, got {other:?}"),
716        }
717    }
718
719    #[tokio::test]
720    async fn test_streaming_split_sanitizes_stream_body_in_original() {
721        let chunks = vec![Ok(Bytes::from("line1\n"))];
722        let stream = futures::stream::iter(chunks);
723        let sb = StreamBody {
724            stream: Arc::new(Mutex::new(Some(Box::pin(stream)))),
725            metadata: Default::default(),
726        };
727        let ex = Exchange::new(Message::new(Body::Stream(sb)));
728
729        let expression =
730            test_expression(vec![Exchange::new(Message::new(Body::Text("frag".into())))]);
731        let sub_pipeline = passthrough_pipeline();
732        let mut splitter = StreamingSplitterService::new(
733            expression,
734            sub_pipeline,
735            AggregationStrategy::Original,
736            true,
737        );
738
739        let result = splitter
740            .ready()
741            .await
742            .expect("ready")
743            .call(ex)
744            .await
745            .expect("call");
746        assert!(
747            matches!(result.input.body, Body::Empty),
748            "original body should be sanitized to Empty"
749        );
750    }
751}