1use futures::{StreamExt, pin_mut};
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio_util::sync::CancellationToken;
6use tower::Service;
7
8use camel_api::{
9 AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, StreamingSplitExpression, Value,
10};
11
12pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
13pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
14
15#[derive(Clone)]
16pub struct StreamingSplitterService {
17 expression: StreamingSplitExpression,
18 sub_pipeline: BoxProcessor,
19 aggregation: AggregationStrategy,
20 stop_on_exception: bool,
21 cancel_token: CancellationToken,
22}
23
24impl StreamingSplitterService {
25 pub fn new(
26 expression: StreamingSplitExpression,
27 sub_pipeline: BoxProcessor,
28 aggregation: AggregationStrategy,
29 stop_on_exception: bool,
30 ) -> Self {
31 Self {
32 expression,
33 sub_pipeline,
34 aggregation,
35 stop_on_exception,
36 cancel_token: CancellationToken::new(),
37 }
38 }
39
40 pub fn cancel(&self) {
41 self.cancel_token.cancel();
42 }
43
44 pub fn is_cancelled(&self) -> bool {
45 self.cancel_token.is_cancelled()
46 }
47}
48
49impl Service<Exchange> for StreamingSplitterService {
50 type Response = Exchange;
51 type Error = CamelError;
52 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
53
54 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
55 self.sub_pipeline.poll_ready(cx)
56 }
57
58 fn call(&mut self, exchange: Exchange) -> Self::Future {
59 let mut original = exchange.clone();
60 if matches!(original.input.body, Body::Stream(_)) {
61 original.input.body = Body::Empty;
62 }
63 let expression = self.expression.clone();
64 let sub_pipeline = self.sub_pipeline.clone();
65 let aggregation = self.aggregation.clone();
66 let stop_on_exception = self.stop_on_exception;
67 let cancel_token = self.cancel_token.clone();
68
69 Box::pin(async move {
70 let stream = expression(exchange);
71 pin_mut!(stream);
72
73 let mut acc: Option<Exchange> = None;
74 let mut acc_bodies: Vec<Value> = Vec::new();
75 let mut index: u64 = 0;
76
77 let mut current = stream.next().await;
79
80 while let Some(fragment_result) = current.take() {
81 if cancel_token.is_cancelled() {
82 return Err(CamelError::ProcessorError(
83 "StreamingSplitter cancelled".to_string(),
84 ));
85 }
86
87 let fragment = fragment_result?;
88
89 let next = stream.next().await;
91 let is_last = next.is_none();
92
93 let mut fragment = fragment;
94 fragment.set_property(CAMEL_SPLIT_INDEX, Value::from(index));
95 fragment.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(is_last));
96
97 let mut pipeline = sub_pipeline.clone();
98 let ready = tower::ServiceExt::ready(&mut pipeline).await;
99 let result = match ready {
100 Ok(svc) => svc.call(fragment).await,
101 Err(e) => Err(e),
102 };
103
104 match result {
105 Ok(processed) => {
106 match &aggregation {
107 AggregationStrategy::CollectAll => {
108 let v = match &processed.input.body {
109 Body::Text(s) => Value::String(s.clone()),
110 Body::Json(v) => v.clone(),
111 Body::Xml(s) => Value::String(s.clone()),
112 Body::Bytes(b) => {
113 Value::String(String::from_utf8_lossy(b).into_owned())
114 }
115 Body::Empty => Value::Null,
116 Body::Stream(_) => {
117 return Err(CamelError::TypeConversionFailed(
118 "StreamingSplitter CollectAll cannot aggregate Body::Stream — use 'stream_cache' or 'convert_body_to' before this step".to_string(),
119 ));
120 }
121 };
122 acc_bodies.push(v);
123 }
124 AggregationStrategy::Custom(fold_fn) => {
125 acc = Some(match acc {
126 Some(prev) => fold_fn(prev, processed),
127 None => processed,
128 });
129 }
130 _ => {
131 acc = Some(processed);
132 }
133 }
134 index += 1;
135 }
136 Err(e) => {
137 if stop_on_exception {
138 return Err(e);
139 }
140 index += 1;
141 }
142 }
143
144 current = next;
145 }
146
147 match &aggregation {
148 AggregationStrategy::LastWins => Ok(acc.unwrap_or(original)),
149 AggregationStrategy::Original => Ok(original),
150 AggregationStrategy::CollectAll => {
151 let mut out = original;
152 out.input.body = Body::Json(Value::Array(acc_bodies));
153 Ok(out)
154 }
155 AggregationStrategy::Custom(_) => Ok(acc.unwrap_or(original)),
156 }
157 })
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use bytes::Bytes;
165 use camel_api::{BoxProcessorExt, Message, StreamBody, StreamMetadata};
166 use futures::stream;
167 use std::sync::Arc;
168 use tokio::sync::Mutex;
169 use tower::ServiceExt;
170
171 use crate::stream_codec::{StreamSplitInput, resolve_codec, resolve_format};
172
173 fn passthrough_pipeline() -> BoxProcessor {
174 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
175 }
176
177 fn uppercase_pipeline() -> BoxProcessor {
178 BoxProcessor::from_fn(|mut ex: Exchange| {
179 Box::pin(async move {
180 if let Body::Text(s) = &ex.input.body {
181 ex.input.body = Body::Text(s.to_uppercase());
182 }
183 Ok(ex)
184 })
185 })
186 }
187
188 fn make_exchange(text: &str) -> Exchange {
189 Exchange::new(Message::new(text))
190 }
191
192 fn test_expression(fragments: Vec<Exchange>) -> StreamingSplitExpression {
193 Arc::new(move |_| {
194 let frags = fragments.clone();
195 Box::pin(stream::iter(frags.into_iter().map(Ok)))
196 })
197 }
198
199 fn error_expression() -> StreamingSplitExpression {
200 Arc::new(|_| {
201 Box::pin(stream::iter(vec![Err(CamelError::ProcessorError(
202 "stream error".to_string(),
203 ))]))
204 })
205 }
206
207 fn ndjson_stream_expression(config: camel_api::StreamSplitConfig) -> StreamingSplitExpression {
211 Arc::new(move |exchange: Exchange| {
212 let config = config.clone();
213 let (stream_body, parent) = match &exchange.input.body {
214 Body::Stream(sb) => (sb.clone(), {
215 let mut p = exchange.clone();
216 p.input.body = Body::Empty;
217 p
218 }),
219 _ => {
220 return Box::pin(futures::stream::once(async {
221 Err(CamelError::ProcessorError(
222 "streaming split requires Body::Stream".into(),
223 ))
224 }));
225 }
226 };
227
228 let stream = match stream_body.stream.try_lock() {
229 Ok(mut guard) => match guard.take() {
230 Some(s) => s,
231 None => {
232 return Box::pin(futures::stream::once(async {
233 Err(CamelError::ProcessorError(
234 "stream body already consumed".into(),
235 ))
236 }));
237 }
238 },
239 Err(_) => {
240 return Box::pin(futures::stream::once(async {
241 Err(CamelError::ProcessorError("stream body locked".into()))
242 }));
243 }
244 };
245
246 let input = StreamSplitInput {
247 parent,
248 stream,
249 metadata: stream_body.metadata,
250 };
251
252 match resolve_format(&config.format, &input.metadata) {
253 Ok(f) => {
254 let codec = resolve_codec(&f);
255 codec.split(input, config)
256 }
257 Err(e) => Box::pin(futures::stream::once(async { Err(e) })),
258 }
259 })
260 }
261
262 #[tokio::test]
267 async fn test_ndjson_body_stream_streaming_split() {
268 let ndjson_lines: Vec<Result<Bytes, CamelError>> = vec![
271 Ok(Bytes::from("{\"id\":1,\"name\":\"a\"}\n")),
272 Ok(Bytes::from("{\"id\":2,\"name\":\"b\"}\n")),
273 Ok(Bytes::from("{\"id\":3,\"name\":\"c\"}\n")),
274 ];
275 let byte_stream = futures::stream::iter(ndjson_lines);
276
277 let stream_body = StreamBody {
278 stream: Arc::new(Mutex::new(Some(Box::pin(byte_stream)))),
279 metadata: StreamMetadata {
280 content_type: Some("application/x-ndjson".into()),
281 size_hint: None,
282 origin: Some("test://ndjson".into()),
283 },
284 };
285
286 let ex = Exchange::new(Message::new(Body::Stream(stream_body)));
287
288 let split_config = camel_api::StreamSplitConfig {
290 format: camel_api::StreamSplitFormat::Ndjson,
291 ..Default::default()
292 };
293
294 let fragments: Arc<Mutex<Vec<(Option<serde_json::Value>, Option<Value>, Option<Value>)>>> =
296 Arc::new(Mutex::new(Vec::new()));
297 let fragments_clone = Arc::clone(&fragments);
298 let recorder = BoxProcessor::from_fn(move |ex: Exchange| {
299 let frags = Arc::clone(&fragments_clone);
300 Box::pin(async move {
301 let body_json = match &ex.input.body {
302 Body::Json(v) => Some(v.clone()),
303 _ => None,
304 };
305 let split_index = ex.property(CAMEL_SPLIT_INDEX).cloned();
306 let split_complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
307 let mut guard = frags.lock().await;
308 guard.push((body_json, split_index, split_complete));
309 Ok(ex)
310 })
311 });
312
313 let expression = ndjson_stream_expression(split_config);
314
315 let mut splitter = StreamingSplitterService::new(
317 expression,
318 recorder,
319 AggregationStrategy::CollectAll,
320 true, );
322
323 let result = splitter
324 .ready()
325 .await
326 .expect("splitter ready")
327 .call(ex)
328 .await
329 .expect("splitter call");
330
331 let guard = fragments.lock().await;
333
334 assert_eq!(guard.len(), 3, "expected 3 NDJSON fragments");
336
337 for (i, (body_json, _idx, _complete)) in guard.iter().enumerate() {
339 assert!(
340 body_json.is_some(),
341 "fragment {i}: expected Body::Json body, got non-Json"
342 );
343 }
344
345 for (i, (_body, idx, _complete)) in guard.iter().enumerate() {
347 assert_eq!(
348 *idx,
349 Some(Value::Number(serde_json::Number::from(i as u64))),
350 "fragment {i}: CamelSplitIndex mismatch"
351 );
352 }
353
354 assert_eq!(
356 guard[0].2,
357 Some(Value::Bool(false)),
358 "first fragment: CamelSplitComplete should be false"
359 );
360 assert_eq!(
361 guard[1].2,
362 Some(Value::Bool(false)),
363 "second fragment: CamelSplitComplete should be false"
364 );
365 assert_eq!(
366 guard[2].2,
367 Some(Value::Bool(true)),
368 "last fragment: CamelSplitComplete should be true"
369 );
370
371 match &result.input.body {
373 Body::Json(v) => {
374 let arr = v.as_array().expect("CollectAll result should be array");
375 assert_eq!(arr.len(), 3);
376 assert_eq!(arr[0], serde_json::json!({"id":1,"name":"a"}));
377 assert_eq!(arr[1], serde_json::json!({"id":2,"name":"b"}));
378 assert_eq!(arr[2], serde_json::json!({"id":3,"name":"c"}));
379 }
380 other => panic!("expected Body::Json from CollectAll, got {other:?}"),
381 }
382
383 assert!(
386 matches!(result.input.body, Body::Json(_)),
387 "aggregate body should be Json, not Stream"
388 );
389 }
390
391 #[tokio::test]
396 async fn test_ndjson_body_stream_empty_stream() {
397 let byte_stream = futures::stream::iter(Vec::<Result<Bytes, CamelError>>::new());
400
401 let stream_body = StreamBody {
402 stream: Arc::new(Mutex::new(Some(Box::pin(byte_stream)))),
403 metadata: StreamMetadata {
404 content_type: Some("application/x-ndjson".into()),
405 size_hint: None,
406 origin: None,
407 },
408 };
409
410 let mut ex = Exchange::new(Message::new(Body::Stream(stream_body)));
411 ex.set_property("trace_id", Value::String("empty-test".into()));
412
413 let split_config = camel_api::StreamSplitConfig {
414 format: camel_api::StreamSplitFormat::Ndjson,
415 ..Default::default()
416 };
417
418 let expression = ndjson_stream_expression(split_config);
419
420 let mut splitter = StreamingSplitterService::new(
422 expression,
423 passthrough_pipeline(),
424 AggregationStrategy::CollectAll,
425 true,
426 );
427
428 let result = splitter
429 .ready()
430 .await
431 .expect("splitter ready")
432 .call(ex)
433 .await
434 .expect("splitter call");
435
436 match &result.input.body {
439 Body::Json(v) => {
440 let arr = v.as_array().expect("CollectAll result should be array");
441 assert!(
442 arr.is_empty(),
443 "empty stream should produce empty array, got {arr:?}"
444 );
445 }
446 other => {
447 panic!("expected Body::Json([]) from CollectAll on empty stream, got {other:?}")
448 }
449 }
450
451 assert_eq!(
453 result.property("trace_id"),
454 Some(&Value::String("empty-test".into()))
455 );
456 }
457
458 #[tokio::test]
459 async fn test_streaming_sequential_last_wins() {
460 let expr = test_expression(vec![
461 make_exchange("a"),
462 make_exchange("b"),
463 make_exchange("c"),
464 ]);
465 let mut svc = StreamingSplitterService::new(
466 expr,
467 uppercase_pipeline(),
468 AggregationStrategy::LastWins,
469 true,
470 );
471
472 let result = svc
473 .ready()
474 .await
475 .unwrap()
476 .call(make_exchange("original"))
477 .await
478 .unwrap();
479 assert_eq!(result.input.body.as_text(), Some("C"));
480 }
481
482 #[tokio::test]
483 async fn test_streaming_sequential_original() {
484 let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
485 let mut svc = StreamingSplitterService::new(
486 expr,
487 uppercase_pipeline(),
488 AggregationStrategy::Original,
489 true,
490 );
491
492 let result = svc
493 .ready()
494 .await
495 .unwrap()
496 .call(make_exchange("original"))
497 .await
498 .unwrap();
499 assert_eq!(result.input.body.as_text(), Some("original"));
500 }
501
502 #[tokio::test]
503 async fn test_streaming_stop_on_exception() {
504 let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
505 let fail_pipeline = BoxProcessor::from_fn(|_| {
506 Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
507 });
508 let mut svc =
509 StreamingSplitterService::new(expr, fail_pipeline, AggregationStrategy::LastWins, true);
510
511 let result = svc
512 .ready()
513 .await
514 .unwrap()
515 .call(make_exchange("original"))
516 .await;
517 assert!(result.is_err());
518 }
519
520 #[tokio::test]
521 async fn test_streaming_empty_stream() {
522 let expr: StreamingSplitExpression = Arc::new(|_| Box::pin(futures::stream::empty()));
523 let mut svc = StreamingSplitterService::new(
524 expr,
525 passthrough_pipeline(),
526 AggregationStrategy::LastWins,
527 true,
528 );
529
530 let mut ex = make_exchange("original");
531 ex.set_property("marker", Value::Bool(true));
532 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
533 assert_eq!(result.input.body.as_text(), Some("original"));
534 assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
535 }
536
537 #[tokio::test]
538 async fn test_streaming_error_in_expression() {
539 let mut svc = StreamingSplitterService::new(
540 error_expression(),
541 passthrough_pipeline(),
542 AggregationStrategy::LastWins,
543 true,
544 );
545
546 let result = svc
547 .ready()
548 .await
549 .unwrap()
550 .call(make_exchange("original"))
551 .await;
552 assert!(result.is_err());
553 }
554
555 #[tokio::test]
556 async fn test_streaming_cancellation() {
557 let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
558 let slow_pipeline = BoxProcessor::from_fn(|ex| {
559 Box::pin(async move {
560 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
561 Ok(ex)
562 })
563 });
564 let svc =
565 StreamingSplitterService::new(expr, slow_pipeline, AggregationStrategy::LastWins, true);
566 svc.cancel();
567
568 let mut svc_clone = svc.clone();
569 let result = svc_clone
570 .ready()
571 .await
572 .unwrap()
573 .call(make_exchange("original"))
574 .await;
575 assert!(result.is_err());
576 }
577
578 #[tokio::test]
579 async fn test_streaming_sequential_collect_all() {
580 let expr = test_expression(vec![
581 make_exchange("a"),
582 make_exchange("b"),
583 make_exchange("c"),
584 ]);
585 let mut svc = StreamingSplitterService::new(
586 expr,
587 uppercase_pipeline(),
588 AggregationStrategy::CollectAll,
589 true,
590 );
591
592 let result = svc
593 .ready()
594 .await
595 .unwrap()
596 .call(make_exchange("original"))
597 .await
598 .unwrap();
599 let expected = serde_json::json!(["A", "B", "C"]);
600 match &result.input.body {
601 Body::Json(v) => assert_eq!(*v, expected),
602 other => panic!("expected JSON body, got {other:?}"),
603 }
604 }
605
606 #[tokio::test]
607 async fn test_streaming_sequential_custom_aggregation() {
608 let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
609 Arc::new(|mut acc: Exchange, next: Exchange| {
610 let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
611 let next_text = next.input.body.as_text().unwrap_or("").to_string();
612 acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
613 acc
614 });
615
616 let expr = test_expression(vec![
617 make_exchange("a"),
618 make_exchange("b"),
619 make_exchange("c"),
620 ]);
621 let mut svc = StreamingSplitterService::new(
622 expr,
623 uppercase_pipeline(),
624 AggregationStrategy::Custom(joiner),
625 true,
626 );
627
628 let result = svc
629 .ready()
630 .await
631 .unwrap()
632 .call(make_exchange("original"))
633 .await
634 .unwrap();
635 assert_eq!(result.input.body.as_text(), Some("A+B+C"));
636 }
637
638 #[tokio::test]
639 async fn test_streaming_error_continue_on_exception() {
640 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
641 let count_clone = call_count.clone();
642 let fail_on_first = BoxProcessor::from_fn(move |ex: Exchange| {
643 let count = count_clone.clone();
644 Box::pin(async move {
645 let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
646 if n == 0 {
647 Err(CamelError::ProcessorError("first fails".into()))
648 } else {
649 Ok(ex)
650 }
651 })
652 });
653
654 let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
655 let mut svc = StreamingSplitterService::new(
656 expr,
657 fail_on_first,
658 AggregationStrategy::LastWins,
659 false,
660 );
661
662 let result = svc
663 .ready()
664 .await
665 .unwrap()
666 .call(make_exchange("original"))
667 .await
668 .unwrap();
669 assert_eq!(result.input.body.as_text(), Some("b"));
670 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
671 }
672
673 #[tokio::test]
674 async fn test_streaming_metadata_lookahead() {
675 let recorder = BoxProcessor::from_fn(|ex: Exchange| {
676 Box::pin(async move {
677 let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
678 let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
679 let body = serde_json::json!({
680 "index": idx,
681 "complete": complete,
682 });
683 let mut out = ex;
684 out.input.body = Body::Json(body);
685 Ok(out)
686 })
687 });
688
689 let expr = test_expression(vec![
690 make_exchange("x"),
691 make_exchange("y"),
692 make_exchange("z"),
693 ]);
694 let mut svc =
695 StreamingSplitterService::new(expr, recorder, AggregationStrategy::CollectAll, true);
696
697 let result = svc
698 .ready()
699 .await
700 .unwrap()
701 .call(make_exchange("original"))
702 .await
703 .unwrap();
704 let expected = serde_json::json!([
705 {"index": 0, "complete": false},
706 {"index": 1, "complete": false},
707 {"index": 2, "complete": true},
708 ]);
709 match &result.input.body {
710 Body::Json(v) => assert_eq!(*v, expected),
711 other => panic!("expected JSON body, got {other:?}"),
712 }
713 }
714
715 #[tokio::test]
716 async fn test_streaming_split_sanitizes_stream_body_in_original() {
717 let chunks = vec![Ok(Bytes::from("line1\n"))];
718 let stream = futures::stream::iter(chunks);
719 let sb = StreamBody {
720 stream: Arc::new(Mutex::new(Some(Box::pin(stream)))),
721 metadata: Default::default(),
722 };
723 let ex = Exchange::new(Message::new(Body::Stream(sb)));
724
725 let expression =
726 test_expression(vec![Exchange::new(Message::new(Body::Text("frag".into())))]);
727 let sub_pipeline = passthrough_pipeline();
728 let mut splitter = StreamingSplitterService::new(
729 expression,
730 sub_pipeline,
731 AggregationStrategy::Original,
732 true,
733 );
734
735 let result = splitter
736 .ready()
737 .await
738 .expect("ready")
739 .call(ex)
740 .await
741 .expect("call");
742 assert!(
743 matches!(result.input.body, Body::Empty),
744 "original body should be sanitized to Empty"
745 );
746 }
747}