1use futures::future::join_all;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6use tokio::sync::Semaphore;
7use tokio_util::sync::CancellationToken;
8use tower::Service;
9
10use camel_api::{
11 AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, SplitterConfig, Value,
12};
13
14pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
18pub const CAMEL_SPLIT_SIZE: &str = "CamelSplitSize";
20pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
22
23#[derive(Clone)]
34pub struct SplitterService {
35 expression: camel_api::SplitExpression,
36 sub_pipeline: BoxProcessor,
37 aggregation: AggregationStrategy,
38 parallel: bool,
39 parallel_limit: Option<usize>,
40 stop_on_exception: bool,
41 cancel_token: CancellationToken,
42}
43
44impl SplitterService {
45 pub fn new(config: SplitterConfig, sub_pipeline: BoxProcessor) -> Result<Self, CamelError> {
47 config.validate()?;
48 Ok(Self {
49 expression: config.expression,
50 sub_pipeline,
51 aggregation: config.aggregation,
52 parallel: config.parallel,
53 parallel_limit: config.parallel_limit,
54 stop_on_exception: config.stop_on_exception,
55 cancel_token: CancellationToken::new(),
56 })
57 }
58
59 pub fn cancel(&self) {
61 self.cancel_token.cancel();
62 }
63
64 pub fn is_cancelled(&self) -> bool {
66 self.cancel_token.is_cancelled()
67 }
68}
69
70impl Service<Exchange> for SplitterService {
71 type Response = Exchange;
72 type Error = CamelError;
73 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
74
75 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76 self.sub_pipeline.poll_ready(cx)
77 }
78
79 fn call(&mut self, exchange: Exchange) -> Self::Future {
80 let original = exchange.clone();
81 let expression = self.expression.clone();
82 let sub_pipeline = self.sub_pipeline.clone();
83 let aggregation = self.aggregation.clone();
84 let parallel = self.parallel;
85 let parallel_limit = self.parallel_limit;
86 let stop_on_exception = self.stop_on_exception;
87 let cancel_token = self.cancel_token.clone();
88
89 Box::pin(async move {
90 let mut fragments = expression(&exchange);
92
93 if fragments.is_empty() {
95 return Ok(original);
96 }
97
98 let total = fragments.len();
99
100 for (i, frag) in fragments.iter_mut().enumerate() {
102 frag.set_property(CAMEL_SPLIT_INDEX, Value::from(i as u64));
103 frag.set_property(CAMEL_SPLIT_SIZE, Value::from(total as u64));
104 frag.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(i == total - 1));
105 }
106
107 if cancel_token.is_cancelled() {
109 return Err(CamelError::ProcessorError(
110 "Splitter cancelled, dropping exchange".to_string(),
111 ));
112 }
113
114 let results = if parallel {
116 process_parallel(
117 fragments,
118 sub_pipeline,
119 parallel_limit,
120 stop_on_exception,
121 cancel_token,
122 )
123 .await
124 } else {
125 process_sequential(fragments, sub_pipeline, stop_on_exception).await
126 };
127
128 aggregate(results, original, aggregation)
130 })
131 }
132}
133
134async fn process_sequential(
137 fragments: Vec<Exchange>,
138 sub_pipeline: BoxProcessor,
139 stop_on_exception: bool,
140) -> Vec<Result<Exchange, CamelError>> {
141 let mut results = Vec::with_capacity(fragments.len());
142
143 for fragment in fragments {
144 let mut pipeline = sub_pipeline.clone();
145 match tower::ServiceExt::ready(&mut pipeline).await {
146 Err(e) => {
147 results.push(Err(e));
148 if stop_on_exception {
149 break;
150 }
151 }
152 Ok(svc) => {
153 let result = svc.call(fragment).await;
154 let is_err = result.is_err();
155 results.push(result);
156 if stop_on_exception && is_err {
157 break;
158 }
159 }
160 }
161 }
162
163 results
164}
165
166async fn process_parallel(
169 fragments: Vec<Exchange>,
170 sub_pipeline: BoxProcessor,
171 parallel_limit: Option<usize>,
172 _stop_on_exception: bool,
173 cancel_token: CancellationToken,
174) -> Vec<Result<Exchange, CamelError>> {
175 let semaphore = parallel_limit.map(|limit| Arc::new(Semaphore::new(limit)));
176
177 let futures: Vec<_> = fragments
178 .into_iter()
179 .map(|fragment| {
180 let mut pipeline = sub_pipeline.clone();
181 let sem = semaphore.clone();
182 let cancel = cancel_token.clone();
183 async move {
184 if cancel.is_cancelled() {
186 return Err(CamelError::ProcessorError("Splitter cancelled".to_string()));
187 }
188
189 let _permit = match &sem {
191 Some(s) => {
192 tokio::select! {
193 result = s.acquire() => {
194 Some(result.map_err(|e| {
195 CamelError::ProcessorError(format!("semaphore error: {e}"))
196 })?)
197 }
198 _ = cancel.cancelled() => {
199 return Err(CamelError::ProcessorError(
200 "Splitter cancelled while waiting for semaphore".to_string(),
201 ));
202 }
203 }
204 }
205 None => None,
206 };
207
208 if cancel.is_cancelled() {
210 return Err(CamelError::ProcessorError("Splitter cancelled".to_string()));
211 }
212
213 tokio::select! {
214 result = async {
215 tower::ServiceExt::ready(&mut pipeline).await?;
216 pipeline.call(fragment).await
217 } => result,
218 _ = cancel.cancelled() => {
219 Err(CamelError::ProcessorError(
220 "Splitter cancelled during processing".to_string(),
221 ))
222 }
223 }
224 }
225 })
226 .collect();
227
228 join_all(futures).await
229}
230
231fn aggregate(
234 results: Vec<Result<Exchange, CamelError>>,
235 original: Exchange,
236 strategy: AggregationStrategy,
237) -> Result<Exchange, CamelError> {
238 match strategy {
239 AggregationStrategy::LastWins => {
240 results.into_iter().last().unwrap_or_else(|| Ok(original))
242 }
243 AggregationStrategy::CollectAll => {
244 let mut bodies = Vec::new();
246 for result in results {
247 let ex = result?;
248 let value = match &ex.input.body {
249 Body::Text(s) => Value::String(s.clone()),
250 Body::Json(v) => v.clone(),
251 Body::Xml(s) => Value::String(s.clone()),
252 Body::Bytes(b) => Value::String(String::from_utf8_lossy(b).into_owned()),
253 Body::Empty => Value::Null,
254 Body::Stream(s) => serde_json::json!({
255 "_stream": {
256 "origin": s.metadata.origin,
257 "placeholder": true,
258 "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
259 }
260 }),
261 };
262 bodies.push(value);
263 }
264 let mut out = original;
265 out.input.body = Body::Json(Value::Array(bodies));
266 Ok(out)
267 }
268 AggregationStrategy::Original => Ok(original),
269 AggregationStrategy::Custom(fold_fn) => {
270 let mut iter = results.into_iter();
272 let first = iter.next().unwrap_or_else(|| Ok(original.clone()))?;
273 iter.try_fold(first, |acc, next_result| {
274 let next = next_result?;
275 Ok(fold_fn(acc, next))
276 })
277 }
278 }
279}
280
281#[cfg(test)]
284mod tests {
285 use super::*;
286 use camel_api::{BoxProcessorExt, Message};
287 use std::sync::Arc;
288 use std::sync::atomic::{AtomicUsize, Ordering};
289 use tower::ServiceExt;
290
291 fn passthrough_pipeline() -> BoxProcessor {
294 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
295 }
296
297 fn uppercase_pipeline() -> BoxProcessor {
298 BoxProcessor::from_fn(|mut ex: Exchange| {
299 Box::pin(async move {
300 if let Body::Text(s) = &ex.input.body {
301 ex.input.body = Body::Text(s.to_uppercase());
302 }
303 Ok(ex)
304 })
305 })
306 }
307
308 fn failing_pipeline() -> BoxProcessor {
309 BoxProcessor::from_fn(|_ex| {
310 Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
311 })
312 }
313
314 fn fail_on_nth(n: usize) -> BoxProcessor {
315 let count = Arc::new(AtomicUsize::new(0));
316 BoxProcessor::from_fn(move |ex: Exchange| {
317 let count = Arc::clone(&count);
318 Box::pin(async move {
319 let c = count.fetch_add(1, Ordering::SeqCst);
320 if c == n {
321 Err(CamelError::ProcessorError(format!("fail on {c}")))
322 } else {
323 Ok(ex)
324 }
325 })
326 })
327 }
328
329 fn make_exchange(text: &str) -> Exchange {
330 Exchange::new(Message::new(text))
331 }
332
333 #[test]
334 fn test_splitter_zero_parallel_limit_rejected() {
335 let config = SplitterConfig::new(camel_api::split_body_lines())
336 .parallel(true)
337 .parallel_limit(0);
338 let result = SplitterService::new(config, passthrough_pipeline());
339 assert!(result.is_err(), "zero parallel_limit should return Err");
340 }
341
342 #[tokio::test]
345 async fn test_split_sequential_last_wins() {
346 let config = SplitterConfig::new(camel_api::split_body_lines())
347 .aggregation(AggregationStrategy::LastWins);
348 let mut svc = SplitterService::new(config, uppercase_pipeline()).unwrap();
349
350 let result = svc
351 .ready()
352 .await
353 .unwrap()
354 .call(make_exchange("a\nb\nc"))
355 .await
356 .unwrap();
357 assert_eq!(result.input.body.as_text(), Some("C"));
358 }
359
360 #[tokio::test]
363 async fn test_split_sequential_collect_all() {
364 let config = SplitterConfig::new(camel_api::split_body_lines())
365 .aggregation(AggregationStrategy::CollectAll);
366 let mut svc = SplitterService::new(config, uppercase_pipeline()).unwrap();
367
368 let result = svc
369 .ready()
370 .await
371 .unwrap()
372 .call(make_exchange("a\nb\nc"))
373 .await
374 .unwrap();
375 let expected = serde_json::json!(["A", "B", "C"]);
376 match &result.input.body {
377 Body::Json(v) => assert_eq!(*v, expected),
378 other => panic!("expected JSON body, got {other:?}"),
379 }
380 }
381
382 #[tokio::test]
385 async fn test_split_sequential_original() {
386 let config = SplitterConfig::new(camel_api::split_body_lines())
387 .aggregation(AggregationStrategy::Original);
388 let mut svc = SplitterService::new(config, uppercase_pipeline()).unwrap();
389
390 let result = svc
391 .ready()
392 .await
393 .unwrap()
394 .call(make_exchange("a\nb\nc"))
395 .await
396 .unwrap();
397 assert_eq!(result.input.body.as_text(), Some("a\nb\nc"));
399 }
400
401 #[tokio::test]
404 async fn test_split_sequential_custom_aggregation() {
405 let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
406 Arc::new(|mut acc: Exchange, next: Exchange| {
407 let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
408 let next_text = next.input.body.as_text().unwrap_or("").to_string();
409 acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
410 acc
411 });
412
413 let config = SplitterConfig::new(camel_api::split_body_lines())
414 .aggregation(AggregationStrategy::Custom(joiner));
415 let mut svc = SplitterService::new(config, uppercase_pipeline()).unwrap();
416
417 let result = svc
418 .ready()
419 .await
420 .unwrap()
421 .call(make_exchange("a\nb\nc"))
422 .await
423 .unwrap();
424 assert_eq!(result.input.body.as_text(), Some("A+B+C"));
425 }
426
427 #[tokio::test]
430 async fn test_split_stop_on_exception() {
431 let config = SplitterConfig::new(camel_api::split_body_lines()).stop_on_exception(true);
433 let mut svc = SplitterService::new(config, fail_on_nth(1)).unwrap();
434
435 let result = svc
436 .ready()
437 .await
438 .unwrap()
439 .call(make_exchange("a\nb\nc\nd\ne"))
440 .await;
441
442 assert!(result.is_err(), "expected error due to stop_on_exception");
444 }
445
446 #[tokio::test]
449 async fn test_split_continue_on_exception() {
450 let config = SplitterConfig::new(camel_api::split_body_lines())
452 .stop_on_exception(false)
453 .aggregation(AggregationStrategy::LastWins);
454 let mut svc = SplitterService::new(config, fail_on_nth(1)).unwrap();
455
456 let result = svc
457 .ready()
458 .await
459 .unwrap()
460 .call(make_exchange("a\nb\nc"))
461 .await;
462
463 assert!(result.is_ok(), "last fragment should succeed");
465 }
466
467 #[tokio::test]
470 async fn test_split_empty_fragments() {
471 let config = SplitterConfig::new(camel_api::split_body_lines());
473 let mut svc = SplitterService::new(config, passthrough_pipeline()).unwrap();
474
475 let mut ex = Exchange::new(Message::default()); ex.set_property("marker", Value::Bool(true));
477
478 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
479 assert!(result.input.body.is_empty());
480 assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
481 }
482
483 #[tokio::test]
486 async fn test_split_metadata_properties() {
487 let recorder = BoxProcessor::from_fn(|ex: Exchange| {
491 Box::pin(async move {
492 let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
493 let size = ex.property(CAMEL_SPLIT_SIZE).cloned();
494 let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
495 let body = serde_json::json!({
496 "index": idx,
497 "size": size,
498 "complete": complete,
499 });
500 let mut out = ex;
501 out.input.body = Body::Json(body);
502 Ok(out)
503 })
504 });
505
506 let config = SplitterConfig::new(camel_api::split_body_lines())
507 .aggregation(AggregationStrategy::CollectAll);
508 let mut svc = SplitterService::new(config, recorder).unwrap();
509
510 let result = svc
511 .ready()
512 .await
513 .unwrap()
514 .call(make_exchange("x\ny\nz"))
515 .await
516 .unwrap();
517
518 let expected = serde_json::json!([
519 {"index": 0, "size": 3, "complete": false},
520 {"index": 1, "size": 3, "complete": false},
521 {"index": 2, "size": 3, "complete": true},
522 ]);
523 match &result.input.body {
524 Body::Json(v) => assert_eq!(*v, expected),
525 other => panic!("expected JSON body, got {other:?}"),
526 }
527 }
528
529 #[tokio::test]
532 async fn test_poll_ready_delegates_to_sub_pipeline() {
533 use std::sync::atomic::AtomicBool;
534
535 #[derive(Clone)]
537 struct DelayedReady {
538 ready: Arc<AtomicBool>,
539 }
540
541 impl Service<Exchange> for DelayedReady {
542 type Response = Exchange;
543 type Error = CamelError;
544 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
545
546 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
547 if self.ready.load(Ordering::SeqCst) {
548 Poll::Ready(Ok(()))
549 } else {
550 cx.waker().wake_by_ref();
551 Poll::Pending
552 }
553 }
554
555 fn call(&mut self, exchange: Exchange) -> Self::Future {
556 Box::pin(async move { Ok(exchange) })
557 }
558 }
559
560 let ready_flag = Arc::new(AtomicBool::new(false));
561 let inner = DelayedReady {
562 ready: Arc::clone(&ready_flag),
563 };
564 let boxed: BoxProcessor = BoxProcessor::new(inner);
565
566 let config = SplitterConfig::new(camel_api::split_body_lines());
567 let mut svc = SplitterService::new(config, boxed).unwrap();
568
569 let waker = futures::task::noop_waker();
571 let mut cx = Context::from_waker(&waker);
572 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
573 assert!(
574 poll.is_pending(),
575 "expected Pending when sub_pipeline not ready"
576 );
577
578 ready_flag.store(true, Ordering::SeqCst);
580
581 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
582 assert!(
583 matches!(poll, Poll::Ready(Ok(()))),
584 "expected Ready after sub_pipeline becomes ready"
585 );
586 }
587
588 #[tokio::test]
591 async fn test_split_parallel_basic() {
592 let config = SplitterConfig::new(camel_api::split_body_lines())
593 .parallel(true)
594 .aggregation(AggregationStrategy::CollectAll);
595 let mut svc = SplitterService::new(config, uppercase_pipeline()).unwrap();
596
597 let result = svc
598 .ready()
599 .await
600 .unwrap()
601 .call(make_exchange("a\nb\nc"))
602 .await
603 .unwrap();
604
605 let expected = serde_json::json!(["A", "B", "C"]);
606 match &result.input.body {
607 Body::Json(v) => assert_eq!(*v, expected),
608 other => panic!("expected JSON body, got {other:?}"),
609 }
610 }
611
612 #[tokio::test]
615 async fn test_split_parallel_with_limit() {
616 use std::sync::atomic::AtomicUsize;
617
618 let concurrent = Arc::new(AtomicUsize::new(0));
619 let max_concurrent = Arc::new(AtomicUsize::new(0));
620
621 let c = Arc::clone(&concurrent);
622 let mc = Arc::clone(&max_concurrent);
623 let pipeline = BoxProcessor::from_fn(move |ex: Exchange| {
624 let c = Arc::clone(&c);
625 let mc = Arc::clone(&mc);
626 Box::pin(async move {
627 let current = c.fetch_add(1, Ordering::SeqCst) + 1;
628 mc.fetch_max(current, Ordering::SeqCst);
630 tokio::task::yield_now().await;
632 c.fetch_sub(1, Ordering::SeqCst);
633 Ok(ex)
634 })
635 });
636
637 let config = SplitterConfig::new(camel_api::split_body_lines())
638 .parallel(true)
639 .parallel_limit(2)
640 .aggregation(AggregationStrategy::CollectAll);
641 let mut svc = SplitterService::new(config, pipeline).unwrap();
642
643 let result = svc
644 .ready()
645 .await
646 .unwrap()
647 .call(make_exchange("a\nb\nc\nd"))
648 .await;
649 assert!(result.is_ok());
650
651 let observed_max = max_concurrent.load(Ordering::SeqCst);
652 assert!(
653 observed_max <= 2,
654 "max concurrency was {observed_max}, expected <= 2"
655 );
656 }
657
658 #[tokio::test]
661 async fn test_split_parallel_stop_on_exception() {
662 let config = SplitterConfig::new(camel_api::split_body_lines())
663 .parallel(true)
664 .stop_on_exception(true);
665 let mut svc = SplitterService::new(config, failing_pipeline()).unwrap();
666
667 let result = svc
668 .ready()
669 .await
670 .unwrap()
671 .call(make_exchange("a\nb\nc"))
672 .await;
673
674 assert!(result.is_err(), "expected error when all fragments fail");
676 }
677
678 #[tokio::test]
681 async fn test_splitter_stream_bodies_creates_valid_json() {
682 use bytes::Bytes;
683 use camel_api::{StreamBody, StreamMetadata};
684 use futures::stream;
685 use tokio::sync::Mutex;
686
687 let chunks = vec![Ok(Bytes::from("test"))];
688 let stream_body = StreamBody {
689 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
690 metadata: StreamMetadata {
691 origin: Some("kafka://topic/partition".to_string()),
692 ..Default::default()
693 },
694 };
695
696 let original = Exchange::new(Message {
697 headers: Default::default(),
698 body: Body::Empty,
699 });
700
701 let results = vec![Ok(Exchange::new(Message {
702 headers: Default::default(),
703 body: Body::Stream(stream_body),
704 }))];
705
706 let result = aggregate(results, original, AggregationStrategy::CollectAll);
707
708 let exchange = result.expect("Expected Ok result");
709 assert!(
710 matches!(exchange.input.body, Body::Json(_)),
711 "Expected Json body"
712 );
713
714 if let Body::Json(value) = exchange.input.body {
715 let json_str = serde_json::to_string(&value).unwrap();
716 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
717
718 assert!(parsed.is_array());
719 let arr = parsed.as_array().unwrap();
720 assert!(arr[0].is_object());
721 assert!(arr[0]["_stream"].is_object());
722 assert_eq!(arr[0]["_stream"]["origin"], "kafka://topic/partition");
723 assert_eq!(arr[0]["_stream"]["placeholder"], true);
724 }
725 }
726
727 #[tokio::test]
728 async fn test_splitter_stream_with_none_origin_creates_valid_json() {
729 use bytes::Bytes;
730 use camel_api::{StreamBody, StreamMetadata};
731 use futures::stream;
732 use tokio::sync::Mutex;
733
734 let chunks = vec![Ok(Bytes::from("test"))];
735 let stream_body = StreamBody {
736 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
737 metadata: StreamMetadata {
738 origin: None,
739 ..Default::default()
740 },
741 };
742
743 let original = Exchange::new(Message {
744 headers: Default::default(),
745 body: Body::Empty,
746 });
747
748 let results = vec![Ok(Exchange::new(Message {
749 headers: Default::default(),
750 body: Body::Stream(stream_body),
751 }))];
752
753 let result = aggregate(results, original, AggregationStrategy::CollectAll);
754
755 let exchange = result.expect("Expected Ok result");
756 assert!(
757 matches!(exchange.input.body, Body::Json(_)),
758 "Expected Json body"
759 );
760
761 if let Body::Json(value) = exchange.input.body {
762 let json_str = serde_json::to_string(&value).unwrap();
763 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
764
765 assert!(parsed.is_array());
766 let arr = parsed.as_array().unwrap();
767 assert!(arr[0].is_object());
768 assert!(arr[0]["_stream"].is_object());
769 assert_eq!(arr[0]["_stream"]["origin"], serde_json::Value::Null);
770 assert_eq!(arr[0]["_stream"]["placeholder"], true);
771 }
772 }
773
774 #[tokio::test]
777 async fn test_splitter_parallel_cancel_aborts_processing() {
778 use std::sync::atomic::AtomicBool;
779
780 let started = Arc::new(AtomicBool::new(false));
781
782 let s = Arc::clone(&started);
783 let pipeline = BoxProcessor::from_fn(move |ex: Exchange| {
784 let s = Arc::clone(&s);
785 Box::pin(async move {
786 s.store(true, Ordering::SeqCst);
787 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
789 Ok(ex)
790 })
791 });
792
793 let config = SplitterConfig::new(camel_api::split_body_lines())
794 .parallel(true)
795 .aggregation(AggregationStrategy::LastWins);
796 let svc = SplitterService::new(config, pipeline).unwrap();
797
798 svc.cancel();
800 assert!(svc.is_cancelled());
801
802 let mut svc_clone = svc.clone();
803 let result = svc_clone
804 .ready()
805 .await
806 .unwrap()
807 .call(make_exchange("a\nb\nc"))
808 .await;
809
810 assert!(result.is_err(), "cancelled splitter should return error");
811 }
812}