Skip to main content

camel_processor/
splitter.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::future::join_all;
6use tokio::sync::Semaphore;
7use tower::Service;
8
9use camel_api::{
10    AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, SplitterConfig, Value,
11};
12
13// ── Metadata property keys ─────────────────────────────────────────────
14
15/// Property key for the zero-based index of a fragment within the split.
16pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
17/// Property key for the total number of fragments produced by the split.
18pub const CAMEL_SPLIT_SIZE: &str = "CamelSplitSize";
19/// Property key indicating whether this fragment is the last one.
20pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
21
22// ── SplitterService ────────────────────────────────────────────────────
23
24/// Tower Service implementing the Splitter EIP.
25///
26/// Splits an incoming exchange into fragments via a configurable expression,
27/// processes each fragment through a sub-pipeline, and aggregates the results.
28///
29/// **Note:** In parallel mode, `stop_on_exception` only affects the aggregation
30/// phase. All spawned fragments run to completion because `join_all` cannot
31/// cancel in-flight futures. Sequential mode stops processing immediately.
32#[derive(Clone)]
33pub struct SplitterService {
34    expression: camel_api::SplitExpression,
35    sub_pipeline: BoxProcessor,
36    aggregation: AggregationStrategy,
37    parallel: bool,
38    parallel_limit: Option<usize>,
39    stop_on_exception: bool,
40}
41
42impl SplitterService {
43    /// Create a new `SplitterService` from a [`SplitterConfig`] and a sub-pipeline.
44    pub fn new(config: SplitterConfig, sub_pipeline: BoxProcessor) -> Self {
45        Self {
46            expression: config.expression,
47            sub_pipeline,
48            aggregation: config.aggregation,
49            parallel: config.parallel,
50            parallel_limit: config.parallel_limit,
51            stop_on_exception: config.stop_on_exception,
52        }
53    }
54}
55
56impl Service<Exchange> for SplitterService {
57    type Response = Exchange;
58    type Error = CamelError;
59    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
60
61    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62        self.sub_pipeline.poll_ready(cx)
63    }
64
65    fn call(&mut self, exchange: Exchange) -> Self::Future {
66        let original = exchange.clone();
67        let expression = self.expression.clone();
68        let sub_pipeline = self.sub_pipeline.clone();
69        let aggregation = self.aggregation.clone();
70        let parallel = self.parallel;
71        let parallel_limit = self.parallel_limit;
72        let stop_on_exception = self.stop_on_exception;
73
74        Box::pin(async move {
75            // Split the exchange into fragments.
76            let mut fragments = expression(&exchange);
77
78            // If no fragments were produced, return the original exchange.
79            if fragments.is_empty() {
80                return Ok(original);
81            }
82
83            let total = fragments.len();
84
85            // Set metadata on each fragment.
86            for (i, frag) in fragments.iter_mut().enumerate() {
87                frag.set_property(CAMEL_SPLIT_INDEX, Value::from(i as u64));
88                frag.set_property(CAMEL_SPLIT_SIZE, Value::from(total as u64));
89                frag.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(i == total - 1));
90            }
91
92            // Process fragments through the sub-pipeline.
93            let results = if parallel {
94                process_parallel(fragments, sub_pipeline, parallel_limit, stop_on_exception).await
95            } else {
96                process_sequential(fragments, sub_pipeline, stop_on_exception).await
97            };
98
99            // Aggregate the results.
100            aggregate(results, original, aggregation)
101        })
102    }
103}
104
105// ── Sequential processing ──────────────────────────────────────────────
106
107async fn process_sequential(
108    fragments: Vec<Exchange>,
109    sub_pipeline: BoxProcessor,
110    stop_on_exception: bool,
111) -> Vec<Result<Exchange, CamelError>> {
112    let mut results = Vec::with_capacity(fragments.len());
113
114    for fragment in fragments {
115        let mut pipeline = sub_pipeline.clone();
116        match tower::ServiceExt::ready(&mut pipeline).await {
117            Err(e) => {
118                results.push(Err(e));
119                if stop_on_exception {
120                    break;
121                }
122            }
123            Ok(svc) => {
124                let result = svc.call(fragment).await;
125                let is_err = result.is_err();
126                results.push(result);
127                if stop_on_exception && is_err {
128                    break;
129                }
130            }
131        }
132    }
133
134    results
135}
136
137// ── Parallel processing ────────────────────────────────────────────────
138
139async fn process_parallel(
140    fragments: Vec<Exchange>,
141    sub_pipeline: BoxProcessor,
142    parallel_limit: Option<usize>,
143    _stop_on_exception: bool,
144) -> Vec<Result<Exchange, CamelError>> {
145    let semaphore = parallel_limit.map(|limit| std::sync::Arc::new(Semaphore::new(limit)));
146
147    let futures: Vec<_> = fragments
148        .into_iter()
149        .map(|fragment| {
150            let mut pipeline = sub_pipeline.clone();
151            let sem = semaphore.clone();
152            async move {
153                // Acquire semaphore permit if a limit is set.
154                let _permit = match &sem {
155                    Some(s) => Some(s.acquire().await.map_err(|e| {
156                        CamelError::ProcessorError(format!("semaphore error: {e}"))
157                    })?),
158                    None => None,
159                };
160
161                tower::ServiceExt::ready(&mut pipeline).await?;
162                pipeline.call(fragment).await
163            }
164        })
165        .collect();
166
167    join_all(futures).await
168}
169
170// ── Aggregation ────────────────────────────────────────────────────────
171
172fn aggregate(
173    results: Vec<Result<Exchange, CamelError>>,
174    original: Exchange,
175    strategy: AggregationStrategy,
176) -> Result<Exchange, CamelError> {
177    match strategy {
178        AggregationStrategy::LastWins => {
179            // Return the last result (error or success).
180            results.into_iter().last().unwrap_or_else(|| Ok(original))
181        }
182        AggregationStrategy::CollectAll => {
183            // Collect all bodies into a JSON array. Errors propagate.
184            let mut bodies = Vec::new();
185            for result in results {
186                let ex = result?;
187                let value = match &ex.input.body {
188                    Body::Text(s) => Value::String(s.clone()),
189                    Body::Json(v) => v.clone(),
190                    Body::Xml(s) => Value::String(s.clone()),
191                    Body::Bytes(b) => Value::String(String::from_utf8_lossy(b).into_owned()),
192                    Body::Empty => Value::Null,
193                    Body::Stream(s) => serde_json::json!({
194                        "_stream": {
195                            "origin": s.metadata.origin,
196                            "placeholder": true,
197                            "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
198                        }
199                    }),
200                };
201                bodies.push(value);
202            }
203            let mut out = original;
204            out.input.body = Body::Json(Value::Array(bodies));
205            Ok(out)
206        }
207        AggregationStrategy::Original => Ok(original),
208        AggregationStrategy::Custom(fold_fn) => {
209            // Fold using the custom function, starting from the first result.
210            let mut iter = results.into_iter();
211            let first = iter.next().unwrap_or_else(|| Ok(original.clone()))?;
212            iter.try_fold(first, |acc, next_result| {
213                let next = next_result?;
214                Ok(fold_fn(acc, next))
215            })
216        }
217    }
218}
219
220// ── Tests ──────────────────────────────────────────────────────────────
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use camel_api::{BoxProcessorExt, Message};
226    use std::sync::Arc;
227    use std::sync::atomic::{AtomicUsize, Ordering};
228    use tower::ServiceExt;
229
230    // ── Test helpers ───────────────────────────────────────────────────
231
232    fn passthrough_pipeline() -> BoxProcessor {
233        BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
234    }
235
236    fn uppercase_pipeline() -> BoxProcessor {
237        BoxProcessor::from_fn(|mut ex: Exchange| {
238            Box::pin(async move {
239                if let Body::Text(s) = &ex.input.body {
240                    ex.input.body = Body::Text(s.to_uppercase());
241                }
242                Ok(ex)
243            })
244        })
245    }
246
247    fn failing_pipeline() -> BoxProcessor {
248        BoxProcessor::from_fn(|_ex| {
249            Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
250        })
251    }
252
253    fn fail_on_nth(n: usize) -> BoxProcessor {
254        let count = Arc::new(AtomicUsize::new(0));
255        BoxProcessor::from_fn(move |ex: Exchange| {
256            let count = Arc::clone(&count);
257            Box::pin(async move {
258                let c = count.fetch_add(1, Ordering::SeqCst);
259                if c == n {
260                    Err(CamelError::ProcessorError(format!("fail on {c}")))
261                } else {
262                    Ok(ex)
263                }
264            })
265        })
266    }
267
268    fn make_exchange(text: &str) -> Exchange {
269        Exchange::new(Message::new(text))
270    }
271
272    // ── 1. Sequential + LastWins ───────────────────────────────────────
273
274    #[tokio::test]
275    async fn test_split_sequential_last_wins() {
276        let config = SplitterConfig::new(camel_api::split_body_lines())
277            .aggregation(AggregationStrategy::LastWins);
278        let mut svc = SplitterService::new(config, uppercase_pipeline());
279
280        let result = svc
281            .ready()
282            .await
283            .unwrap()
284            .call(make_exchange("a\nb\nc"))
285            .await
286            .unwrap();
287        assert_eq!(result.input.body.as_text(), Some("C"));
288    }
289
290    // ── 2. Sequential + CollectAll ─────────────────────────────────────
291
292    #[tokio::test]
293    async fn test_split_sequential_collect_all() {
294        let config = SplitterConfig::new(camel_api::split_body_lines())
295            .aggregation(AggregationStrategy::CollectAll);
296        let mut svc = SplitterService::new(config, uppercase_pipeline());
297
298        let result = svc
299            .ready()
300            .await
301            .unwrap()
302            .call(make_exchange("a\nb\nc"))
303            .await
304            .unwrap();
305        let expected = serde_json::json!(["A", "B", "C"]);
306        match &result.input.body {
307            Body::Json(v) => assert_eq!(*v, expected),
308            other => panic!("expected JSON body, got {other:?}"),
309        }
310    }
311
312    // ── 3. Sequential + Original ───────────────────────────────────────
313
314    #[tokio::test]
315    async fn test_split_sequential_original() {
316        let config = SplitterConfig::new(camel_api::split_body_lines())
317            .aggregation(AggregationStrategy::Original);
318        let mut svc = SplitterService::new(config, uppercase_pipeline());
319
320        let result = svc
321            .ready()
322            .await
323            .unwrap()
324            .call(make_exchange("a\nb\nc"))
325            .await
326            .unwrap();
327        // Original body should be unchanged.
328        assert_eq!(result.input.body.as_text(), Some("a\nb\nc"));
329    }
330
331    // ── 4. Sequential + Custom aggregation ─────────────────────────────
332
333    #[tokio::test]
334    async fn test_split_sequential_custom_aggregation() {
335        let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
336            Arc::new(|mut acc: Exchange, next: Exchange| {
337                let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
338                let next_text = next.input.body.as_text().unwrap_or("").to_string();
339                acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
340                acc
341            });
342
343        let config = SplitterConfig::new(camel_api::split_body_lines())
344            .aggregation(AggregationStrategy::Custom(joiner));
345        let mut svc = SplitterService::new(config, uppercase_pipeline());
346
347        let result = svc
348            .ready()
349            .await
350            .unwrap()
351            .call(make_exchange("a\nb\nc"))
352            .await
353            .unwrap();
354        assert_eq!(result.input.body.as_text(), Some("A+B+C"));
355    }
356
357    // ── 5. Stop on exception ───────────────────────────────────────────
358
359    #[tokio::test]
360    async fn test_split_stop_on_exception() {
361        // 5 fragments, fail on the 2nd (index 1), stop=true
362        let config = SplitterConfig::new(camel_api::split_body_lines()).stop_on_exception(true);
363        let mut svc = SplitterService::new(config, fail_on_nth(1));
364
365        let result = svc
366            .ready()
367            .await
368            .unwrap()
369            .call(make_exchange("a\nb\nc\nd\ne"))
370            .await;
371
372        // LastWins is default, the last result should be the error from fragment 1.
373        assert!(result.is_err(), "expected error due to stop_on_exception");
374    }
375
376    // ── 6. Continue on exception ───────────────────────────────────────
377
378    #[tokio::test]
379    async fn test_split_continue_on_exception() {
380        // 3 fragments, fail on 2nd (index 1), stop=false, LastWins.
381        let config = SplitterConfig::new(camel_api::split_body_lines())
382            .stop_on_exception(false)
383            .aggregation(AggregationStrategy::LastWins);
384        let mut svc = SplitterService::new(config, fail_on_nth(1));
385
386        let result = svc
387            .ready()
388            .await
389            .unwrap()
390            .call(make_exchange("a\nb\nc"))
391            .await;
392
393        // LastWins: last fragment (index 2) succeeded.
394        assert!(result.is_ok(), "last fragment should succeed");
395    }
396
397    // ── 7. Empty fragments ─────────────────────────────────────────────
398
399    #[tokio::test]
400    async fn test_split_empty_fragments() {
401        // Body::Empty → no fragments → return original unchanged.
402        let config = SplitterConfig::new(camel_api::split_body_lines());
403        let mut svc = SplitterService::new(config, passthrough_pipeline());
404
405        let mut ex = Exchange::new(Message::default()); // Body::Empty
406        ex.set_property("marker", Value::Bool(true));
407
408        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
409        assert!(result.input.body.is_empty());
410        assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
411    }
412
413    // ── 8. Metadata properties ─────────────────────────────────────────
414
415    #[tokio::test]
416    async fn test_split_metadata_properties() {
417        // Use passthrough so we can inspect metadata on returned fragments.
418        // CollectAll won't preserve metadata, so use a pipeline that records
419        // the metadata into the body as JSON.
420        let recorder = BoxProcessor::from_fn(|ex: Exchange| {
421            Box::pin(async move {
422                let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
423                let size = ex.property(CAMEL_SPLIT_SIZE).cloned();
424                let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
425                let body = serde_json::json!({
426                    "index": idx,
427                    "size": size,
428                    "complete": complete,
429                });
430                let mut out = ex;
431                out.input.body = Body::Json(body);
432                Ok(out)
433            })
434        });
435
436        let config = SplitterConfig::new(camel_api::split_body_lines())
437            .aggregation(AggregationStrategy::CollectAll);
438        let mut svc = SplitterService::new(config, recorder);
439
440        let result = svc
441            .ready()
442            .await
443            .unwrap()
444            .call(make_exchange("x\ny\nz"))
445            .await
446            .unwrap();
447
448        let expected = serde_json::json!([
449            {"index": 0, "size": 3, "complete": false},
450            {"index": 1, "size": 3, "complete": false},
451            {"index": 2, "size": 3, "complete": true},
452        ]);
453        match &result.input.body {
454            Body::Json(v) => assert_eq!(*v, expected),
455            other => panic!("expected JSON body, got {other:?}"),
456        }
457    }
458
459    // ── 9. poll_ready delegates to sub-pipeline ────────────────────────
460
461    #[tokio::test]
462    async fn test_poll_ready_delegates_to_sub_pipeline() {
463        use std::sync::atomic::AtomicBool;
464
465        // A service that is initially not ready, then becomes ready.
466        #[derive(Clone)]
467        struct DelayedReady {
468            ready: Arc<AtomicBool>,
469        }
470
471        impl Service<Exchange> for DelayedReady {
472            type Response = Exchange;
473            type Error = CamelError;
474            type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
475
476            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
477                if self.ready.load(Ordering::SeqCst) {
478                    Poll::Ready(Ok(()))
479                } else {
480                    cx.waker().wake_by_ref();
481                    Poll::Pending
482                }
483            }
484
485            fn call(&mut self, exchange: Exchange) -> Self::Future {
486                Box::pin(async move { Ok(exchange) })
487            }
488        }
489
490        let ready_flag = Arc::new(AtomicBool::new(false));
491        let inner = DelayedReady {
492            ready: Arc::clone(&ready_flag),
493        };
494        let boxed: BoxProcessor = BoxProcessor::new(inner);
495
496        let config = SplitterConfig::new(camel_api::split_body_lines());
497        let mut svc = SplitterService::new(config, boxed);
498
499        // First poll should be Pending.
500        let waker = futures::task::noop_waker();
501        let mut cx = Context::from_waker(&waker);
502        let poll = Pin::new(&mut svc).poll_ready(&mut cx);
503        assert!(
504            poll.is_pending(),
505            "expected Pending when sub_pipeline not ready"
506        );
507
508        // Mark inner as ready.
509        ready_flag.store(true, Ordering::SeqCst);
510
511        let poll = Pin::new(&mut svc).poll_ready(&mut cx);
512        assert!(
513            matches!(poll, Poll::Ready(Ok(()))),
514            "expected Ready after sub_pipeline becomes ready"
515        );
516    }
517
518    // ── 10. Parallel basic ─────────────────────────────────────────────
519
520    #[tokio::test]
521    async fn test_split_parallel_basic() {
522        let config = SplitterConfig::new(camel_api::split_body_lines())
523            .parallel(true)
524            .aggregation(AggregationStrategy::CollectAll);
525        let mut svc = SplitterService::new(config, uppercase_pipeline());
526
527        let result = svc
528            .ready()
529            .await
530            .unwrap()
531            .call(make_exchange("a\nb\nc"))
532            .await
533            .unwrap();
534
535        let expected = serde_json::json!(["A", "B", "C"]);
536        match &result.input.body {
537            Body::Json(v) => assert_eq!(*v, expected),
538            other => panic!("expected JSON body, got {other:?}"),
539        }
540    }
541
542    // ── 11. Parallel with limit ────────────────────────────────────────
543
544    #[tokio::test]
545    async fn test_split_parallel_with_limit() {
546        use std::sync::atomic::AtomicUsize;
547
548        let concurrent = Arc::new(AtomicUsize::new(0));
549        let max_concurrent = Arc::new(AtomicUsize::new(0));
550
551        let c = Arc::clone(&concurrent);
552        let mc = Arc::clone(&max_concurrent);
553        let pipeline = BoxProcessor::from_fn(move |ex: Exchange| {
554            let c = Arc::clone(&c);
555            let mc = Arc::clone(&mc);
556            Box::pin(async move {
557                let current = c.fetch_add(1, Ordering::SeqCst) + 1;
558                // Record the high-water mark.
559                mc.fetch_max(current, Ordering::SeqCst);
560                // Yield to let other tasks run.
561                tokio::task::yield_now().await;
562                c.fetch_sub(1, Ordering::SeqCst);
563                Ok(ex)
564            })
565        });
566
567        let config = SplitterConfig::new(camel_api::split_body_lines())
568            .parallel(true)
569            .parallel_limit(2)
570            .aggregation(AggregationStrategy::CollectAll);
571        let mut svc = SplitterService::new(config, pipeline);
572
573        let result = svc
574            .ready()
575            .await
576            .unwrap()
577            .call(make_exchange("a\nb\nc\nd"))
578            .await;
579        assert!(result.is_ok());
580
581        let observed_max = max_concurrent.load(Ordering::SeqCst);
582        assert!(
583            observed_max <= 2,
584            "max concurrency was {observed_max}, expected <= 2"
585        );
586    }
587
588    // ── 12. Parallel stop on exception ─────────────────────────────────
589
590    #[tokio::test]
591    async fn test_split_parallel_stop_on_exception() {
592        let config = SplitterConfig::new(camel_api::split_body_lines())
593            .parallel(true)
594            .stop_on_exception(true);
595        let mut svc = SplitterService::new(config, failing_pipeline());
596
597        let result = svc
598            .ready()
599            .await
600            .unwrap()
601            .call(make_exchange("a\nb\nc"))
602            .await;
603
604        // All fragments fail; LastWins returns the last error.
605        assert!(result.is_err(), "expected error when all fragments fail");
606    }
607
608    // ── 13. Stream body aggregation creates valid JSON ───────────────────
609
610    #[tokio::test]
611    async fn test_splitter_stream_bodies_creates_valid_json() {
612        use bytes::Bytes;
613        use camel_api::{StreamBody, StreamMetadata};
614        use futures::stream;
615        use tokio::sync::Mutex;
616
617        let chunks = vec![Ok(Bytes::from("test"))];
618        let stream_body = StreamBody {
619            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
620            metadata: StreamMetadata {
621                origin: Some("kafka://topic/partition".to_string()),
622                ..Default::default()
623            },
624        };
625
626        let original = Exchange::new(Message {
627            headers: Default::default(),
628            body: Body::Empty,
629        });
630
631        let results = vec![Ok(Exchange::new(Message {
632            headers: Default::default(),
633            body: Body::Stream(stream_body),
634        }))];
635
636        let result = aggregate(results, original, AggregationStrategy::CollectAll);
637
638        let exchange = result.expect("Expected Ok result");
639        assert!(
640            matches!(exchange.input.body, Body::Json(_)),
641            "Expected Json body"
642        );
643
644        if let Body::Json(value) = exchange.input.body {
645            let json_str = serde_json::to_string(&value).unwrap();
646            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
647
648            assert!(parsed.is_array());
649            let arr = parsed.as_array().unwrap();
650            assert!(arr[0].is_object());
651            assert!(arr[0]["_stream"].is_object());
652            assert_eq!(arr[0]["_stream"]["origin"], "kafka://topic/partition");
653            assert_eq!(arr[0]["_stream"]["placeholder"], true);
654        }
655    }
656
657    #[tokio::test]
658    async fn test_splitter_stream_with_none_origin_creates_valid_json() {
659        use bytes::Bytes;
660        use camel_api::{StreamBody, StreamMetadata};
661        use futures::stream;
662        use tokio::sync::Mutex;
663
664        let chunks = vec![Ok(Bytes::from("test"))];
665        let stream_body = StreamBody {
666            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
667            metadata: StreamMetadata {
668                origin: None,
669                ..Default::default()
670            },
671        };
672
673        let original = Exchange::new(Message {
674            headers: Default::default(),
675            body: Body::Empty,
676        });
677
678        let results = vec![Ok(Exchange::new(Message {
679            headers: Default::default(),
680            body: Body::Stream(stream_body),
681        }))];
682
683        let result = aggregate(results, original, AggregationStrategy::CollectAll);
684
685        let exchange = result.expect("Expected Ok result");
686        assert!(
687            matches!(exchange.input.body, Body::Json(_)),
688            "Expected Json body"
689        );
690
691        if let Body::Json(value) = exchange.input.body {
692            let json_str = serde_json::to_string(&value).unwrap();
693            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
694
695            assert!(parsed.is_array());
696            let arr = parsed.as_array().unwrap();
697            assert!(arr[0].is_object());
698            assert!(arr[0]["_stream"].is_object());
699            assert_eq!(arr[0]["_stream"]["origin"], serde_json::Value::Null);
700            assert_eq!(arr[0]["_stream"]["placeholder"], true);
701        }
702    }
703}