Skip to main content

camel_processor/
splitter.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::future::join_all;
6use tokio::sync::Semaphore;
7use tower::Service;
8
9use camel_api::{
10    AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, SplitterConfig, Value,
11};
12
13// ── Metadata property keys ─────────────────────────────────────────────
14
15/// Property key for the zero-based index of a fragment within the split.
16pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
17/// Property key for the total number of fragments produced by the split.
18pub const CAMEL_SPLIT_SIZE: &str = "CamelSplitSize";
19/// Property key indicating whether this fragment is the last one.
20pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
21
22// ── SplitterService ────────────────────────────────────────────────────
23
24/// Tower Service implementing the Splitter EIP.
25///
26/// Splits an incoming exchange into fragments via a configurable expression,
27/// processes each fragment through a sub-pipeline, and aggregates the results.
28///
29/// **Note:** In parallel mode, `stop_on_exception` only affects the aggregation
30/// phase. All spawned fragments run to completion because `join_all` cannot
31/// cancel in-flight futures. Sequential mode stops processing immediately.
32#[derive(Clone)]
33pub struct SplitterService {
34    expression: camel_api::SplitExpression,
35    sub_pipeline: BoxProcessor,
36    aggregation: AggregationStrategy,
37    parallel: bool,
38    parallel_limit: Option<usize>,
39    stop_on_exception: bool,
40}
41
42impl SplitterService {
43    /// Create a new `SplitterService` from a [`SplitterConfig`] and a sub-pipeline.
44    pub fn new(config: SplitterConfig, sub_pipeline: BoxProcessor) -> Self {
45        Self {
46            expression: config.expression,
47            sub_pipeline,
48            aggregation: config.aggregation,
49            parallel: config.parallel,
50            parallel_limit: config.parallel_limit,
51            stop_on_exception: config.stop_on_exception,
52        }
53    }
54}
55
56impl Service<Exchange> for SplitterService {
57    type Response = Exchange;
58    type Error = CamelError;
59    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
60
61    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62        self.sub_pipeline.poll_ready(cx)
63    }
64
65    fn call(&mut self, exchange: Exchange) -> Self::Future {
66        let original = exchange.clone();
67        let expression = self.expression.clone();
68        let sub_pipeline = self.sub_pipeline.clone();
69        let aggregation = self.aggregation.clone();
70        let parallel = self.parallel;
71        let parallel_limit = self.parallel_limit;
72        let stop_on_exception = self.stop_on_exception;
73
74        Box::pin(async move {
75            // Split the exchange into fragments.
76            let mut fragments = expression(&exchange);
77
78            // If no fragments were produced, return the original exchange.
79            if fragments.is_empty() {
80                return Ok(original);
81            }
82
83            let total = fragments.len();
84
85            // Set metadata on each fragment.
86            for (i, frag) in fragments.iter_mut().enumerate() {
87                frag.set_property(CAMEL_SPLIT_INDEX, Value::from(i as u64));
88                frag.set_property(CAMEL_SPLIT_SIZE, Value::from(total as u64));
89                frag.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(i == total - 1));
90            }
91
92            // Process fragments through the sub-pipeline.
93            let results = if parallel {
94                process_parallel(fragments, sub_pipeline, parallel_limit, stop_on_exception).await
95            } else {
96                process_sequential(fragments, sub_pipeline, stop_on_exception).await
97            };
98
99            // Aggregate the results.
100            aggregate(results, original, aggregation)
101        })
102    }
103}
104
105// ── Sequential processing ──────────────────────────────────────────────
106
107async fn process_sequential(
108    fragments: Vec<Exchange>,
109    sub_pipeline: BoxProcessor,
110    stop_on_exception: bool,
111) -> Vec<Result<Exchange, CamelError>> {
112    let mut results = Vec::with_capacity(fragments.len());
113
114    for fragment in fragments {
115        let mut pipeline = sub_pipeline.clone();
116        match tower::ServiceExt::ready(&mut pipeline).await {
117            Err(e) => {
118                results.push(Err(e));
119                if stop_on_exception {
120                    break;
121                }
122            }
123            Ok(svc) => {
124                let result = svc.call(fragment).await;
125                let is_err = result.is_err();
126                results.push(result);
127                if stop_on_exception && is_err {
128                    break;
129                }
130            }
131        }
132    }
133
134    results
135}
136
137// ── Parallel processing ────────────────────────────────────────────────
138
139async fn process_parallel(
140    fragments: Vec<Exchange>,
141    sub_pipeline: BoxProcessor,
142    parallel_limit: Option<usize>,
143    _stop_on_exception: bool,
144) -> Vec<Result<Exchange, CamelError>> {
145    let semaphore = parallel_limit.map(|limit| std::sync::Arc::new(Semaphore::new(limit)));
146
147    let futures: Vec<_> = fragments
148        .into_iter()
149        .map(|fragment| {
150            let mut pipeline = sub_pipeline.clone();
151            let sem = semaphore.clone();
152            async move {
153                // Acquire semaphore permit if a limit is set.
154                let _permit = match &sem {
155                    Some(s) => Some(s.acquire().await.map_err(|e| {
156                        CamelError::ProcessorError(format!("semaphore error: {e}"))
157                    })?),
158                    None => None,
159                };
160
161                tower::ServiceExt::ready(&mut pipeline).await?;
162                pipeline.call(fragment).await
163            }
164        })
165        .collect();
166
167    join_all(futures).await
168}
169
170// ── Aggregation ────────────────────────────────────────────────────────
171
172fn aggregate(
173    results: Vec<Result<Exchange, CamelError>>,
174    original: Exchange,
175    strategy: AggregationStrategy,
176) -> Result<Exchange, CamelError> {
177    match strategy {
178        AggregationStrategy::LastWins => {
179            // Return the last result (error or success).
180            results.into_iter().last().unwrap_or_else(|| Ok(original))
181        }
182        AggregationStrategy::CollectAll => {
183            // Collect all bodies into a JSON array. Errors propagate.
184            let mut bodies = Vec::new();
185            for result in results {
186                let ex = result?;
187                let value = match &ex.input.body {
188                    Body::Text(s) => Value::String(s.clone()),
189                    Body::Json(v) => v.clone(),
190                    Body::Bytes(b) => Value::String(String::from_utf8_lossy(b).into_owned()),
191                    Body::Empty => Value::Null,
192                    Body::Stream(s) => serde_json::json!({
193                        "_stream": {
194                            "origin": s.metadata.origin,
195                            "placeholder": true,
196                            "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
197                        }
198                    }),
199                };
200                bodies.push(value);
201            }
202            let mut out = original;
203            out.input.body = Body::Json(Value::Array(bodies));
204            Ok(out)
205        }
206        AggregationStrategy::Original => Ok(original),
207        AggregationStrategy::Custom(fold_fn) => {
208            // Fold using the custom function, starting from the first result.
209            let mut iter = results.into_iter();
210            let first = iter.next().unwrap_or_else(|| Ok(original.clone()))?;
211            iter.try_fold(first, |acc, next_result| {
212                let next = next_result?;
213                Ok(fold_fn(acc, next))
214            })
215        }
216    }
217}
218
219// ── Tests ──────────────────────────────────────────────────────────────
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use camel_api::{BoxProcessorExt, Message};
225    use std::sync::Arc;
226    use std::sync::atomic::{AtomicUsize, Ordering};
227    use tower::ServiceExt;
228
229    // ── Test helpers ───────────────────────────────────────────────────
230
231    fn passthrough_pipeline() -> BoxProcessor {
232        BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
233    }
234
235    fn uppercase_pipeline() -> BoxProcessor {
236        BoxProcessor::from_fn(|mut ex: Exchange| {
237            Box::pin(async move {
238                if let Body::Text(s) = &ex.input.body {
239                    ex.input.body = Body::Text(s.to_uppercase());
240                }
241                Ok(ex)
242            })
243        })
244    }
245
246    fn failing_pipeline() -> BoxProcessor {
247        BoxProcessor::from_fn(|_ex| {
248            Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
249        })
250    }
251
252    fn fail_on_nth(n: usize) -> BoxProcessor {
253        let count = Arc::new(AtomicUsize::new(0));
254        BoxProcessor::from_fn(move |ex: Exchange| {
255            let count = Arc::clone(&count);
256            Box::pin(async move {
257                let c = count.fetch_add(1, Ordering::SeqCst);
258                if c == n {
259                    Err(CamelError::ProcessorError(format!("fail on {c}")))
260                } else {
261                    Ok(ex)
262                }
263            })
264        })
265    }
266
267    fn make_exchange(text: &str) -> Exchange {
268        Exchange::new(Message::new(text))
269    }
270
271    // ── 1. Sequential + LastWins ───────────────────────────────────────
272
273    #[tokio::test]
274    async fn test_split_sequential_last_wins() {
275        let config = SplitterConfig::new(camel_api::split_body_lines())
276            .aggregation(AggregationStrategy::LastWins);
277        let mut svc = SplitterService::new(config, uppercase_pipeline());
278
279        let result = svc
280            .ready()
281            .await
282            .unwrap()
283            .call(make_exchange("a\nb\nc"))
284            .await
285            .unwrap();
286        assert_eq!(result.input.body.as_text(), Some("C"));
287    }
288
289    // ── 2. Sequential + CollectAll ─────────────────────────────────────
290
291    #[tokio::test]
292    async fn test_split_sequential_collect_all() {
293        let config = SplitterConfig::new(camel_api::split_body_lines())
294            .aggregation(AggregationStrategy::CollectAll);
295        let mut svc = SplitterService::new(config, uppercase_pipeline());
296
297        let result = svc
298            .ready()
299            .await
300            .unwrap()
301            .call(make_exchange("a\nb\nc"))
302            .await
303            .unwrap();
304        let expected = serde_json::json!(["A", "B", "C"]);
305        match &result.input.body {
306            Body::Json(v) => assert_eq!(*v, expected),
307            other => panic!("expected JSON body, got {other:?}"),
308        }
309    }
310
311    // ── 3. Sequential + Original ───────────────────────────────────────
312
313    #[tokio::test]
314    async fn test_split_sequential_original() {
315        let config = SplitterConfig::new(camel_api::split_body_lines())
316            .aggregation(AggregationStrategy::Original);
317        let mut svc = SplitterService::new(config, uppercase_pipeline());
318
319        let result = svc
320            .ready()
321            .await
322            .unwrap()
323            .call(make_exchange("a\nb\nc"))
324            .await
325            .unwrap();
326        // Original body should be unchanged.
327        assert_eq!(result.input.body.as_text(), Some("a\nb\nc"));
328    }
329
330    // ── 4. Sequential + Custom aggregation ─────────────────────────────
331
332    #[tokio::test]
333    async fn test_split_sequential_custom_aggregation() {
334        let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
335            Arc::new(|mut acc: Exchange, next: Exchange| {
336                let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
337                let next_text = next.input.body.as_text().unwrap_or("").to_string();
338                acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
339                acc
340            });
341
342        let config = SplitterConfig::new(camel_api::split_body_lines())
343            .aggregation(AggregationStrategy::Custom(joiner));
344        let mut svc = SplitterService::new(config, uppercase_pipeline());
345
346        let result = svc
347            .ready()
348            .await
349            .unwrap()
350            .call(make_exchange("a\nb\nc"))
351            .await
352            .unwrap();
353        assert_eq!(result.input.body.as_text(), Some("A+B+C"));
354    }
355
356    // ── 5. Stop on exception ───────────────────────────────────────────
357
358    #[tokio::test]
359    async fn test_split_stop_on_exception() {
360        // 5 fragments, fail on the 2nd (index 1), stop=true
361        let config = SplitterConfig::new(camel_api::split_body_lines()).stop_on_exception(true);
362        let mut svc = SplitterService::new(config, fail_on_nth(1));
363
364        let result = svc
365            .ready()
366            .await
367            .unwrap()
368            .call(make_exchange("a\nb\nc\nd\ne"))
369            .await;
370
371        // LastWins is default, the last result should be the error from fragment 1.
372        assert!(result.is_err(), "expected error due to stop_on_exception");
373    }
374
375    // ── 6. Continue on exception ───────────────────────────────────────
376
377    #[tokio::test]
378    async fn test_split_continue_on_exception() {
379        // 3 fragments, fail on 2nd (index 1), stop=false, LastWins.
380        let config = SplitterConfig::new(camel_api::split_body_lines())
381            .stop_on_exception(false)
382            .aggregation(AggregationStrategy::LastWins);
383        let mut svc = SplitterService::new(config, fail_on_nth(1));
384
385        let result = svc
386            .ready()
387            .await
388            .unwrap()
389            .call(make_exchange("a\nb\nc"))
390            .await;
391
392        // LastWins: last fragment (index 2) succeeded.
393        assert!(result.is_ok(), "last fragment should succeed");
394    }
395
396    // ── 7. Empty fragments ─────────────────────────────────────────────
397
398    #[tokio::test]
399    async fn test_split_empty_fragments() {
400        // Body::Empty → no fragments → return original unchanged.
401        let config = SplitterConfig::new(camel_api::split_body_lines());
402        let mut svc = SplitterService::new(config, passthrough_pipeline());
403
404        let mut ex = Exchange::new(Message::default()); // Body::Empty
405        ex.set_property("marker", Value::Bool(true));
406
407        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
408        assert!(result.input.body.is_empty());
409        assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
410    }
411
412    // ── 8. Metadata properties ─────────────────────────────────────────
413
414    #[tokio::test]
415    async fn test_split_metadata_properties() {
416        // Use passthrough so we can inspect metadata on returned fragments.
417        // CollectAll won't preserve metadata, so use a pipeline that records
418        // the metadata into the body as JSON.
419        let recorder = BoxProcessor::from_fn(|ex: Exchange| {
420            Box::pin(async move {
421                let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
422                let size = ex.property(CAMEL_SPLIT_SIZE).cloned();
423                let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
424                let body = serde_json::json!({
425                    "index": idx,
426                    "size": size,
427                    "complete": complete,
428                });
429                let mut out = ex;
430                out.input.body = Body::Json(body);
431                Ok(out)
432            })
433        });
434
435        let config = SplitterConfig::new(camel_api::split_body_lines())
436            .aggregation(AggregationStrategy::CollectAll);
437        let mut svc = SplitterService::new(config, recorder);
438
439        let result = svc
440            .ready()
441            .await
442            .unwrap()
443            .call(make_exchange("x\ny\nz"))
444            .await
445            .unwrap();
446
447        let expected = serde_json::json!([
448            {"index": 0, "size": 3, "complete": false},
449            {"index": 1, "size": 3, "complete": false},
450            {"index": 2, "size": 3, "complete": true},
451        ]);
452        match &result.input.body {
453            Body::Json(v) => assert_eq!(*v, expected),
454            other => panic!("expected JSON body, got {other:?}"),
455        }
456    }
457
458    // ── 9. poll_ready delegates to sub-pipeline ────────────────────────
459
460    #[tokio::test]
461    async fn test_poll_ready_delegates_to_sub_pipeline() {
462        use std::sync::atomic::AtomicBool;
463
464        // A service that is initially not ready, then becomes ready.
465        #[derive(Clone)]
466        struct DelayedReady {
467            ready: Arc<AtomicBool>,
468        }
469
470        impl Service<Exchange> for DelayedReady {
471            type Response = Exchange;
472            type Error = CamelError;
473            type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
474
475            fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
476                if self.ready.load(Ordering::SeqCst) {
477                    Poll::Ready(Ok(()))
478                } else {
479                    cx.waker().wake_by_ref();
480                    Poll::Pending
481                }
482            }
483
484            fn call(&mut self, exchange: Exchange) -> Self::Future {
485                Box::pin(async move { Ok(exchange) })
486            }
487        }
488
489        let ready_flag = Arc::new(AtomicBool::new(false));
490        let inner = DelayedReady {
491            ready: Arc::clone(&ready_flag),
492        };
493        let boxed: BoxProcessor = BoxProcessor::new(inner);
494
495        let config = SplitterConfig::new(camel_api::split_body_lines());
496        let mut svc = SplitterService::new(config, boxed);
497
498        // First poll should be Pending.
499        let waker = futures::task::noop_waker();
500        let mut cx = Context::from_waker(&waker);
501        let poll = Pin::new(&mut svc).poll_ready(&mut cx);
502        assert!(
503            poll.is_pending(),
504            "expected Pending when sub_pipeline not ready"
505        );
506
507        // Mark inner as ready.
508        ready_flag.store(true, Ordering::SeqCst);
509
510        let poll = Pin::new(&mut svc).poll_ready(&mut cx);
511        assert!(
512            matches!(poll, Poll::Ready(Ok(()))),
513            "expected Ready after sub_pipeline becomes ready"
514        );
515    }
516
517    // ── 10. Parallel basic ─────────────────────────────────────────────
518
519    #[tokio::test]
520    async fn test_split_parallel_basic() {
521        let config = SplitterConfig::new(camel_api::split_body_lines())
522            .parallel(true)
523            .aggregation(AggregationStrategy::CollectAll);
524        let mut svc = SplitterService::new(config, uppercase_pipeline());
525
526        let result = svc
527            .ready()
528            .await
529            .unwrap()
530            .call(make_exchange("a\nb\nc"))
531            .await
532            .unwrap();
533
534        let expected = serde_json::json!(["A", "B", "C"]);
535        match &result.input.body {
536            Body::Json(v) => assert_eq!(*v, expected),
537            other => panic!("expected JSON body, got {other:?}"),
538        }
539    }
540
541    // ── 11. Parallel with limit ────────────────────────────────────────
542
543    #[tokio::test]
544    async fn test_split_parallel_with_limit() {
545        use std::sync::atomic::AtomicUsize;
546
547        let concurrent = Arc::new(AtomicUsize::new(0));
548        let max_concurrent = Arc::new(AtomicUsize::new(0));
549
550        let c = Arc::clone(&concurrent);
551        let mc = Arc::clone(&max_concurrent);
552        let pipeline = BoxProcessor::from_fn(move |ex: Exchange| {
553            let c = Arc::clone(&c);
554            let mc = Arc::clone(&mc);
555            Box::pin(async move {
556                let current = c.fetch_add(1, Ordering::SeqCst) + 1;
557                // Record the high-water mark.
558                mc.fetch_max(current, Ordering::SeqCst);
559                // Yield to let other tasks run.
560                tokio::task::yield_now().await;
561                c.fetch_sub(1, Ordering::SeqCst);
562                Ok(ex)
563            })
564        });
565
566        let config = SplitterConfig::new(camel_api::split_body_lines())
567            .parallel(true)
568            .parallel_limit(2)
569            .aggregation(AggregationStrategy::CollectAll);
570        let mut svc = SplitterService::new(config, pipeline);
571
572        let result = svc
573            .ready()
574            .await
575            .unwrap()
576            .call(make_exchange("a\nb\nc\nd"))
577            .await;
578        assert!(result.is_ok());
579
580        let observed_max = max_concurrent.load(Ordering::SeqCst);
581        assert!(
582            observed_max <= 2,
583            "max concurrency was {observed_max}, expected <= 2"
584        );
585    }
586
587    // ── 12. Parallel stop on exception ─────────────────────────────────
588
589    #[tokio::test]
590    async fn test_split_parallel_stop_on_exception() {
591        let config = SplitterConfig::new(camel_api::split_body_lines())
592            .parallel(true)
593            .stop_on_exception(true);
594        let mut svc = SplitterService::new(config, failing_pipeline());
595
596        let result = svc
597            .ready()
598            .await
599            .unwrap()
600            .call(make_exchange("a\nb\nc"))
601            .await;
602
603        // All fragments fail; LastWins returns the last error.
604        assert!(result.is_err(), "expected error when all fragments fail");
605    }
606
607    // ── 13. Stream body aggregation creates valid JSON ───────────────────
608
609    #[tokio::test]
610    async fn test_splitter_stream_bodies_creates_valid_json() {
611        use bytes::Bytes;
612        use camel_api::{StreamBody, StreamMetadata};
613        use futures::stream;
614        use tokio::sync::Mutex;
615
616        let chunks = vec![Ok(Bytes::from("test"))];
617        let stream_body = StreamBody {
618            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
619            metadata: StreamMetadata {
620                origin: Some("kafka://topic/partition".to_string()),
621                ..Default::default()
622            },
623        };
624
625        let original = Exchange::new(Message {
626            headers: Default::default(),
627            body: Body::Empty,
628        });
629
630        let results = vec![Ok(Exchange::new(Message {
631            headers: Default::default(),
632            body: Body::Stream(stream_body),
633        }))];
634
635        let result = aggregate(results, original, AggregationStrategy::CollectAll);
636
637        let exchange = result.expect("Expected Ok result");
638        assert!(
639            matches!(exchange.input.body, Body::Json(_)),
640            "Expected Json body"
641        );
642
643        if let Body::Json(value) = exchange.input.body {
644            let json_str = serde_json::to_string(&value).unwrap();
645            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
646
647            assert!(parsed.is_array());
648            let arr = parsed.as_array().unwrap();
649            assert!(arr[0].is_object());
650            assert!(arr[0]["_stream"].is_object());
651            assert_eq!(arr[0]["_stream"]["origin"], "kafka://topic/partition");
652            assert_eq!(arr[0]["_stream"]["placeholder"], true);
653        }
654    }
655
656    #[tokio::test]
657    async fn test_splitter_stream_with_none_origin_creates_valid_json() {
658        use bytes::Bytes;
659        use camel_api::{StreamBody, StreamMetadata};
660        use futures::stream;
661        use tokio::sync::Mutex;
662
663        let chunks = vec![Ok(Bytes::from("test"))];
664        let stream_body = StreamBody {
665            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
666            metadata: StreamMetadata {
667                origin: None,
668                ..Default::default()
669            },
670        };
671
672        let original = Exchange::new(Message {
673            headers: Default::default(),
674            body: Body::Empty,
675        });
676
677        let results = vec![Ok(Exchange::new(Message {
678            headers: Default::default(),
679            body: Body::Stream(stream_body),
680        }))];
681
682        let result = aggregate(results, original, AggregationStrategy::CollectAll);
683
684        let exchange = result.expect("Expected Ok result");
685        assert!(
686            matches!(exchange.input.body, Body::Json(_)),
687            "Expected Json body"
688        );
689
690        if let Body::Json(value) = exchange.input.body {
691            let json_str = serde_json::to_string(&value).unwrap();
692            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
693
694            assert!(parsed.is_array());
695            let arr = parsed.as_array().unwrap();
696            assert!(arr[0].is_object());
697            assert!(arr[0]["_stream"].is_object());
698            assert_eq!(arr[0]["_stream"]["origin"], serde_json::Value::Null);
699            assert_eq!(arr[0]["_stream"]["placeholder"], true);
700        }
701    }
702}