1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::future::join_all;
6use tokio::sync::Semaphore;
7use tower::Service;
8
9use camel_api::{
10 AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, SplitterConfig, Value,
11};
12
13pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
17pub const CAMEL_SPLIT_SIZE: &str = "CamelSplitSize";
19pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
21
22#[derive(Clone)]
33pub struct SplitterService {
34 expression: camel_api::SplitExpression,
35 sub_pipeline: BoxProcessor,
36 aggregation: AggregationStrategy,
37 parallel: bool,
38 parallel_limit: Option<usize>,
39 stop_on_exception: bool,
40}
41
42impl SplitterService {
43 pub fn new(config: SplitterConfig, sub_pipeline: BoxProcessor) -> Self {
45 Self {
46 expression: config.expression,
47 sub_pipeline,
48 aggregation: config.aggregation,
49 parallel: config.parallel,
50 parallel_limit: config.parallel_limit,
51 stop_on_exception: config.stop_on_exception,
52 }
53 }
54}
55
56impl Service<Exchange> for SplitterService {
57 type Response = Exchange;
58 type Error = CamelError;
59 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
60
61 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62 self.sub_pipeline.poll_ready(cx)
63 }
64
65 fn call(&mut self, exchange: Exchange) -> Self::Future {
66 let original = exchange.clone();
67 let expression = self.expression.clone();
68 let sub_pipeline = self.sub_pipeline.clone();
69 let aggregation = self.aggregation.clone();
70 let parallel = self.parallel;
71 let parallel_limit = self.parallel_limit;
72 let stop_on_exception = self.stop_on_exception;
73
74 Box::pin(async move {
75 let mut fragments = expression(&exchange);
77
78 if fragments.is_empty() {
80 return Ok(original);
81 }
82
83 let total = fragments.len();
84
85 for (i, frag) in fragments.iter_mut().enumerate() {
87 frag.set_property(CAMEL_SPLIT_INDEX, Value::from(i as u64));
88 frag.set_property(CAMEL_SPLIT_SIZE, Value::from(total as u64));
89 frag.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(i == total - 1));
90 }
91
92 let results = if parallel {
94 process_parallel(fragments, sub_pipeline, parallel_limit, stop_on_exception).await
95 } else {
96 process_sequential(fragments, sub_pipeline, stop_on_exception).await
97 };
98
99 aggregate(results, original, aggregation)
101 })
102 }
103}
104
105async fn process_sequential(
108 fragments: Vec<Exchange>,
109 sub_pipeline: BoxProcessor,
110 stop_on_exception: bool,
111) -> Vec<Result<Exchange, CamelError>> {
112 let mut results = Vec::with_capacity(fragments.len());
113
114 for fragment in fragments {
115 let mut pipeline = sub_pipeline.clone();
116 match tower::ServiceExt::ready(&mut pipeline).await {
117 Err(e) => {
118 results.push(Err(e));
119 if stop_on_exception {
120 break;
121 }
122 }
123 Ok(svc) => {
124 let result = svc.call(fragment).await;
125 let is_err = result.is_err();
126 results.push(result);
127 if stop_on_exception && is_err {
128 break;
129 }
130 }
131 }
132 }
133
134 results
135}
136
137async fn process_parallel(
140 fragments: Vec<Exchange>,
141 sub_pipeline: BoxProcessor,
142 parallel_limit: Option<usize>,
143 _stop_on_exception: bool,
144) -> Vec<Result<Exchange, CamelError>> {
145 let semaphore = parallel_limit.map(|limit| std::sync::Arc::new(Semaphore::new(limit)));
146
147 let futures: Vec<_> = fragments
148 .into_iter()
149 .map(|fragment| {
150 let mut pipeline = sub_pipeline.clone();
151 let sem = semaphore.clone();
152 async move {
153 let _permit = match &sem {
155 Some(s) => Some(s.acquire().await.map_err(|e| {
156 CamelError::ProcessorError(format!("semaphore error: {e}"))
157 })?),
158 None => None,
159 };
160
161 tower::ServiceExt::ready(&mut pipeline).await?;
162 pipeline.call(fragment).await
163 }
164 })
165 .collect();
166
167 join_all(futures).await
168}
169
170fn aggregate(
173 results: Vec<Result<Exchange, CamelError>>,
174 original: Exchange,
175 strategy: AggregationStrategy,
176) -> Result<Exchange, CamelError> {
177 match strategy {
178 AggregationStrategy::LastWins => {
179 results.into_iter().last().unwrap_or_else(|| Ok(original))
181 }
182 AggregationStrategy::CollectAll => {
183 let mut bodies = Vec::new();
185 for result in results {
186 let ex = result?;
187 let value = match &ex.input.body {
188 Body::Text(s) => Value::String(s.clone()),
189 Body::Json(v) => v.clone(),
190 Body::Xml(s) => Value::String(s.clone()),
191 Body::Bytes(b) => Value::String(String::from_utf8_lossy(b).into_owned()),
192 Body::Empty => Value::Null,
193 Body::Stream(s) => serde_json::json!({
194 "_stream": {
195 "origin": s.metadata.origin,
196 "placeholder": true,
197 "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
198 }
199 }),
200 };
201 bodies.push(value);
202 }
203 let mut out = original;
204 out.input.body = Body::Json(Value::Array(bodies));
205 Ok(out)
206 }
207 AggregationStrategy::Original => Ok(original),
208 AggregationStrategy::Custom(fold_fn) => {
209 let mut iter = results.into_iter();
211 let first = iter.next().unwrap_or_else(|| Ok(original.clone()))?;
212 iter.try_fold(first, |acc, next_result| {
213 let next = next_result?;
214 Ok(fold_fn(acc, next))
215 })
216 }
217 }
218}
219
220#[cfg(test)]
223mod tests {
224 use super::*;
225 use camel_api::{BoxProcessorExt, Message};
226 use std::sync::Arc;
227 use std::sync::atomic::{AtomicUsize, Ordering};
228 use tower::ServiceExt;
229
230 fn passthrough_pipeline() -> BoxProcessor {
233 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
234 }
235
236 fn uppercase_pipeline() -> BoxProcessor {
237 BoxProcessor::from_fn(|mut ex: Exchange| {
238 Box::pin(async move {
239 if let Body::Text(s) = &ex.input.body {
240 ex.input.body = Body::Text(s.to_uppercase());
241 }
242 Ok(ex)
243 })
244 })
245 }
246
247 fn failing_pipeline() -> BoxProcessor {
248 BoxProcessor::from_fn(|_ex| {
249 Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
250 })
251 }
252
253 fn fail_on_nth(n: usize) -> BoxProcessor {
254 let count = Arc::new(AtomicUsize::new(0));
255 BoxProcessor::from_fn(move |ex: Exchange| {
256 let count = Arc::clone(&count);
257 Box::pin(async move {
258 let c = count.fetch_add(1, Ordering::SeqCst);
259 if c == n {
260 Err(CamelError::ProcessorError(format!("fail on {c}")))
261 } else {
262 Ok(ex)
263 }
264 })
265 })
266 }
267
268 fn make_exchange(text: &str) -> Exchange {
269 Exchange::new(Message::new(text))
270 }
271
272 #[tokio::test]
275 async fn test_split_sequential_last_wins() {
276 let config = SplitterConfig::new(camel_api::split_body_lines())
277 .aggregation(AggregationStrategy::LastWins);
278 let mut svc = SplitterService::new(config, uppercase_pipeline());
279
280 let result = svc
281 .ready()
282 .await
283 .unwrap()
284 .call(make_exchange("a\nb\nc"))
285 .await
286 .unwrap();
287 assert_eq!(result.input.body.as_text(), Some("C"));
288 }
289
290 #[tokio::test]
293 async fn test_split_sequential_collect_all() {
294 let config = SplitterConfig::new(camel_api::split_body_lines())
295 .aggregation(AggregationStrategy::CollectAll);
296 let mut svc = SplitterService::new(config, uppercase_pipeline());
297
298 let result = svc
299 .ready()
300 .await
301 .unwrap()
302 .call(make_exchange("a\nb\nc"))
303 .await
304 .unwrap();
305 let expected = serde_json::json!(["A", "B", "C"]);
306 match &result.input.body {
307 Body::Json(v) => assert_eq!(*v, expected),
308 other => panic!("expected JSON body, got {other:?}"),
309 }
310 }
311
312 #[tokio::test]
315 async fn test_split_sequential_original() {
316 let config = SplitterConfig::new(camel_api::split_body_lines())
317 .aggregation(AggregationStrategy::Original);
318 let mut svc = SplitterService::new(config, uppercase_pipeline());
319
320 let result = svc
321 .ready()
322 .await
323 .unwrap()
324 .call(make_exchange("a\nb\nc"))
325 .await
326 .unwrap();
327 assert_eq!(result.input.body.as_text(), Some("a\nb\nc"));
329 }
330
331 #[tokio::test]
334 async fn test_split_sequential_custom_aggregation() {
335 let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
336 Arc::new(|mut acc: Exchange, next: Exchange| {
337 let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
338 let next_text = next.input.body.as_text().unwrap_or("").to_string();
339 acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
340 acc
341 });
342
343 let config = SplitterConfig::new(camel_api::split_body_lines())
344 .aggregation(AggregationStrategy::Custom(joiner));
345 let mut svc = SplitterService::new(config, uppercase_pipeline());
346
347 let result = svc
348 .ready()
349 .await
350 .unwrap()
351 .call(make_exchange("a\nb\nc"))
352 .await
353 .unwrap();
354 assert_eq!(result.input.body.as_text(), Some("A+B+C"));
355 }
356
357 #[tokio::test]
360 async fn test_split_stop_on_exception() {
361 let config = SplitterConfig::new(camel_api::split_body_lines()).stop_on_exception(true);
363 let mut svc = SplitterService::new(config, fail_on_nth(1));
364
365 let result = svc
366 .ready()
367 .await
368 .unwrap()
369 .call(make_exchange("a\nb\nc\nd\ne"))
370 .await;
371
372 assert!(result.is_err(), "expected error due to stop_on_exception");
374 }
375
376 #[tokio::test]
379 async fn test_split_continue_on_exception() {
380 let config = SplitterConfig::new(camel_api::split_body_lines())
382 .stop_on_exception(false)
383 .aggregation(AggregationStrategy::LastWins);
384 let mut svc = SplitterService::new(config, fail_on_nth(1));
385
386 let result = svc
387 .ready()
388 .await
389 .unwrap()
390 .call(make_exchange("a\nb\nc"))
391 .await;
392
393 assert!(result.is_ok(), "last fragment should succeed");
395 }
396
397 #[tokio::test]
400 async fn test_split_empty_fragments() {
401 let config = SplitterConfig::new(camel_api::split_body_lines());
403 let mut svc = SplitterService::new(config, passthrough_pipeline());
404
405 let mut ex = Exchange::new(Message::default()); ex.set_property("marker", Value::Bool(true));
407
408 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
409 assert!(result.input.body.is_empty());
410 assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
411 }
412
413 #[tokio::test]
416 async fn test_split_metadata_properties() {
417 let recorder = BoxProcessor::from_fn(|ex: Exchange| {
421 Box::pin(async move {
422 let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
423 let size = ex.property(CAMEL_SPLIT_SIZE).cloned();
424 let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
425 let body = serde_json::json!({
426 "index": idx,
427 "size": size,
428 "complete": complete,
429 });
430 let mut out = ex;
431 out.input.body = Body::Json(body);
432 Ok(out)
433 })
434 });
435
436 let config = SplitterConfig::new(camel_api::split_body_lines())
437 .aggregation(AggregationStrategy::CollectAll);
438 let mut svc = SplitterService::new(config, recorder);
439
440 let result = svc
441 .ready()
442 .await
443 .unwrap()
444 .call(make_exchange("x\ny\nz"))
445 .await
446 .unwrap();
447
448 let expected = serde_json::json!([
449 {"index": 0, "size": 3, "complete": false},
450 {"index": 1, "size": 3, "complete": false},
451 {"index": 2, "size": 3, "complete": true},
452 ]);
453 match &result.input.body {
454 Body::Json(v) => assert_eq!(*v, expected),
455 other => panic!("expected JSON body, got {other:?}"),
456 }
457 }
458
459 #[tokio::test]
462 async fn test_poll_ready_delegates_to_sub_pipeline() {
463 use std::sync::atomic::AtomicBool;
464
465 #[derive(Clone)]
467 struct DelayedReady {
468 ready: Arc<AtomicBool>,
469 }
470
471 impl Service<Exchange> for DelayedReady {
472 type Response = Exchange;
473 type Error = CamelError;
474 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
475
476 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
477 if self.ready.load(Ordering::SeqCst) {
478 Poll::Ready(Ok(()))
479 } else {
480 cx.waker().wake_by_ref();
481 Poll::Pending
482 }
483 }
484
485 fn call(&mut self, exchange: Exchange) -> Self::Future {
486 Box::pin(async move { Ok(exchange) })
487 }
488 }
489
490 let ready_flag = Arc::new(AtomicBool::new(false));
491 let inner = DelayedReady {
492 ready: Arc::clone(&ready_flag),
493 };
494 let boxed: BoxProcessor = BoxProcessor::new(inner);
495
496 let config = SplitterConfig::new(camel_api::split_body_lines());
497 let mut svc = SplitterService::new(config, boxed);
498
499 let waker = futures::task::noop_waker();
501 let mut cx = Context::from_waker(&waker);
502 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
503 assert!(
504 poll.is_pending(),
505 "expected Pending when sub_pipeline not ready"
506 );
507
508 ready_flag.store(true, Ordering::SeqCst);
510
511 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
512 assert!(
513 matches!(poll, Poll::Ready(Ok(()))),
514 "expected Ready after sub_pipeline becomes ready"
515 );
516 }
517
518 #[tokio::test]
521 async fn test_split_parallel_basic() {
522 let config = SplitterConfig::new(camel_api::split_body_lines())
523 .parallel(true)
524 .aggregation(AggregationStrategy::CollectAll);
525 let mut svc = SplitterService::new(config, uppercase_pipeline());
526
527 let result = svc
528 .ready()
529 .await
530 .unwrap()
531 .call(make_exchange("a\nb\nc"))
532 .await
533 .unwrap();
534
535 let expected = serde_json::json!(["A", "B", "C"]);
536 match &result.input.body {
537 Body::Json(v) => assert_eq!(*v, expected),
538 other => panic!("expected JSON body, got {other:?}"),
539 }
540 }
541
542 #[tokio::test]
545 async fn test_split_parallel_with_limit() {
546 use std::sync::atomic::AtomicUsize;
547
548 let concurrent = Arc::new(AtomicUsize::new(0));
549 let max_concurrent = Arc::new(AtomicUsize::new(0));
550
551 let c = Arc::clone(&concurrent);
552 let mc = Arc::clone(&max_concurrent);
553 let pipeline = BoxProcessor::from_fn(move |ex: Exchange| {
554 let c = Arc::clone(&c);
555 let mc = Arc::clone(&mc);
556 Box::pin(async move {
557 let current = c.fetch_add(1, Ordering::SeqCst) + 1;
558 mc.fetch_max(current, Ordering::SeqCst);
560 tokio::task::yield_now().await;
562 c.fetch_sub(1, Ordering::SeqCst);
563 Ok(ex)
564 })
565 });
566
567 let config = SplitterConfig::new(camel_api::split_body_lines())
568 .parallel(true)
569 .parallel_limit(2)
570 .aggregation(AggregationStrategy::CollectAll);
571 let mut svc = SplitterService::new(config, pipeline);
572
573 let result = svc
574 .ready()
575 .await
576 .unwrap()
577 .call(make_exchange("a\nb\nc\nd"))
578 .await;
579 assert!(result.is_ok());
580
581 let observed_max = max_concurrent.load(Ordering::SeqCst);
582 assert!(
583 observed_max <= 2,
584 "max concurrency was {observed_max}, expected <= 2"
585 );
586 }
587
588 #[tokio::test]
591 async fn test_split_parallel_stop_on_exception() {
592 let config = SplitterConfig::new(camel_api::split_body_lines())
593 .parallel(true)
594 .stop_on_exception(true);
595 let mut svc = SplitterService::new(config, failing_pipeline());
596
597 let result = svc
598 .ready()
599 .await
600 .unwrap()
601 .call(make_exchange("a\nb\nc"))
602 .await;
603
604 assert!(result.is_err(), "expected error when all fragments fail");
606 }
607
608 #[tokio::test]
611 async fn test_splitter_stream_bodies_creates_valid_json() {
612 use bytes::Bytes;
613 use camel_api::{StreamBody, StreamMetadata};
614 use futures::stream;
615 use tokio::sync::Mutex;
616
617 let chunks = vec![Ok(Bytes::from("test"))];
618 let stream_body = StreamBody {
619 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
620 metadata: StreamMetadata {
621 origin: Some("kafka://topic/partition".to_string()),
622 ..Default::default()
623 },
624 };
625
626 let original = Exchange::new(Message {
627 headers: Default::default(),
628 body: Body::Empty,
629 });
630
631 let results = vec![Ok(Exchange::new(Message {
632 headers: Default::default(),
633 body: Body::Stream(stream_body),
634 }))];
635
636 let result = aggregate(results, original, AggregationStrategy::CollectAll);
637
638 let exchange = result.expect("Expected Ok result");
639 assert!(
640 matches!(exchange.input.body, Body::Json(_)),
641 "Expected Json body"
642 );
643
644 if let Body::Json(value) = exchange.input.body {
645 let json_str = serde_json::to_string(&value).unwrap();
646 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
647
648 assert!(parsed.is_array());
649 let arr = parsed.as_array().unwrap();
650 assert!(arr[0].is_object());
651 assert!(arr[0]["_stream"].is_object());
652 assert_eq!(arr[0]["_stream"]["origin"], "kafka://topic/partition");
653 assert_eq!(arr[0]["_stream"]["placeholder"], true);
654 }
655 }
656
657 #[tokio::test]
658 async fn test_splitter_stream_with_none_origin_creates_valid_json() {
659 use bytes::Bytes;
660 use camel_api::{StreamBody, StreamMetadata};
661 use futures::stream;
662 use tokio::sync::Mutex;
663
664 let chunks = vec![Ok(Bytes::from("test"))];
665 let stream_body = StreamBody {
666 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
667 metadata: StreamMetadata {
668 origin: None,
669 ..Default::default()
670 },
671 };
672
673 let original = Exchange::new(Message {
674 headers: Default::default(),
675 body: Body::Empty,
676 });
677
678 let results = vec![Ok(Exchange::new(Message {
679 headers: Default::default(),
680 body: Body::Stream(stream_body),
681 }))];
682
683 let result = aggregate(results, original, AggregationStrategy::CollectAll);
684
685 let exchange = result.expect("Expected Ok result");
686 assert!(
687 matches!(exchange.input.body, Body::Json(_)),
688 "Expected Json body"
689 );
690
691 if let Body::Json(value) = exchange.input.body {
692 let json_str = serde_json::to_string(&value).unwrap();
693 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
694
695 assert!(parsed.is_array());
696 let arr = parsed.as_array().unwrap();
697 assert!(arr[0].is_object());
698 assert!(arr[0]["_stream"].is_object());
699 assert_eq!(arr[0]["_stream"]["origin"], serde_json::Value::Null);
700 assert_eq!(arr[0]["_stream"]["placeholder"], true);
701 }
702 }
703}