1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::future::join_all;
6use tokio::sync::Semaphore;
7use tower::Service;
8
9use camel_api::{
10 AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, SplitterConfig, Value,
11};
12
13pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
17pub const CAMEL_SPLIT_SIZE: &str = "CamelSplitSize";
19pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
21
22#[derive(Clone)]
33pub struct SplitterService {
34 expression: camel_api::SplitExpression,
35 sub_pipeline: BoxProcessor,
36 aggregation: AggregationStrategy,
37 parallel: bool,
38 parallel_limit: Option<usize>,
39 stop_on_exception: bool,
40}
41
42impl SplitterService {
43 pub fn new(config: SplitterConfig, sub_pipeline: BoxProcessor) -> Self {
45 if config.parallel
46 && let Some(limit) = config.parallel_limit
47 {
48 assert!(limit > 0, "parallel_limit must be > 0");
49 }
50 Self {
51 expression: config.expression,
52 sub_pipeline,
53 aggregation: config.aggregation,
54 parallel: config.parallel,
55 parallel_limit: config.parallel_limit,
56 stop_on_exception: config.stop_on_exception,
57 }
58 }
59}
60
61impl Service<Exchange> for SplitterService {
62 type Response = Exchange;
63 type Error = CamelError;
64 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
65
66 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67 self.sub_pipeline.poll_ready(cx)
68 }
69
70 fn call(&mut self, exchange: Exchange) -> Self::Future {
71 let original = exchange.clone();
72 let expression = self.expression.clone();
73 let sub_pipeline = self.sub_pipeline.clone();
74 let aggregation = self.aggregation.clone();
75 let parallel = self.parallel;
76 let parallel_limit = self.parallel_limit;
77 let stop_on_exception = self.stop_on_exception;
78
79 Box::pin(async move {
80 let mut fragments = expression(&exchange);
82
83 if fragments.is_empty() {
85 return Ok(original);
86 }
87
88 let total = fragments.len();
89
90 for (i, frag) in fragments.iter_mut().enumerate() {
92 frag.set_property(CAMEL_SPLIT_INDEX, Value::from(i as u64));
93 frag.set_property(CAMEL_SPLIT_SIZE, Value::from(total as u64));
94 frag.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(i == total - 1));
95 }
96
97 let results = if parallel {
99 process_parallel(fragments, sub_pipeline, parallel_limit, stop_on_exception).await
100 } else {
101 process_sequential(fragments, sub_pipeline, stop_on_exception).await
102 };
103
104 aggregate(results, original, aggregation)
106 })
107 }
108}
109
110async fn process_sequential(
113 fragments: Vec<Exchange>,
114 sub_pipeline: BoxProcessor,
115 stop_on_exception: bool,
116) -> Vec<Result<Exchange, CamelError>> {
117 let mut results = Vec::with_capacity(fragments.len());
118
119 for fragment in fragments {
120 let mut pipeline = sub_pipeline.clone();
121 match tower::ServiceExt::ready(&mut pipeline).await {
122 Err(e) => {
123 results.push(Err(e));
124 if stop_on_exception {
125 break;
126 }
127 }
128 Ok(svc) => {
129 let result = svc.call(fragment).await;
130 let is_err = result.is_err();
131 results.push(result);
132 if stop_on_exception && is_err {
133 break;
134 }
135 }
136 }
137 }
138
139 results
140}
141
142async fn process_parallel(
145 fragments: Vec<Exchange>,
146 sub_pipeline: BoxProcessor,
147 parallel_limit: Option<usize>,
148 _stop_on_exception: bool,
149) -> Vec<Result<Exchange, CamelError>> {
150 let semaphore = parallel_limit.map(|limit| std::sync::Arc::new(Semaphore::new(limit)));
151
152 let futures: Vec<_> = fragments
153 .into_iter()
154 .map(|fragment| {
155 let mut pipeline = sub_pipeline.clone();
156 let sem = semaphore.clone();
157 async move {
158 let _permit = match &sem {
160 Some(s) => Some(s.acquire().await.map_err(|e| {
161 CamelError::ProcessorError(format!("semaphore error: {e}"))
162 })?),
163 None => None,
164 };
165
166 tower::ServiceExt::ready(&mut pipeline).await?;
167 pipeline.call(fragment).await
168 }
169 })
170 .collect();
171
172 join_all(futures).await
173}
174
175fn aggregate(
178 results: Vec<Result<Exchange, CamelError>>,
179 original: Exchange,
180 strategy: AggregationStrategy,
181) -> Result<Exchange, CamelError> {
182 match strategy {
183 AggregationStrategy::LastWins => {
184 results.into_iter().last().unwrap_or_else(|| Ok(original))
186 }
187 AggregationStrategy::CollectAll => {
188 let mut bodies = Vec::new();
190 for result in results {
191 let ex = result?;
192 let value = match &ex.input.body {
193 Body::Text(s) => Value::String(s.clone()),
194 Body::Json(v) => v.clone(),
195 Body::Xml(s) => Value::String(s.clone()),
196 Body::Bytes(b) => Value::String(String::from_utf8_lossy(b).into_owned()),
197 Body::Empty => Value::Null,
198 Body::Stream(s) => serde_json::json!({
199 "_stream": {
200 "origin": s.metadata.origin,
201 "placeholder": true,
202 "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
203 }
204 }),
205 };
206 bodies.push(value);
207 }
208 let mut out = original;
209 out.input.body = Body::Json(Value::Array(bodies));
210 Ok(out)
211 }
212 AggregationStrategy::Original => Ok(original),
213 AggregationStrategy::Custom(fold_fn) => {
214 let mut iter = results.into_iter();
216 let first = iter.next().unwrap_or_else(|| Ok(original.clone()))?;
217 iter.try_fold(first, |acc, next_result| {
218 let next = next_result?;
219 Ok(fold_fn(acc, next))
220 })
221 }
222 }
223}
224
225#[cfg(test)]
228mod tests {
229 use super::*;
230 use camel_api::{BoxProcessorExt, Message};
231 use std::sync::Arc;
232 use std::sync::atomic::{AtomicUsize, Ordering};
233 use tower::ServiceExt;
234
235 fn passthrough_pipeline() -> BoxProcessor {
238 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
239 }
240
241 fn uppercase_pipeline() -> BoxProcessor {
242 BoxProcessor::from_fn(|mut ex: Exchange| {
243 Box::pin(async move {
244 if let Body::Text(s) = &ex.input.body {
245 ex.input.body = Body::Text(s.to_uppercase());
246 }
247 Ok(ex)
248 })
249 })
250 }
251
252 fn failing_pipeline() -> BoxProcessor {
253 BoxProcessor::from_fn(|_ex| {
254 Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
255 })
256 }
257
258 fn fail_on_nth(n: usize) -> BoxProcessor {
259 let count = Arc::new(AtomicUsize::new(0));
260 BoxProcessor::from_fn(move |ex: Exchange| {
261 let count = Arc::clone(&count);
262 Box::pin(async move {
263 let c = count.fetch_add(1, Ordering::SeqCst);
264 if c == n {
265 Err(CamelError::ProcessorError(format!("fail on {c}")))
266 } else {
267 Ok(ex)
268 }
269 })
270 })
271 }
272
273 fn make_exchange(text: &str) -> Exchange {
274 Exchange::new(Message::new(text))
275 }
276
277 #[test]
278 fn test_splitter_zero_parallel_limit_rejected() {
279 let config = SplitterConfig::new(camel_api::split_body_lines())
280 .parallel(true)
281 .parallel_limit(0);
282 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
283 SplitterService::new(config, passthrough_pipeline());
284 }));
285 assert!(result.is_err(), "zero parallel_limit should panic");
286 }
287
288 #[tokio::test]
291 async fn test_split_sequential_last_wins() {
292 let config = SplitterConfig::new(camel_api::split_body_lines())
293 .aggregation(AggregationStrategy::LastWins);
294 let mut svc = SplitterService::new(config, uppercase_pipeline());
295
296 let result = svc
297 .ready()
298 .await
299 .unwrap()
300 .call(make_exchange("a\nb\nc"))
301 .await
302 .unwrap();
303 assert_eq!(result.input.body.as_text(), Some("C"));
304 }
305
306 #[tokio::test]
309 async fn test_split_sequential_collect_all() {
310 let config = SplitterConfig::new(camel_api::split_body_lines())
311 .aggregation(AggregationStrategy::CollectAll);
312 let mut svc = SplitterService::new(config, uppercase_pipeline());
313
314 let result = svc
315 .ready()
316 .await
317 .unwrap()
318 .call(make_exchange("a\nb\nc"))
319 .await
320 .unwrap();
321 let expected = serde_json::json!(["A", "B", "C"]);
322 match &result.input.body {
323 Body::Json(v) => assert_eq!(*v, expected),
324 other => panic!("expected JSON body, got {other:?}"),
325 }
326 }
327
328 #[tokio::test]
331 async fn test_split_sequential_original() {
332 let config = SplitterConfig::new(camel_api::split_body_lines())
333 .aggregation(AggregationStrategy::Original);
334 let mut svc = SplitterService::new(config, uppercase_pipeline());
335
336 let result = svc
337 .ready()
338 .await
339 .unwrap()
340 .call(make_exchange("a\nb\nc"))
341 .await
342 .unwrap();
343 assert_eq!(result.input.body.as_text(), Some("a\nb\nc"));
345 }
346
347 #[tokio::test]
350 async fn test_split_sequential_custom_aggregation() {
351 let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
352 Arc::new(|mut acc: Exchange, next: Exchange| {
353 let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
354 let next_text = next.input.body.as_text().unwrap_or("").to_string();
355 acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
356 acc
357 });
358
359 let config = SplitterConfig::new(camel_api::split_body_lines())
360 .aggregation(AggregationStrategy::Custom(joiner));
361 let mut svc = SplitterService::new(config, uppercase_pipeline());
362
363 let result = svc
364 .ready()
365 .await
366 .unwrap()
367 .call(make_exchange("a\nb\nc"))
368 .await
369 .unwrap();
370 assert_eq!(result.input.body.as_text(), Some("A+B+C"));
371 }
372
373 #[tokio::test]
376 async fn test_split_stop_on_exception() {
377 let config = SplitterConfig::new(camel_api::split_body_lines()).stop_on_exception(true);
379 let mut svc = SplitterService::new(config, fail_on_nth(1));
380
381 let result = svc
382 .ready()
383 .await
384 .unwrap()
385 .call(make_exchange("a\nb\nc\nd\ne"))
386 .await;
387
388 assert!(result.is_err(), "expected error due to stop_on_exception");
390 }
391
392 #[tokio::test]
395 async fn test_split_continue_on_exception() {
396 let config = SplitterConfig::new(camel_api::split_body_lines())
398 .stop_on_exception(false)
399 .aggregation(AggregationStrategy::LastWins);
400 let mut svc = SplitterService::new(config, fail_on_nth(1));
401
402 let result = svc
403 .ready()
404 .await
405 .unwrap()
406 .call(make_exchange("a\nb\nc"))
407 .await;
408
409 assert!(result.is_ok(), "last fragment should succeed");
411 }
412
413 #[tokio::test]
416 async fn test_split_empty_fragments() {
417 let config = SplitterConfig::new(camel_api::split_body_lines());
419 let mut svc = SplitterService::new(config, passthrough_pipeline());
420
421 let mut ex = Exchange::new(Message::default()); ex.set_property("marker", Value::Bool(true));
423
424 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
425 assert!(result.input.body.is_empty());
426 assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
427 }
428
429 #[tokio::test]
432 async fn test_split_metadata_properties() {
433 let recorder = BoxProcessor::from_fn(|ex: Exchange| {
437 Box::pin(async move {
438 let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
439 let size = ex.property(CAMEL_SPLIT_SIZE).cloned();
440 let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
441 let body = serde_json::json!({
442 "index": idx,
443 "size": size,
444 "complete": complete,
445 });
446 let mut out = ex;
447 out.input.body = Body::Json(body);
448 Ok(out)
449 })
450 });
451
452 let config = SplitterConfig::new(camel_api::split_body_lines())
453 .aggregation(AggregationStrategy::CollectAll);
454 let mut svc = SplitterService::new(config, recorder);
455
456 let result = svc
457 .ready()
458 .await
459 .unwrap()
460 .call(make_exchange("x\ny\nz"))
461 .await
462 .unwrap();
463
464 let expected = serde_json::json!([
465 {"index": 0, "size": 3, "complete": false},
466 {"index": 1, "size": 3, "complete": false},
467 {"index": 2, "size": 3, "complete": true},
468 ]);
469 match &result.input.body {
470 Body::Json(v) => assert_eq!(*v, expected),
471 other => panic!("expected JSON body, got {other:?}"),
472 }
473 }
474
475 #[tokio::test]
478 async fn test_poll_ready_delegates_to_sub_pipeline() {
479 use std::sync::atomic::AtomicBool;
480
481 #[derive(Clone)]
483 struct DelayedReady {
484 ready: Arc<AtomicBool>,
485 }
486
487 impl Service<Exchange> for DelayedReady {
488 type Response = Exchange;
489 type Error = CamelError;
490 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
491
492 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
493 if self.ready.load(Ordering::SeqCst) {
494 Poll::Ready(Ok(()))
495 } else {
496 cx.waker().wake_by_ref();
497 Poll::Pending
498 }
499 }
500
501 fn call(&mut self, exchange: Exchange) -> Self::Future {
502 Box::pin(async move { Ok(exchange) })
503 }
504 }
505
506 let ready_flag = Arc::new(AtomicBool::new(false));
507 let inner = DelayedReady {
508 ready: Arc::clone(&ready_flag),
509 };
510 let boxed: BoxProcessor = BoxProcessor::new(inner);
511
512 let config = SplitterConfig::new(camel_api::split_body_lines());
513 let mut svc = SplitterService::new(config, boxed);
514
515 let waker = futures::task::noop_waker();
517 let mut cx = Context::from_waker(&waker);
518 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
519 assert!(
520 poll.is_pending(),
521 "expected Pending when sub_pipeline not ready"
522 );
523
524 ready_flag.store(true, Ordering::SeqCst);
526
527 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
528 assert!(
529 matches!(poll, Poll::Ready(Ok(()))),
530 "expected Ready after sub_pipeline becomes ready"
531 );
532 }
533
534 #[tokio::test]
537 async fn test_split_parallel_basic() {
538 let config = SplitterConfig::new(camel_api::split_body_lines())
539 .parallel(true)
540 .aggregation(AggregationStrategy::CollectAll);
541 let mut svc = SplitterService::new(config, uppercase_pipeline());
542
543 let result = svc
544 .ready()
545 .await
546 .unwrap()
547 .call(make_exchange("a\nb\nc"))
548 .await
549 .unwrap();
550
551 let expected = serde_json::json!(["A", "B", "C"]);
552 match &result.input.body {
553 Body::Json(v) => assert_eq!(*v, expected),
554 other => panic!("expected JSON body, got {other:?}"),
555 }
556 }
557
558 #[tokio::test]
561 async fn test_split_parallel_with_limit() {
562 use std::sync::atomic::AtomicUsize;
563
564 let concurrent = Arc::new(AtomicUsize::new(0));
565 let max_concurrent = Arc::new(AtomicUsize::new(0));
566
567 let c = Arc::clone(&concurrent);
568 let mc = Arc::clone(&max_concurrent);
569 let pipeline = BoxProcessor::from_fn(move |ex: Exchange| {
570 let c = Arc::clone(&c);
571 let mc = Arc::clone(&mc);
572 Box::pin(async move {
573 let current = c.fetch_add(1, Ordering::SeqCst) + 1;
574 mc.fetch_max(current, Ordering::SeqCst);
576 tokio::task::yield_now().await;
578 c.fetch_sub(1, Ordering::SeqCst);
579 Ok(ex)
580 })
581 });
582
583 let config = SplitterConfig::new(camel_api::split_body_lines())
584 .parallel(true)
585 .parallel_limit(2)
586 .aggregation(AggregationStrategy::CollectAll);
587 let mut svc = SplitterService::new(config, pipeline);
588
589 let result = svc
590 .ready()
591 .await
592 .unwrap()
593 .call(make_exchange("a\nb\nc\nd"))
594 .await;
595 assert!(result.is_ok());
596
597 let observed_max = max_concurrent.load(Ordering::SeqCst);
598 assert!(
599 observed_max <= 2,
600 "max concurrency was {observed_max}, expected <= 2"
601 );
602 }
603
604 #[tokio::test]
607 async fn test_split_parallel_stop_on_exception() {
608 let config = SplitterConfig::new(camel_api::split_body_lines())
609 .parallel(true)
610 .stop_on_exception(true);
611 let mut svc = SplitterService::new(config, failing_pipeline());
612
613 let result = svc
614 .ready()
615 .await
616 .unwrap()
617 .call(make_exchange("a\nb\nc"))
618 .await;
619
620 assert!(result.is_err(), "expected error when all fragments fail");
622 }
623
624 #[tokio::test]
627 async fn test_splitter_stream_bodies_creates_valid_json() {
628 use bytes::Bytes;
629 use camel_api::{StreamBody, StreamMetadata};
630 use futures::stream;
631 use tokio::sync::Mutex;
632
633 let chunks = vec![Ok(Bytes::from("test"))];
634 let stream_body = StreamBody {
635 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
636 metadata: StreamMetadata {
637 origin: Some("kafka://topic/partition".to_string()),
638 ..Default::default()
639 },
640 };
641
642 let original = Exchange::new(Message {
643 headers: Default::default(),
644 body: Body::Empty,
645 });
646
647 let results = vec![Ok(Exchange::new(Message {
648 headers: Default::default(),
649 body: Body::Stream(stream_body),
650 }))];
651
652 let result = aggregate(results, original, AggregationStrategy::CollectAll);
653
654 let exchange = result.expect("Expected Ok result");
655 assert!(
656 matches!(exchange.input.body, Body::Json(_)),
657 "Expected Json body"
658 );
659
660 if let Body::Json(value) = exchange.input.body {
661 let json_str = serde_json::to_string(&value).unwrap();
662 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
663
664 assert!(parsed.is_array());
665 let arr = parsed.as_array().unwrap();
666 assert!(arr[0].is_object());
667 assert!(arr[0]["_stream"].is_object());
668 assert_eq!(arr[0]["_stream"]["origin"], "kafka://topic/partition");
669 assert_eq!(arr[0]["_stream"]["placeholder"], true);
670 }
671 }
672
673 #[tokio::test]
674 async fn test_splitter_stream_with_none_origin_creates_valid_json() {
675 use bytes::Bytes;
676 use camel_api::{StreamBody, StreamMetadata};
677 use futures::stream;
678 use tokio::sync::Mutex;
679
680 let chunks = vec![Ok(Bytes::from("test"))];
681 let stream_body = StreamBody {
682 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
683 metadata: StreamMetadata {
684 origin: None,
685 ..Default::default()
686 },
687 };
688
689 let original = Exchange::new(Message {
690 headers: Default::default(),
691 body: Body::Empty,
692 });
693
694 let results = vec![Ok(Exchange::new(Message {
695 headers: Default::default(),
696 body: Body::Stream(stream_body),
697 }))];
698
699 let result = aggregate(results, original, AggregationStrategy::CollectAll);
700
701 let exchange = result.expect("Expected Ok result");
702 assert!(
703 matches!(exchange.input.body, Body::Json(_)),
704 "Expected Json body"
705 );
706
707 if let Body::Json(value) = exchange.input.body {
708 let json_str = serde_json::to_string(&value).unwrap();
709 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
710
711 assert!(parsed.is_array());
712 let arr = parsed.as_array().unwrap();
713 assert!(arr[0].is_object());
714 assert!(arr[0]["_stream"].is_object());
715 assert_eq!(arr[0]["_stream"]["origin"], serde_json::Value::Null);
716 assert_eq!(arr[0]["_stream"]["placeholder"], true);
717 }
718 }
719}