Skip to main content

camel_processor/
streaming_splitter.rs

1use futures::{StreamExt, pin_mut};
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio_util::sync::CancellationToken;
6use tower::Service;
7
8use camel_api::{
9    AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, StreamingSplitExpression, Value,
10};
11
12pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
13pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
14
15#[derive(Clone)]
16pub struct StreamingSplitterService {
17    expression: StreamingSplitExpression,
18    sub_pipeline: BoxProcessor,
19    aggregation: AggregationStrategy,
20    stop_on_exception: bool,
21    cancel_token: CancellationToken,
22}
23
24impl StreamingSplitterService {
25    pub fn new(
26        expression: StreamingSplitExpression,
27        sub_pipeline: BoxProcessor,
28        aggregation: AggregationStrategy,
29        stop_on_exception: bool,
30    ) -> Self {
31        Self {
32            expression,
33            sub_pipeline,
34            aggregation,
35            stop_on_exception,
36            cancel_token: CancellationToken::new(),
37        }
38    }
39
40    pub fn cancel(&self) {
41        self.cancel_token.cancel();
42    }
43
44    pub fn is_cancelled(&self) -> bool {
45        self.cancel_token.is_cancelled()
46    }
47}
48
49impl Service<Exchange> for StreamingSplitterService {
50    type Response = Exchange;
51    type Error = CamelError;
52    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
53
54    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
55        self.sub_pipeline.poll_ready(cx)
56    }
57
58    fn call(&mut self, exchange: Exchange) -> Self::Future {
59        let mut original = exchange.clone();
60        if matches!(original.input.body, Body::Stream(_)) {
61            original.input.body = Body::Empty;
62        }
63        let expression = self.expression.clone();
64        let sub_pipeline = self.sub_pipeline.clone();
65        let aggregation = self.aggregation.clone();
66        let stop_on_exception = self.stop_on_exception;
67        let cancel_token = self.cancel_token.clone();
68
69        Box::pin(async move {
70            let stream = expression(exchange);
71            pin_mut!(stream);
72
73            let mut acc: Option<Exchange> = None;
74            let mut acc_bodies: Vec<Value> = Vec::new();
75            let mut index: u64 = 0;
76
77            // One-entry lookahead for CamelSplitComplete
78            let mut current = stream.next().await;
79
80            while let Some(fragment_result) = current.take() {
81                if cancel_token.is_cancelled() {
82                    return Err(CamelError::ProcessorError(
83                        "StreamingSplitter cancelled".to_string(),
84                    ));
85                }
86
87                let fragment = fragment_result?;
88
89                // Peek next to know if this is the last entry
90                let next = stream.next().await;
91                let is_last = next.is_none();
92
93                let mut fragment = fragment;
94                fragment.set_property(CAMEL_SPLIT_INDEX, Value::from(index));
95                fragment.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(is_last));
96
97                let mut pipeline = sub_pipeline.clone();
98                let ready = tower::ServiceExt::ready(&mut pipeline).await;
99                let result = match ready {
100                    Ok(svc) => svc.call(fragment).await,
101                    Err(e) => Err(e),
102                };
103
104                match result {
105                    Ok(processed) => {
106                        match &aggregation {
107                            AggregationStrategy::CollectAll => {
108                                let v = match &processed.input.body {
109                                    Body::Text(s) => Value::String(s.clone()),
110                                    Body::Json(v) => v.clone(),
111                                    Body::Xml(s) => Value::String(s.clone()),
112                                    Body::Bytes(b) => {
113                                        Value::String(String::from_utf8_lossy(b).into_owned())
114                                    }
115                                    Body::Empty => Value::Null,
116                                    Body::Stream(_) => {
117                                        return Err(CamelError::TypeConversionFailed(
118                                            "StreamingSplitter CollectAll cannot aggregate Body::Stream — use 'stream_cache' or 'convert_body_to' before this step".to_string(),
119                                        ));
120                                    }
121                                };
122                                acc_bodies.push(v);
123                            }
124                            AggregationStrategy::Custom(fold_fn) => {
125                                acc = Some(match acc {
126                                    Some(prev) => fold_fn(prev, processed),
127                                    None => processed,
128                                });
129                            }
130                            _ => {
131                                acc = Some(processed);
132                            }
133                        }
134                        index += 1;
135                    }
136                    Err(e) => {
137                        if stop_on_exception {
138                            return Err(e);
139                        }
140                        index += 1;
141                    }
142                }
143
144                current = next;
145            }
146
147            match &aggregation {
148                AggregationStrategy::LastWins => Ok(acc.unwrap_or(original)),
149                AggregationStrategy::Original => Ok(original),
150                AggregationStrategy::CollectAll => {
151                    let mut out = original;
152                    out.input.body = Body::Json(Value::Array(acc_bodies));
153                    Ok(out)
154                }
155                AggregationStrategy::Custom(_) => Ok(acc.unwrap_or(original)),
156            }
157        })
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use bytes::Bytes;
165    use camel_api::{BoxProcessorExt, Message, StreamBody, StreamMetadata};
166    use futures::stream;
167    use std::sync::Arc;
168    use tokio::sync::Mutex;
169    use tower::ServiceExt;
170
171    use crate::stream_codec::{StreamSplitInput, resolve_codec, resolve_format};
172
173    fn passthrough_pipeline() -> BoxProcessor {
174        BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
175    }
176
177    fn uppercase_pipeline() -> BoxProcessor {
178        BoxProcessor::from_fn(|mut ex: Exchange| {
179            Box::pin(async move {
180                if let Body::Text(s) = &ex.input.body {
181                    ex.input.body = Body::Text(s.to_uppercase());
182                }
183                Ok(ex)
184            })
185        })
186    }
187
188    fn make_exchange(text: &str) -> Exchange {
189        Exchange::new(Message::new(text))
190    }
191
192    fn test_expression(fragments: Vec<Exchange>) -> StreamingSplitExpression {
193        Arc::new(move |_| {
194            let frags = fragments.clone();
195            Box::pin(stream::iter(frags.into_iter().map(Ok)))
196        })
197    }
198
199    fn error_expression() -> StreamingSplitExpression {
200        Arc::new(|_| {
201            Box::pin(stream::iter(vec![Err(CamelError::ProcessorError(
202                "stream error".to_string(),
203            ))]))
204        })
205    }
206
207    /// Build a `StreamingSplitExpression` that reads from `Body::Stream` and
208    /// splits using the NdjsonCodec. Mirrors the resolution logic in
209    /// `step_resolution.rs`.
210    fn ndjson_stream_expression(config: camel_api::StreamSplitConfig) -> StreamingSplitExpression {
211        Arc::new(move |exchange: Exchange| {
212            let config = config.clone();
213            let (stream_body, parent) = match &exchange.input.body {
214                Body::Stream(sb) => (sb.clone(), {
215                    let mut p = exchange.clone();
216                    p.input.body = Body::Empty;
217                    p
218                }),
219                _ => {
220                    return Box::pin(futures::stream::once(async {
221                        Err(CamelError::ProcessorError(
222                            "streaming split requires Body::Stream".into(),
223                        ))
224                    }));
225                }
226            };
227
228            let stream = match stream_body.stream.try_lock() {
229                Ok(mut guard) => match guard.take() {
230                    Some(s) => s,
231                    None => {
232                        return Box::pin(futures::stream::once(async {
233                            Err(CamelError::ProcessorError(
234                                "stream body already consumed".into(),
235                            ))
236                        }));
237                    }
238                },
239                Err(_) => {
240                    return Box::pin(futures::stream::once(async {
241                        Err(CamelError::ProcessorError("stream body locked".into()))
242                    }));
243                }
244            };
245
246            let input = StreamSplitInput {
247                parent,
248                stream,
249                metadata: stream_body.metadata,
250            };
251
252            match resolve_format(&config.format, &input.metadata) {
253                Ok(f) => {
254                    let codec = resolve_codec(&f);
255                    codec.split(input, config)
256                }
257                Err(e) => Box::pin(futures::stream::once(async { Err(e) })),
258            }
259        })
260    }
261
262    // ---------------------------------------------------------------------------
263    // Integration test: Body::Stream NDJSON → streaming split → Body::Json fragments
264    // ---------------------------------------------------------------------------
265
266    #[tokio::test]
267    async fn test_ndjson_body_stream_streaming_split() {
268        // ── Arrange ────────────────────────────────────────────────────────────
269        // 3 lines of NDJSON as a byte stream
270        let ndjson_lines: Vec<Result<Bytes, CamelError>> = vec![
271            Ok(Bytes::from("{\"id\":1,\"name\":\"a\"}\n")),
272            Ok(Bytes::from("{\"id\":2,\"name\":\"b\"}\n")),
273            Ok(Bytes::from("{\"id\":3,\"name\":\"c\"}\n")),
274        ];
275        let byte_stream = futures::stream::iter(ndjson_lines);
276
277        let stream_body = StreamBody {
278            stream: Arc::new(Mutex::new(Some(Box::pin(byte_stream)))),
279            metadata: StreamMetadata {
280                content_type: Some("application/x-ndjson".into()),
281                size_hint: None,
282                origin: Some("test://ndjson".into()),
283            },
284        };
285
286        let ex = Exchange::new(Message::new(Body::Stream(stream_body)));
287
288        // Streaming split config — Ndjson format
289        let split_config = camel_api::StreamSplitConfig {
290            format: camel_api::StreamSplitFormat::Ndjson,
291            ..Default::default()
292        };
293
294        // Recorder sub-pipeline: captures per-fragment body + properties
295        let fragments: Arc<Mutex<Vec<(Option<serde_json::Value>, Option<Value>, Option<Value>)>>> =
296            Arc::new(Mutex::new(Vec::new()));
297        let fragments_clone = Arc::clone(&fragments);
298        let recorder = BoxProcessor::from_fn(move |ex: Exchange| {
299            let frags = Arc::clone(&fragments_clone);
300            Box::pin(async move {
301                let body_json = match &ex.input.body {
302                    Body::Json(v) => Some(v.clone()),
303                    _ => None,
304                };
305                let split_index = ex.property(CAMEL_SPLIT_INDEX).cloned();
306                let split_complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
307                let mut guard = frags.lock().await;
308                guard.push((body_json, split_index, split_complete));
309                Ok(ex)
310            })
311        });
312
313        let expression = ndjson_stream_expression(split_config);
314
315        // ── Act ────────────────────────────────────────────────────────────────
316        let mut splitter = StreamingSplitterService::new(
317            expression,
318            recorder,
319            AggregationStrategy::CollectAll,
320            true, // stop_on_exception
321        );
322
323        let result = splitter
324            .ready()
325            .await
326            .expect("splitter ready")
327            .call(ex)
328            .await
329            .expect("splitter call");
330
331        // ── Assert ─────────────────────────────────────────────────────────────
332        let guard = fragments.lock().await;
333
334        // 1. Three fragments were produced
335        assert_eq!(guard.len(), 3, "expected 3 NDJSON fragments");
336
337        // 2. Each fragment has Body::Json
338        for (i, (body_json, _idx, _complete)) in guard.iter().enumerate() {
339            assert!(
340                body_json.is_some(),
341                "fragment {i}: expected Body::Json body, got non-Json"
342            );
343        }
344
345        // 3. Each fragment has CamelSplitIndex property (0, 1, 2)
346        for (i, (_body, idx, _complete)) in guard.iter().enumerate() {
347            assert_eq!(
348                *idx,
349                Some(Value::Number(serde_json::Number::from(i as u64))),
350                "fragment {i}: CamelSplitIndex mismatch"
351            );
352        }
353
354        // 4. CamelSplitComplete: first two false, last one true
355        assert_eq!(
356            guard[0].2,
357            Some(Value::Bool(false)),
358            "first fragment: CamelSplitComplete should be false"
359        );
360        assert_eq!(
361            guard[1].2,
362            Some(Value::Bool(false)),
363            "second fragment: CamelSplitComplete should be false"
364        );
365        assert_eq!(
366            guard[2].2,
367            Some(Value::Bool(true)),
368            "last fragment: CamelSplitComplete should be true"
369        );
370
371        // 5. CollectAll aggregated into JSON array with correct values
372        match &result.input.body {
373            Body::Json(v) => {
374                let arr = v.as_array().expect("CollectAll result should be array");
375                assert_eq!(arr.len(), 3);
376                assert_eq!(arr[0], serde_json::json!({"id":1,"name":"a"}));
377                assert_eq!(arr[1], serde_json::json!({"id":2,"name":"b"}));
378                assert_eq!(arr[2], serde_json::json!({"id":3,"name":"c"}));
379            }
380            other => panic!("expected Body::Json from CollectAll, got {other:?}"),
381        }
382
383        // 6. Original stream body sanitized (already Empty, becomes part of aggregate)
384        //    The aggregate exchange's body is Json, not Stream
385        assert!(
386            matches!(result.input.body, Body::Json(_)),
387            "aggregate body should be Json, not Stream"
388        );
389    }
390
391    // ---------------------------------------------------------------------------
392    // Integration test: Empty Body::Stream → aggregate result is empty
393    // ---------------------------------------------------------------------------
394
395    #[tokio::test]
396    async fn test_ndjson_body_stream_empty_stream() {
397        // ── Arrange ────────────────────────────────────────────────────────────
398        // Empty byte stream
399        let byte_stream = futures::stream::iter(Vec::<Result<Bytes, CamelError>>::new());
400
401        let stream_body = StreamBody {
402            stream: Arc::new(Mutex::new(Some(Box::pin(byte_stream)))),
403            metadata: StreamMetadata {
404                content_type: Some("application/x-ndjson".into()),
405                size_hint: None,
406                origin: None,
407            },
408        };
409
410        let mut ex = Exchange::new(Message::new(Body::Stream(stream_body)));
411        ex.set_property("trace_id", Value::String("empty-test".into()));
412
413        let split_config = camel_api::StreamSplitConfig {
414            format: camel_api::StreamSplitFormat::Ndjson,
415            ..Default::default()
416        };
417
418        let expression = ndjson_stream_expression(split_config);
419
420        // ── Act ────────────────────────────────────────────────────────────────
421        let mut splitter = StreamingSplitterService::new(
422            expression,
423            passthrough_pipeline(),
424            AggregationStrategy::CollectAll,
425            true,
426        );
427
428        let result = splitter
429            .ready()
430            .await
431            .expect("splitter ready")
432            .call(ex)
433            .await
434            .expect("splitter call");
435
436        // ── Assert ─────────────────────────────────────────────────────────────
437        // Empty stream → CollectAll produces Body::Json([])
438        match &result.input.body {
439            Body::Json(v) => {
440                let arr = v.as_array().expect("CollectAll result should be array");
441                assert!(
442                    arr.is_empty(),
443                    "empty stream should produce empty array, got {arr:?}"
444                );
445            }
446            other => {
447                panic!("expected Body::Json([]) from CollectAll on empty stream, got {other:?}")
448            }
449        }
450
451        // Properties preserved
452        assert_eq!(
453            result.property("trace_id"),
454            Some(&Value::String("empty-test".into()))
455        );
456    }
457
458    #[tokio::test]
459    async fn test_streaming_sequential_last_wins() {
460        let expr = test_expression(vec![
461            make_exchange("a"),
462            make_exchange("b"),
463            make_exchange("c"),
464        ]);
465        let mut svc = StreamingSplitterService::new(
466            expr,
467            uppercase_pipeline(),
468            AggregationStrategy::LastWins,
469            true,
470        );
471
472        let result = svc
473            .ready()
474            .await
475            .unwrap()
476            .call(make_exchange("original"))
477            .await
478            .unwrap();
479        assert_eq!(result.input.body.as_text(), Some("C"));
480    }
481
482    #[tokio::test]
483    async fn test_streaming_sequential_original() {
484        let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
485        let mut svc = StreamingSplitterService::new(
486            expr,
487            uppercase_pipeline(),
488            AggregationStrategy::Original,
489            true,
490        );
491
492        let result = svc
493            .ready()
494            .await
495            .unwrap()
496            .call(make_exchange("original"))
497            .await
498            .unwrap();
499        assert_eq!(result.input.body.as_text(), Some("original"));
500    }
501
502    #[tokio::test]
503    async fn test_streaming_stop_on_exception() {
504        let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
505        let fail_pipeline = BoxProcessor::from_fn(|_| {
506            Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
507        });
508        let mut svc =
509            StreamingSplitterService::new(expr, fail_pipeline, AggregationStrategy::LastWins, true);
510
511        let result = svc
512            .ready()
513            .await
514            .unwrap()
515            .call(make_exchange("original"))
516            .await;
517        assert!(result.is_err());
518    }
519
520    #[tokio::test]
521    async fn test_streaming_empty_stream() {
522        let expr: StreamingSplitExpression = Arc::new(|_| Box::pin(futures::stream::empty()));
523        let mut svc = StreamingSplitterService::new(
524            expr,
525            passthrough_pipeline(),
526            AggregationStrategy::LastWins,
527            true,
528        );
529
530        let mut ex = make_exchange("original");
531        ex.set_property("marker", Value::Bool(true));
532        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
533        assert_eq!(result.input.body.as_text(), Some("original"));
534        assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
535    }
536
537    #[tokio::test]
538    async fn test_streaming_error_in_expression() {
539        let mut svc = StreamingSplitterService::new(
540            error_expression(),
541            passthrough_pipeline(),
542            AggregationStrategy::LastWins,
543            true,
544        );
545
546        let result = svc
547            .ready()
548            .await
549            .unwrap()
550            .call(make_exchange("original"))
551            .await;
552        assert!(result.is_err());
553    }
554
555    #[tokio::test]
556    async fn test_streaming_cancellation() {
557        let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
558        let slow_pipeline = BoxProcessor::from_fn(|ex| {
559            Box::pin(async move {
560                tokio::time::sleep(std::time::Duration::from_secs(60)).await;
561                Ok(ex)
562            })
563        });
564        let svc =
565            StreamingSplitterService::new(expr, slow_pipeline, AggregationStrategy::LastWins, true);
566        svc.cancel();
567
568        let mut svc_clone = svc.clone();
569        let result = svc_clone
570            .ready()
571            .await
572            .unwrap()
573            .call(make_exchange("original"))
574            .await;
575        assert!(result.is_err());
576    }
577
578    #[tokio::test]
579    async fn test_streaming_sequential_collect_all() {
580        let expr = test_expression(vec![
581            make_exchange("a"),
582            make_exchange("b"),
583            make_exchange("c"),
584        ]);
585        let mut svc = StreamingSplitterService::new(
586            expr,
587            uppercase_pipeline(),
588            AggregationStrategy::CollectAll,
589            true,
590        );
591
592        let result = svc
593            .ready()
594            .await
595            .unwrap()
596            .call(make_exchange("original"))
597            .await
598            .unwrap();
599        let expected = serde_json::json!(["A", "B", "C"]);
600        match &result.input.body {
601            Body::Json(v) => assert_eq!(*v, expected),
602            other => panic!("expected JSON body, got {other:?}"),
603        }
604    }
605
606    #[tokio::test]
607    async fn test_streaming_sequential_custom_aggregation() {
608        let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
609            Arc::new(|mut acc: Exchange, next: Exchange| {
610                let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
611                let next_text = next.input.body.as_text().unwrap_or("").to_string();
612                acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
613                acc
614            });
615
616        let expr = test_expression(vec![
617            make_exchange("a"),
618            make_exchange("b"),
619            make_exchange("c"),
620        ]);
621        let mut svc = StreamingSplitterService::new(
622            expr,
623            uppercase_pipeline(),
624            AggregationStrategy::Custom(joiner),
625            true,
626        );
627
628        let result = svc
629            .ready()
630            .await
631            .unwrap()
632            .call(make_exchange("original"))
633            .await
634            .unwrap();
635        assert_eq!(result.input.body.as_text(), Some("A+B+C"));
636    }
637
638    #[tokio::test]
639    async fn test_streaming_error_continue_on_exception() {
640        let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
641        let count_clone = call_count.clone();
642        let fail_on_first = BoxProcessor::from_fn(move |ex: Exchange| {
643            let count = count_clone.clone();
644            Box::pin(async move {
645                let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
646                if n == 0 {
647                    Err(CamelError::ProcessorError("first fails".into()))
648                } else {
649                    Ok(ex)
650                }
651            })
652        });
653
654        let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
655        let mut svc = StreamingSplitterService::new(
656            expr,
657            fail_on_first,
658            AggregationStrategy::LastWins,
659            false,
660        );
661
662        let result = svc
663            .ready()
664            .await
665            .unwrap()
666            .call(make_exchange("original"))
667            .await
668            .unwrap();
669        assert_eq!(result.input.body.as_text(), Some("b"));
670        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
671    }
672
673    #[tokio::test]
674    async fn test_streaming_metadata_lookahead() {
675        let recorder = BoxProcessor::from_fn(|ex: Exchange| {
676            Box::pin(async move {
677                let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
678                let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
679                let body = serde_json::json!({
680                    "index": idx,
681                    "complete": complete,
682                });
683                let mut out = ex;
684                out.input.body = Body::Json(body);
685                Ok(out)
686            })
687        });
688
689        let expr = test_expression(vec![
690            make_exchange("x"),
691            make_exchange("y"),
692            make_exchange("z"),
693        ]);
694        let mut svc =
695            StreamingSplitterService::new(expr, recorder, AggregationStrategy::CollectAll, true);
696
697        let result = svc
698            .ready()
699            .await
700            .unwrap()
701            .call(make_exchange("original"))
702            .await
703            .unwrap();
704        let expected = serde_json::json!([
705            {"index": 0, "complete": false},
706            {"index": 1, "complete": false},
707            {"index": 2, "complete": true},
708        ]);
709        match &result.input.body {
710            Body::Json(v) => assert_eq!(*v, expected),
711            other => panic!("expected JSON body, got {other:?}"),
712        }
713    }
714
715    #[tokio::test]
716    async fn test_streaming_split_sanitizes_stream_body_in_original() {
717        let chunks = vec![Ok(Bytes::from("line1\n"))];
718        let stream = futures::stream::iter(chunks);
719        let sb = StreamBody {
720            stream: Arc::new(Mutex::new(Some(Box::pin(stream)))),
721            metadata: Default::default(),
722        };
723        let ex = Exchange::new(Message::new(Body::Stream(sb)));
724
725        let expression =
726            test_expression(vec![Exchange::new(Message::new(Body::Text("frag".into())))]);
727        let sub_pipeline = passthrough_pipeline();
728        let mut splitter = StreamingSplitterService::new(
729            expression,
730            sub_pipeline,
731            AggregationStrategy::Original,
732            true,
733        );
734
735        let result = splitter
736            .ready()
737            .await
738            .expect("ready")
739            .call(ex)
740            .await
741            .expect("call");
742        assert!(
743            matches!(result.input.body, Body::Empty),
744            "original body should be sanitized to Empty"
745        );
746    }
747}