1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::future::join_all;
6use tokio::sync::Semaphore;
7use tower::Service;
8
9use camel_api::{
10 AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, SplitterConfig, Value,
11};
12
13pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
17pub const CAMEL_SPLIT_SIZE: &str = "CamelSplitSize";
19pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
21
22#[derive(Clone)]
33pub struct SplitterService {
34 expression: camel_api::SplitExpression,
35 sub_pipeline: BoxProcessor,
36 aggregation: AggregationStrategy,
37 parallel: bool,
38 parallel_limit: Option<usize>,
39 stop_on_exception: bool,
40}
41
42impl SplitterService {
43 pub fn new(config: SplitterConfig, sub_pipeline: BoxProcessor) -> Self {
45 Self {
46 expression: config.expression,
47 sub_pipeline,
48 aggregation: config.aggregation,
49 parallel: config.parallel,
50 parallel_limit: config.parallel_limit,
51 stop_on_exception: config.stop_on_exception,
52 }
53 }
54}
55
56impl Service<Exchange> for SplitterService {
57 type Response = Exchange;
58 type Error = CamelError;
59 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
60
61 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62 self.sub_pipeline.poll_ready(cx)
63 }
64
65 fn call(&mut self, exchange: Exchange) -> Self::Future {
66 let original = exchange.clone();
67 let expression = self.expression.clone();
68 let sub_pipeline = self.sub_pipeline.clone();
69 let aggregation = self.aggregation.clone();
70 let parallel = self.parallel;
71 let parallel_limit = self.parallel_limit;
72 let stop_on_exception = self.stop_on_exception;
73
74 Box::pin(async move {
75 let mut fragments = expression(&exchange);
77
78 if fragments.is_empty() {
80 return Ok(original);
81 }
82
83 let total = fragments.len();
84
85 for (i, frag) in fragments.iter_mut().enumerate() {
87 frag.set_property(CAMEL_SPLIT_INDEX, Value::from(i as u64));
88 frag.set_property(CAMEL_SPLIT_SIZE, Value::from(total as u64));
89 frag.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(i == total - 1));
90 }
91
92 let results = if parallel {
94 process_parallel(fragments, sub_pipeline, parallel_limit, stop_on_exception).await
95 } else {
96 process_sequential(fragments, sub_pipeline, stop_on_exception).await
97 };
98
99 aggregate(results, original, aggregation)
101 })
102 }
103}
104
105async fn process_sequential(
108 fragments: Vec<Exchange>,
109 sub_pipeline: BoxProcessor,
110 stop_on_exception: bool,
111) -> Vec<Result<Exchange, CamelError>> {
112 let mut results = Vec::with_capacity(fragments.len());
113
114 for fragment in fragments {
115 let mut pipeline = sub_pipeline.clone();
116 match tower::ServiceExt::ready(&mut pipeline).await {
117 Err(e) => {
118 results.push(Err(e));
119 if stop_on_exception {
120 break;
121 }
122 }
123 Ok(svc) => {
124 let result = svc.call(fragment).await;
125 let is_err = result.is_err();
126 results.push(result);
127 if stop_on_exception && is_err {
128 break;
129 }
130 }
131 }
132 }
133
134 results
135}
136
137async fn process_parallel(
140 fragments: Vec<Exchange>,
141 sub_pipeline: BoxProcessor,
142 parallel_limit: Option<usize>,
143 _stop_on_exception: bool,
144) -> Vec<Result<Exchange, CamelError>> {
145 let semaphore = parallel_limit.map(|limit| std::sync::Arc::new(Semaphore::new(limit)));
146
147 let futures: Vec<_> = fragments
148 .into_iter()
149 .map(|fragment| {
150 let mut pipeline = sub_pipeline.clone();
151 let sem = semaphore.clone();
152 async move {
153 let _permit = match &sem {
155 Some(s) => Some(s.acquire().await.map_err(|e| {
156 CamelError::ProcessorError(format!("semaphore error: {e}"))
157 })?),
158 None => None,
159 };
160
161 tower::ServiceExt::ready(&mut pipeline).await?;
162 pipeline.call(fragment).await
163 }
164 })
165 .collect();
166
167 join_all(futures).await
168}
169
170fn aggregate(
173 results: Vec<Result<Exchange, CamelError>>,
174 original: Exchange,
175 strategy: AggregationStrategy,
176) -> Result<Exchange, CamelError> {
177 match strategy {
178 AggregationStrategy::LastWins => {
179 results.into_iter().last().unwrap_or_else(|| Ok(original))
181 }
182 AggregationStrategy::CollectAll => {
183 let mut bodies = Vec::new();
185 for result in results {
186 let ex = result?;
187 let value = match &ex.input.body {
188 Body::Text(s) => Value::String(s.clone()),
189 Body::Json(v) => v.clone(),
190 Body::Bytes(b) => Value::String(String::from_utf8_lossy(b).into_owned()),
191 Body::Empty => Value::Null,
192 Body::Stream(s) => serde_json::json!({
193 "_stream": {
194 "origin": s.metadata.origin,
195 "placeholder": true,
196 "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
197 }
198 }),
199 };
200 bodies.push(value);
201 }
202 let mut out = original;
203 out.input.body = Body::Json(Value::Array(bodies));
204 Ok(out)
205 }
206 AggregationStrategy::Original => Ok(original),
207 AggregationStrategy::Custom(fold_fn) => {
208 let mut iter = results.into_iter();
210 let first = iter.next().unwrap_or_else(|| Ok(original.clone()))?;
211 iter.try_fold(first, |acc, next_result| {
212 let next = next_result?;
213 Ok(fold_fn(acc, next))
214 })
215 }
216 }
217}
218
219#[cfg(test)]
222mod tests {
223 use super::*;
224 use camel_api::{BoxProcessorExt, Message};
225 use std::sync::Arc;
226 use std::sync::atomic::{AtomicUsize, Ordering};
227 use tower::ServiceExt;
228
229 fn passthrough_pipeline() -> BoxProcessor {
232 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
233 }
234
235 fn uppercase_pipeline() -> BoxProcessor {
236 BoxProcessor::from_fn(|mut ex: Exchange| {
237 Box::pin(async move {
238 if let Body::Text(s) = &ex.input.body {
239 ex.input.body = Body::Text(s.to_uppercase());
240 }
241 Ok(ex)
242 })
243 })
244 }
245
246 fn failing_pipeline() -> BoxProcessor {
247 BoxProcessor::from_fn(|_ex| {
248 Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
249 })
250 }
251
252 fn fail_on_nth(n: usize) -> BoxProcessor {
253 let count = Arc::new(AtomicUsize::new(0));
254 BoxProcessor::from_fn(move |ex: Exchange| {
255 let count = Arc::clone(&count);
256 Box::pin(async move {
257 let c = count.fetch_add(1, Ordering::SeqCst);
258 if c == n {
259 Err(CamelError::ProcessorError(format!("fail on {c}")))
260 } else {
261 Ok(ex)
262 }
263 })
264 })
265 }
266
267 fn make_exchange(text: &str) -> Exchange {
268 Exchange::new(Message::new(text))
269 }
270
271 #[tokio::test]
274 async fn test_split_sequential_last_wins() {
275 let config = SplitterConfig::new(camel_api::split_body_lines())
276 .aggregation(AggregationStrategy::LastWins);
277 let mut svc = SplitterService::new(config, uppercase_pipeline());
278
279 let result = svc
280 .ready()
281 .await
282 .unwrap()
283 .call(make_exchange("a\nb\nc"))
284 .await
285 .unwrap();
286 assert_eq!(result.input.body.as_text(), Some("C"));
287 }
288
289 #[tokio::test]
292 async fn test_split_sequential_collect_all() {
293 let config = SplitterConfig::new(camel_api::split_body_lines())
294 .aggregation(AggregationStrategy::CollectAll);
295 let mut svc = SplitterService::new(config, uppercase_pipeline());
296
297 let result = svc
298 .ready()
299 .await
300 .unwrap()
301 .call(make_exchange("a\nb\nc"))
302 .await
303 .unwrap();
304 let expected = serde_json::json!(["A", "B", "C"]);
305 match &result.input.body {
306 Body::Json(v) => assert_eq!(*v, expected),
307 other => panic!("expected JSON body, got {other:?}"),
308 }
309 }
310
311 #[tokio::test]
314 async fn test_split_sequential_original() {
315 let config = SplitterConfig::new(camel_api::split_body_lines())
316 .aggregation(AggregationStrategy::Original);
317 let mut svc = SplitterService::new(config, uppercase_pipeline());
318
319 let result = svc
320 .ready()
321 .await
322 .unwrap()
323 .call(make_exchange("a\nb\nc"))
324 .await
325 .unwrap();
326 assert_eq!(result.input.body.as_text(), Some("a\nb\nc"));
328 }
329
330 #[tokio::test]
333 async fn test_split_sequential_custom_aggregation() {
334 let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
335 Arc::new(|mut acc: Exchange, next: Exchange| {
336 let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
337 let next_text = next.input.body.as_text().unwrap_or("").to_string();
338 acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
339 acc
340 });
341
342 let config = SplitterConfig::new(camel_api::split_body_lines())
343 .aggregation(AggregationStrategy::Custom(joiner));
344 let mut svc = SplitterService::new(config, uppercase_pipeline());
345
346 let result = svc
347 .ready()
348 .await
349 .unwrap()
350 .call(make_exchange("a\nb\nc"))
351 .await
352 .unwrap();
353 assert_eq!(result.input.body.as_text(), Some("A+B+C"));
354 }
355
356 #[tokio::test]
359 async fn test_split_stop_on_exception() {
360 let config = SplitterConfig::new(camel_api::split_body_lines()).stop_on_exception(true);
362 let mut svc = SplitterService::new(config, fail_on_nth(1));
363
364 let result = svc
365 .ready()
366 .await
367 .unwrap()
368 .call(make_exchange("a\nb\nc\nd\ne"))
369 .await;
370
371 assert!(result.is_err(), "expected error due to stop_on_exception");
373 }
374
375 #[tokio::test]
378 async fn test_split_continue_on_exception() {
379 let config = SplitterConfig::new(camel_api::split_body_lines())
381 .stop_on_exception(false)
382 .aggregation(AggregationStrategy::LastWins);
383 let mut svc = SplitterService::new(config, fail_on_nth(1));
384
385 let result = svc
386 .ready()
387 .await
388 .unwrap()
389 .call(make_exchange("a\nb\nc"))
390 .await;
391
392 assert!(result.is_ok(), "last fragment should succeed");
394 }
395
396 #[tokio::test]
399 async fn test_split_empty_fragments() {
400 let config = SplitterConfig::new(camel_api::split_body_lines());
402 let mut svc = SplitterService::new(config, passthrough_pipeline());
403
404 let mut ex = Exchange::new(Message::default()); ex.set_property("marker", Value::Bool(true));
406
407 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
408 assert!(result.input.body.is_empty());
409 assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
410 }
411
412 #[tokio::test]
415 async fn test_split_metadata_properties() {
416 let recorder = BoxProcessor::from_fn(|ex: Exchange| {
420 Box::pin(async move {
421 let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
422 let size = ex.property(CAMEL_SPLIT_SIZE).cloned();
423 let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
424 let body = serde_json::json!({
425 "index": idx,
426 "size": size,
427 "complete": complete,
428 });
429 let mut out = ex;
430 out.input.body = Body::Json(body);
431 Ok(out)
432 })
433 });
434
435 let config = SplitterConfig::new(camel_api::split_body_lines())
436 .aggregation(AggregationStrategy::CollectAll);
437 let mut svc = SplitterService::new(config, recorder);
438
439 let result = svc
440 .ready()
441 .await
442 .unwrap()
443 .call(make_exchange("x\ny\nz"))
444 .await
445 .unwrap();
446
447 let expected = serde_json::json!([
448 {"index": 0, "size": 3, "complete": false},
449 {"index": 1, "size": 3, "complete": false},
450 {"index": 2, "size": 3, "complete": true},
451 ]);
452 match &result.input.body {
453 Body::Json(v) => assert_eq!(*v, expected),
454 other => panic!("expected JSON body, got {other:?}"),
455 }
456 }
457
458 #[tokio::test]
461 async fn test_poll_ready_delegates_to_sub_pipeline() {
462 use std::sync::atomic::AtomicBool;
463
464 #[derive(Clone)]
466 struct DelayedReady {
467 ready: Arc<AtomicBool>,
468 }
469
470 impl Service<Exchange> for DelayedReady {
471 type Response = Exchange;
472 type Error = CamelError;
473 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
474
475 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
476 if self.ready.load(Ordering::SeqCst) {
477 Poll::Ready(Ok(()))
478 } else {
479 cx.waker().wake_by_ref();
480 Poll::Pending
481 }
482 }
483
484 fn call(&mut self, exchange: Exchange) -> Self::Future {
485 Box::pin(async move { Ok(exchange) })
486 }
487 }
488
489 let ready_flag = Arc::new(AtomicBool::new(false));
490 let inner = DelayedReady {
491 ready: Arc::clone(&ready_flag),
492 };
493 let boxed: BoxProcessor = BoxProcessor::new(inner);
494
495 let config = SplitterConfig::new(camel_api::split_body_lines());
496 let mut svc = SplitterService::new(config, boxed);
497
498 let waker = futures::task::noop_waker();
500 let mut cx = Context::from_waker(&waker);
501 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
502 assert!(
503 poll.is_pending(),
504 "expected Pending when sub_pipeline not ready"
505 );
506
507 ready_flag.store(true, Ordering::SeqCst);
509
510 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
511 assert!(
512 matches!(poll, Poll::Ready(Ok(()))),
513 "expected Ready after sub_pipeline becomes ready"
514 );
515 }
516
517 #[tokio::test]
520 async fn test_split_parallel_basic() {
521 let config = SplitterConfig::new(camel_api::split_body_lines())
522 .parallel(true)
523 .aggregation(AggregationStrategy::CollectAll);
524 let mut svc = SplitterService::new(config, uppercase_pipeline());
525
526 let result = svc
527 .ready()
528 .await
529 .unwrap()
530 .call(make_exchange("a\nb\nc"))
531 .await
532 .unwrap();
533
534 let expected = serde_json::json!(["A", "B", "C"]);
535 match &result.input.body {
536 Body::Json(v) => assert_eq!(*v, expected),
537 other => panic!("expected JSON body, got {other:?}"),
538 }
539 }
540
541 #[tokio::test]
544 async fn test_split_parallel_with_limit() {
545 use std::sync::atomic::AtomicUsize;
546
547 let concurrent = Arc::new(AtomicUsize::new(0));
548 let max_concurrent = Arc::new(AtomicUsize::new(0));
549
550 let c = Arc::clone(&concurrent);
551 let mc = Arc::clone(&max_concurrent);
552 let pipeline = BoxProcessor::from_fn(move |ex: Exchange| {
553 let c = Arc::clone(&c);
554 let mc = Arc::clone(&mc);
555 Box::pin(async move {
556 let current = c.fetch_add(1, Ordering::SeqCst) + 1;
557 mc.fetch_max(current, Ordering::SeqCst);
559 tokio::task::yield_now().await;
561 c.fetch_sub(1, Ordering::SeqCst);
562 Ok(ex)
563 })
564 });
565
566 let config = SplitterConfig::new(camel_api::split_body_lines())
567 .parallel(true)
568 .parallel_limit(2)
569 .aggregation(AggregationStrategy::CollectAll);
570 let mut svc = SplitterService::new(config, pipeline);
571
572 let result = svc
573 .ready()
574 .await
575 .unwrap()
576 .call(make_exchange("a\nb\nc\nd"))
577 .await;
578 assert!(result.is_ok());
579
580 let observed_max = max_concurrent.load(Ordering::SeqCst);
581 assert!(
582 observed_max <= 2,
583 "max concurrency was {observed_max}, expected <= 2"
584 );
585 }
586
587 #[tokio::test]
590 async fn test_split_parallel_stop_on_exception() {
591 let config = SplitterConfig::new(camel_api::split_body_lines())
592 .parallel(true)
593 .stop_on_exception(true);
594 let mut svc = SplitterService::new(config, failing_pipeline());
595
596 let result = svc
597 .ready()
598 .await
599 .unwrap()
600 .call(make_exchange("a\nb\nc"))
601 .await;
602
603 assert!(result.is_err(), "expected error when all fragments fail");
605 }
606
607 #[tokio::test]
610 async fn test_splitter_stream_bodies_creates_valid_json() {
611 use bytes::Bytes;
612 use camel_api::{StreamBody, StreamMetadata};
613 use futures::stream;
614 use tokio::sync::Mutex;
615
616 let chunks = vec![Ok(Bytes::from("test"))];
617 let stream_body = StreamBody {
618 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
619 metadata: StreamMetadata {
620 origin: Some("kafka://topic/partition".to_string()),
621 ..Default::default()
622 },
623 };
624
625 let original = Exchange::new(Message {
626 headers: Default::default(),
627 body: Body::Empty,
628 });
629
630 let results = vec![Ok(Exchange::new(Message {
631 headers: Default::default(),
632 body: Body::Stream(stream_body),
633 }))];
634
635 let result = aggregate(results, original, AggregationStrategy::CollectAll);
636
637 let exchange = result.expect("Expected Ok result");
638 assert!(
639 matches!(exchange.input.body, Body::Json(_)),
640 "Expected Json body"
641 );
642
643 if let Body::Json(value) = exchange.input.body {
644 let json_str = serde_json::to_string(&value).unwrap();
645 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
646
647 assert!(parsed.is_array());
648 let arr = parsed.as_array().unwrap();
649 assert!(arr[0].is_object());
650 assert!(arr[0]["_stream"].is_object());
651 assert_eq!(arr[0]["_stream"]["origin"], "kafka://topic/partition");
652 assert_eq!(arr[0]["_stream"]["placeholder"], true);
653 }
654 }
655
656 #[tokio::test]
657 async fn test_splitter_stream_with_none_origin_creates_valid_json() {
658 use bytes::Bytes;
659 use camel_api::{StreamBody, StreamMetadata};
660 use futures::stream;
661 use tokio::sync::Mutex;
662
663 let chunks = vec![Ok(Bytes::from("test"))];
664 let stream_body = StreamBody {
665 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
666 metadata: StreamMetadata {
667 origin: None,
668 ..Default::default()
669 },
670 };
671
672 let original = Exchange::new(Message {
673 headers: Default::default(),
674 body: Body::Empty,
675 });
676
677 let results = vec![Ok(Exchange::new(Message {
678 headers: Default::default(),
679 body: Body::Stream(stream_body),
680 }))];
681
682 let result = aggregate(results, original, AggregationStrategy::CollectAll);
683
684 let exchange = result.expect("Expected Ok result");
685 assert!(
686 matches!(exchange.input.body, Body::Json(_)),
687 "Expected Json body"
688 );
689
690 if let Body::Json(value) = exchange.input.body {
691 let json_str = serde_json::to_string(&value).unwrap();
692 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
693
694 assert!(parsed.is_array());
695 let arr = parsed.as_array().unwrap();
696 assert!(arr[0].is_object());
697 assert!(arr[0]["_stream"].is_object());
698 assert_eq!(arr[0]["_stream"]["origin"], serde_json::Value::Null);
699 assert_eq!(arr[0]["_stream"]["placeholder"], true);
700 }
701 }
702}