1use futures::{StreamExt, pin_mut};
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio_util::sync::CancellationToken;
6use tower::Service;
7
8use camel_api::{
9 AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, StreamingSplitExpression, Value,
10};
11
12pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
13pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
14
15#[derive(Clone)]
16pub struct StreamingSplitterService {
17 expression: StreamingSplitExpression,
18 sub_pipeline: BoxProcessor,
19 aggregation: AggregationStrategy,
20 stop_on_exception: bool,
21 cancel_token: CancellationToken,
22}
23
24impl StreamingSplitterService {
25 pub fn new(
26 expression: StreamingSplitExpression,
27 sub_pipeline: BoxProcessor,
28 aggregation: AggregationStrategy,
29 stop_on_exception: bool,
30 ) -> Self {
31 Self {
32 expression,
33 sub_pipeline,
34 aggregation,
35 stop_on_exception,
36 cancel_token: CancellationToken::new(),
37 }
38 }
39
40 pub fn cancel(&self) {
41 self.cancel_token.cancel();
42 }
43
44 pub fn is_cancelled(&self) -> bool {
45 self.cancel_token.is_cancelled()
46 }
47}
48
49impl Service<Exchange> for StreamingSplitterService {
50 type Response = Exchange;
51 type Error = CamelError;
52 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
53
54 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
55 self.sub_pipeline.poll_ready(cx)
56 }
57
58 fn call(&mut self, exchange: Exchange) -> Self::Future {
59 let mut original = exchange.clone();
60 if matches!(original.input.body, Body::Stream(_)) {
61 original.input.body = Body::Empty;
62 }
63 let expression = self.expression.clone();
64 let sub_pipeline = self.sub_pipeline.clone();
65 let aggregation = self.aggregation.clone();
66 let stop_on_exception = self.stop_on_exception;
67 let cancel_token = self.cancel_token.clone();
68
69 Box::pin(async move {
70 let stream = expression(exchange);
71 pin_mut!(stream);
72
73 let mut acc: Option<Exchange> = None;
74 let mut acc_bodies: Vec<Value> = Vec::new();
75 let mut index: u64 = 0;
76
77 let mut current = stream.next().await;
79
80 while let Some(fragment_result) = current.take() {
81 if cancel_token.is_cancelled() {
82 return Err(CamelError::ProcessorError(
83 "StreamingSplitter cancelled".to_string(),
84 ));
85 }
86
87 let fragment = fragment_result?;
88
89 let next = stream.next().await;
91 let is_last = next.is_none();
92
93 let mut fragment = fragment;
94 fragment.set_property(CAMEL_SPLIT_INDEX, Value::from(index));
95 fragment.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(is_last));
96
97 let mut pipeline = sub_pipeline.clone();
98 let ready = tower::ServiceExt::ready(&mut pipeline).await;
99 let result = match ready {
100 Ok(svc) => svc.call(fragment).await,
101 Err(e) => Err(e),
102 };
103
104 match result {
105 Ok(processed) => {
106 match &aggregation {
107 AggregationStrategy::CollectAll => {
108 let v = match &processed.input.body {
109 Body::Text(s) => Value::String(s.clone()),
110 Body::Json(v) => v.clone(),
111 Body::Xml(s) => Value::String(s.clone()),
112 Body::Bytes(b) => {
113 Value::String(String::from_utf8_lossy(b).into_owned())
114 }
115 Body::Empty => Value::Null,
116 Body::Stream(_) => {
117 return Err(CamelError::TypeConversionFailed(
118 "StreamingSplitter CollectAll cannot aggregate Body::Stream — use 'stream_cache' or 'convert_body_to' before this step".to_string(),
119 ));
120 }
121 };
122 acc_bodies.push(v);
123 }
124 AggregationStrategy::Custom(fold_fn) => {
125 acc = Some(match acc {
126 Some(prev) => fold_fn(prev, processed),
127 None => processed,
128 });
129 }
130 _ => {
131 acc = Some(processed);
132 }
133 }
134 index += 1;
135 }
136 Err(e) => {
137 if stop_on_exception {
138 return Err(e);
139 }
140 index += 1;
141 }
142 }
143
144 current = next;
145 }
146
147 match &aggregation {
148 AggregationStrategy::LastWins => Ok(acc.unwrap_or(original)),
149 AggregationStrategy::Original => Ok(original),
150 AggregationStrategy::CollectAll => {
151 let mut out = original;
152 out.input.body = Body::Json(Value::Array(acc_bodies));
153 Ok(out)
154 }
155 AggregationStrategy::Custom(_) => Ok(acc.unwrap_or(original)),
156 }
157 })
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use bytes::Bytes;
165 use camel_api::{BoxProcessorExt, Message, StreamBody, StreamMetadata};
166 use futures::stream;
167 use std::sync::Arc;
168 use tokio::sync::Mutex;
169 use tower::ServiceExt;
170
171 use crate::stream_codec::{StreamSplitInput, resolve_format, resolve_incremental_codec};
172
173 fn passthrough_pipeline() -> BoxProcessor {
174 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
175 }
176
177 fn uppercase_pipeline() -> BoxProcessor {
178 BoxProcessor::from_fn(|mut ex: Exchange| {
179 Box::pin(async move {
180 if let Body::Text(s) = &ex.input.body {
181 ex.input.body = Body::Text(s.to_uppercase());
182 }
183 Ok(ex)
184 })
185 })
186 }
187
188 fn make_exchange(text: &str) -> Exchange {
189 Exchange::new(Message::new(text))
190 }
191
192 fn test_expression(fragments: Vec<Exchange>) -> StreamingSplitExpression {
193 Arc::new(move |_| {
194 let frags = fragments.clone();
195 Box::pin(stream::iter(frags.into_iter().map(Ok)))
196 })
197 }
198
199 fn error_expression() -> StreamingSplitExpression {
200 Arc::new(|_| {
201 Box::pin(stream::iter(vec![Err(CamelError::ProcessorError(
202 "stream error".to_string(),
203 ))]))
204 })
205 }
206
207 fn ndjson_stream_expression(config: camel_api::StreamSplitConfig) -> StreamingSplitExpression {
211 Arc::new(move |exchange: Exchange| {
212 let config = config.clone();
213 let (stream_body, parent) = match &exchange.input.body {
214 Body::Stream(sb) => (sb.clone(), {
215 let mut p = exchange.clone();
216 p.input.body = Body::Empty;
217 p
218 }),
219 _ => {
220 return Box::pin(futures::stream::once(async {
221 Err(CamelError::ProcessorError(
222 "streaming split requires Body::Stream".into(),
223 ))
224 }));
225 }
226 };
227
228 let stream = match stream_body.stream.try_lock() {
229 Ok(mut guard) => match guard.take() {
230 Some(s) => s,
231 None => {
232 return Box::pin(futures::stream::once(async {
233 Err(CamelError::ProcessorError(
234 "stream body already consumed".into(),
235 ))
236 }));
237 }
238 },
239 Err(_) => {
240 return Box::pin(futures::stream::once(async {
241 Err(CamelError::ProcessorError("stream body locked".into()))
242 }));
243 }
244 };
245
246 let input = StreamSplitInput {
247 parent,
248 stream,
249 metadata: stream_body.metadata,
250 };
251
252 match resolve_format(&config.format, &input.metadata) {
253 Ok(f) => {
254 let codec = resolve_incremental_codec(&f);
255 let codec = match codec {
256 Ok(c) => c,
257 Err(e) => return Box::pin(futures::stream::once(async { Err(e) })),
258 };
259 codec.split(input, config)
260 }
261 Err(e) => Box::pin(futures::stream::once(async { Err(e) })),
262 }
263 })
264 }
265
266 #[tokio::test]
271 async fn test_ndjson_body_stream_streaming_split() {
272 let ndjson_lines: Vec<Result<Bytes, CamelError>> = vec![
275 Ok(Bytes::from("{\"id\":1,\"name\":\"a\"}\n")),
276 Ok(Bytes::from("{\"id\":2,\"name\":\"b\"}\n")),
277 Ok(Bytes::from("{\"id\":3,\"name\":\"c\"}\n")),
278 ];
279 let byte_stream = futures::stream::iter(ndjson_lines);
280
281 let stream_body = StreamBody {
282 stream: Arc::new(Mutex::new(Some(Box::pin(byte_stream)))),
283 metadata: StreamMetadata {
284 content_type: Some("application/x-ndjson".into()),
285 size_hint: None,
286 origin: Some("test://ndjson".into()),
287 },
288 };
289
290 let ex = Exchange::new(Message::new(Body::Stream(stream_body)));
291
292 let split_config = camel_api::StreamSplitConfig {
294 format: camel_api::StreamSplitFormat::Ndjson,
295 ..Default::default()
296 };
297
298 let fragments: Arc<Mutex<Vec<(Option<serde_json::Value>, Option<Value>, Option<Value>)>>> =
300 Arc::new(Mutex::new(Vec::new()));
301 let fragments_clone = Arc::clone(&fragments);
302 let recorder = BoxProcessor::from_fn(move |ex: Exchange| {
303 let frags = Arc::clone(&fragments_clone);
304 Box::pin(async move {
305 let body_json = match &ex.input.body {
306 Body::Json(v) => Some(v.clone()),
307 _ => None,
308 };
309 let split_index = ex.property(CAMEL_SPLIT_INDEX).cloned();
310 let split_complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
311 let mut guard = frags.lock().await;
312 guard.push((body_json, split_index, split_complete));
313 Ok(ex)
314 })
315 });
316
317 let expression = ndjson_stream_expression(split_config);
318
319 let mut splitter = StreamingSplitterService::new(
321 expression,
322 recorder,
323 AggregationStrategy::CollectAll,
324 true, );
326
327 let result = splitter
328 .ready()
329 .await
330 .expect("splitter ready")
331 .call(ex)
332 .await
333 .expect("splitter call");
334
335 let guard = fragments.lock().await;
337
338 assert_eq!(guard.len(), 3, "expected 3 NDJSON fragments");
340
341 for (i, (body_json, _idx, _complete)) in guard.iter().enumerate() {
343 assert!(
344 body_json.is_some(),
345 "fragment {i}: expected Body::Json body, got non-Json"
346 );
347 }
348
349 for (i, (_body, idx, _complete)) in guard.iter().enumerate() {
351 assert_eq!(
352 *idx,
353 Some(Value::Number(serde_json::Number::from(i as u64))),
354 "fragment {i}: CamelSplitIndex mismatch"
355 );
356 }
357
358 assert_eq!(
360 guard[0].2,
361 Some(Value::Bool(false)),
362 "first fragment: CamelSplitComplete should be false"
363 );
364 assert_eq!(
365 guard[1].2,
366 Some(Value::Bool(false)),
367 "second fragment: CamelSplitComplete should be false"
368 );
369 assert_eq!(
370 guard[2].2,
371 Some(Value::Bool(true)),
372 "last fragment: CamelSplitComplete should be true"
373 );
374
375 match &result.input.body {
377 Body::Json(v) => {
378 let arr = v.as_array().expect("CollectAll result should be array");
379 assert_eq!(arr.len(), 3);
380 assert_eq!(arr[0], serde_json::json!({"id":1,"name":"a"}));
381 assert_eq!(arr[1], serde_json::json!({"id":2,"name":"b"}));
382 assert_eq!(arr[2], serde_json::json!({"id":3,"name":"c"}));
383 }
384 other => panic!("expected Body::Json from CollectAll, got {other:?}"),
385 }
386
387 assert!(
390 matches!(result.input.body, Body::Json(_)),
391 "aggregate body should be Json, not Stream"
392 );
393 }
394
395 #[tokio::test]
400 async fn test_ndjson_body_stream_empty_stream() {
401 let byte_stream = futures::stream::iter(Vec::<Result<Bytes, CamelError>>::new());
404
405 let stream_body = StreamBody {
406 stream: Arc::new(Mutex::new(Some(Box::pin(byte_stream)))),
407 metadata: StreamMetadata {
408 content_type: Some("application/x-ndjson".into()),
409 size_hint: None,
410 origin: None,
411 },
412 };
413
414 let mut ex = Exchange::new(Message::new(Body::Stream(stream_body)));
415 ex.set_property("trace_id", Value::String("empty-test".into()));
416
417 let split_config = camel_api::StreamSplitConfig {
418 format: camel_api::StreamSplitFormat::Ndjson,
419 ..Default::default()
420 };
421
422 let expression = ndjson_stream_expression(split_config);
423
424 let mut splitter = StreamingSplitterService::new(
426 expression,
427 passthrough_pipeline(),
428 AggregationStrategy::CollectAll,
429 true,
430 );
431
432 let result = splitter
433 .ready()
434 .await
435 .expect("splitter ready")
436 .call(ex)
437 .await
438 .expect("splitter call");
439
440 match &result.input.body {
443 Body::Json(v) => {
444 let arr = v.as_array().expect("CollectAll result should be array");
445 assert!(
446 arr.is_empty(),
447 "empty stream should produce empty array, got {arr:?}"
448 );
449 }
450 other => {
451 panic!("expected Body::Json([]) from CollectAll on empty stream, got {other:?}")
452 }
453 }
454
455 assert_eq!(
457 result.property("trace_id"),
458 Some(&Value::String("empty-test".into()))
459 );
460 }
461
462 #[tokio::test]
463 async fn test_streaming_sequential_last_wins() {
464 let expr = test_expression(vec![
465 make_exchange("a"),
466 make_exchange("b"),
467 make_exchange("c"),
468 ]);
469 let mut svc = StreamingSplitterService::new(
470 expr,
471 uppercase_pipeline(),
472 AggregationStrategy::LastWins,
473 true,
474 );
475
476 let result = svc
477 .ready()
478 .await
479 .unwrap()
480 .call(make_exchange("original"))
481 .await
482 .unwrap();
483 assert_eq!(result.input.body.as_text(), Some("C"));
484 }
485
486 #[tokio::test]
487 async fn test_streaming_sequential_original() {
488 let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
489 let mut svc = StreamingSplitterService::new(
490 expr,
491 uppercase_pipeline(),
492 AggregationStrategy::Original,
493 true,
494 );
495
496 let result = svc
497 .ready()
498 .await
499 .unwrap()
500 .call(make_exchange("original"))
501 .await
502 .unwrap();
503 assert_eq!(result.input.body.as_text(), Some("original"));
504 }
505
506 #[tokio::test]
507 async fn test_streaming_stop_on_exception() {
508 let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
509 let fail_pipeline = BoxProcessor::from_fn(|_| {
510 Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
511 });
512 let mut svc =
513 StreamingSplitterService::new(expr, fail_pipeline, AggregationStrategy::LastWins, true);
514
515 let result = svc
516 .ready()
517 .await
518 .unwrap()
519 .call(make_exchange("original"))
520 .await;
521 assert!(result.is_err());
522 }
523
524 #[tokio::test]
525 async fn test_streaming_empty_stream() {
526 let expr: StreamingSplitExpression = Arc::new(|_| Box::pin(futures::stream::empty()));
527 let mut svc = StreamingSplitterService::new(
528 expr,
529 passthrough_pipeline(),
530 AggregationStrategy::LastWins,
531 true,
532 );
533
534 let mut ex = make_exchange("original");
535 ex.set_property("marker", Value::Bool(true));
536 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
537 assert_eq!(result.input.body.as_text(), Some("original"));
538 assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
539 }
540
541 #[tokio::test]
542 async fn test_streaming_error_in_expression() {
543 let mut svc = StreamingSplitterService::new(
544 error_expression(),
545 passthrough_pipeline(),
546 AggregationStrategy::LastWins,
547 true,
548 );
549
550 let result = svc
551 .ready()
552 .await
553 .unwrap()
554 .call(make_exchange("original"))
555 .await;
556 assert!(result.is_err());
557 }
558
559 #[tokio::test]
560 async fn test_streaming_cancellation() {
561 let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
562 let slow_pipeline = BoxProcessor::from_fn(|ex| {
563 Box::pin(async move {
564 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
565 Ok(ex)
566 })
567 });
568 let svc =
569 StreamingSplitterService::new(expr, slow_pipeline, AggregationStrategy::LastWins, true);
570 svc.cancel();
571
572 let mut svc_clone = svc.clone();
573 let result = svc_clone
574 .ready()
575 .await
576 .unwrap()
577 .call(make_exchange("original"))
578 .await;
579 assert!(result.is_err());
580 }
581
582 #[tokio::test]
583 async fn test_streaming_sequential_collect_all() {
584 let expr = test_expression(vec![
585 make_exchange("a"),
586 make_exchange("b"),
587 make_exchange("c"),
588 ]);
589 let mut svc = StreamingSplitterService::new(
590 expr,
591 uppercase_pipeline(),
592 AggregationStrategy::CollectAll,
593 true,
594 );
595
596 let result = svc
597 .ready()
598 .await
599 .unwrap()
600 .call(make_exchange("original"))
601 .await
602 .unwrap();
603 let expected = serde_json::json!(["A", "B", "C"]);
604 match &result.input.body {
605 Body::Json(v) => assert_eq!(*v, expected),
606 other => panic!("expected JSON body, got {other:?}"),
607 }
608 }
609
610 #[tokio::test]
611 async fn test_streaming_sequential_custom_aggregation() {
612 let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
613 Arc::new(|mut acc: Exchange, next: Exchange| {
614 let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
615 let next_text = next.input.body.as_text().unwrap_or("").to_string();
616 acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
617 acc
618 });
619
620 let expr = test_expression(vec![
621 make_exchange("a"),
622 make_exchange("b"),
623 make_exchange("c"),
624 ]);
625 let mut svc = StreamingSplitterService::new(
626 expr,
627 uppercase_pipeline(),
628 AggregationStrategy::Custom(joiner),
629 true,
630 );
631
632 let result = svc
633 .ready()
634 .await
635 .unwrap()
636 .call(make_exchange("original"))
637 .await
638 .unwrap();
639 assert_eq!(result.input.body.as_text(), Some("A+B+C"));
640 }
641
642 #[tokio::test]
643 async fn test_streaming_error_continue_on_exception() {
644 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
645 let count_clone = call_count.clone();
646 let fail_on_first = BoxProcessor::from_fn(move |ex: Exchange| {
647 let count = count_clone.clone();
648 Box::pin(async move {
649 let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
650 if n == 0 {
651 Err(CamelError::ProcessorError("first fails".into()))
652 } else {
653 Ok(ex)
654 }
655 })
656 });
657
658 let expr = test_expression(vec![make_exchange("a"), make_exchange("b")]);
659 let mut svc = StreamingSplitterService::new(
660 expr,
661 fail_on_first,
662 AggregationStrategy::LastWins,
663 false,
664 );
665
666 let result = svc
667 .ready()
668 .await
669 .unwrap()
670 .call(make_exchange("original"))
671 .await
672 .unwrap();
673 assert_eq!(result.input.body.as_text(), Some("b"));
674 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
675 }
676
677 #[tokio::test]
678 async fn test_streaming_metadata_lookahead() {
679 let recorder = BoxProcessor::from_fn(|ex: Exchange| {
680 Box::pin(async move {
681 let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
682 let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
683 let body = serde_json::json!({
684 "index": idx,
685 "complete": complete,
686 });
687 let mut out = ex;
688 out.input.body = Body::Json(body);
689 Ok(out)
690 })
691 });
692
693 let expr = test_expression(vec![
694 make_exchange("x"),
695 make_exchange("y"),
696 make_exchange("z"),
697 ]);
698 let mut svc =
699 StreamingSplitterService::new(expr, recorder, AggregationStrategy::CollectAll, true);
700
701 let result = svc
702 .ready()
703 .await
704 .unwrap()
705 .call(make_exchange("original"))
706 .await
707 .unwrap();
708 let expected = serde_json::json!([
709 {"index": 0, "complete": false},
710 {"index": 1, "complete": false},
711 {"index": 2, "complete": true},
712 ]);
713 match &result.input.body {
714 Body::Json(v) => assert_eq!(*v, expected),
715 other => panic!("expected JSON body, got {other:?}"),
716 }
717 }
718
719 #[tokio::test]
720 async fn test_streaming_split_sanitizes_stream_body_in_original() {
721 let chunks = vec![Ok(Bytes::from("line1\n"))];
722 let stream = futures::stream::iter(chunks);
723 let sb = StreamBody {
724 stream: Arc::new(Mutex::new(Some(Box::pin(stream)))),
725 metadata: Default::default(),
726 };
727 let ex = Exchange::new(Message::new(Body::Stream(sb)));
728
729 let expression =
730 test_expression(vec![Exchange::new(Message::new(Body::Text("frag".into())))]);
731 let sub_pipeline = passthrough_pipeline();
732 let mut splitter = StreamingSplitterService::new(
733 expression,
734 sub_pipeline,
735 AggregationStrategy::Original,
736 true,
737 );
738
739 let result = splitter
740 .ready()
741 .await
742 .expect("ready")
743 .call(ex)
744 .await
745 .expect("call");
746 assert!(
747 matches!(result.input.body, Body::Empty),
748 "original body should be sanitized to Empty"
749 );
750 }
751}