1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::future::join_all;
6use tokio::sync::Semaphore;
7use tower::Service;
8
9use camel_api::{
10 AggregationStrategy, Body, BoxProcessor, CamelError, Exchange, SplitterConfig, Value,
11};
12
13pub const CAMEL_SPLIT_INDEX: &str = "CamelSplitIndex";
17pub const CAMEL_SPLIT_SIZE: &str = "CamelSplitSize";
19pub const CAMEL_SPLIT_COMPLETE: &str = "CamelSplitComplete";
21
22#[derive(Clone)]
33pub struct SplitterService {
34 expression: camel_api::SplitExpression,
35 sub_pipeline: BoxProcessor,
36 aggregation: AggregationStrategy,
37 parallel: bool,
38 parallel_limit: Option<usize>,
39 stop_on_exception: bool,
40}
41
42impl SplitterService {
43 pub fn new(config: SplitterConfig, sub_pipeline: BoxProcessor) -> Self {
45 Self {
46 expression: config.expression,
47 sub_pipeline,
48 aggregation: config.aggregation,
49 parallel: config.parallel,
50 parallel_limit: config.parallel_limit,
51 stop_on_exception: config.stop_on_exception,
52 }
53 }
54}
55
56impl Service<Exchange> for SplitterService {
57 type Response = Exchange;
58 type Error = CamelError;
59 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
60
61 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62 self.sub_pipeline.poll_ready(cx)
63 }
64
65 fn call(&mut self, exchange: Exchange) -> Self::Future {
66 let original = exchange.clone();
67 let expression = self.expression.clone();
68 let sub_pipeline = self.sub_pipeline.clone();
69 let aggregation = self.aggregation.clone();
70 let parallel = self.parallel;
71 let parallel_limit = self.parallel_limit;
72 let stop_on_exception = self.stop_on_exception;
73
74 Box::pin(async move {
75 let mut fragments = expression(&exchange);
77
78 if fragments.is_empty() {
80 return Ok(original);
81 }
82
83 let total = fragments.len();
84
85 for (i, frag) in fragments.iter_mut().enumerate() {
87 frag.set_property(CAMEL_SPLIT_INDEX, Value::from(i as u64));
88 frag.set_property(CAMEL_SPLIT_SIZE, Value::from(total as u64));
89 frag.set_property(CAMEL_SPLIT_COMPLETE, Value::Bool(i == total - 1));
90 }
91
92 let results = if parallel {
94 process_parallel(fragments, sub_pipeline, parallel_limit, stop_on_exception).await
95 } else {
96 process_sequential(fragments, sub_pipeline, stop_on_exception).await
97 };
98
99 aggregate(results, original, aggregation)
101 })
102 }
103}
104
105async fn process_sequential(
108 fragments: Vec<Exchange>,
109 sub_pipeline: BoxProcessor,
110 stop_on_exception: bool,
111) -> Vec<Result<Exchange, CamelError>> {
112 let mut results = Vec::with_capacity(fragments.len());
113
114 for fragment in fragments {
115 let mut pipeline = sub_pipeline.clone();
116 match tower::ServiceExt::ready(&mut pipeline).await {
117 Err(e) => {
118 results.push(Err(e));
119 if stop_on_exception {
120 break;
121 }
122 }
123 Ok(svc) => {
124 let result = svc.call(fragment).await;
125 let is_err = result.is_err();
126 results.push(result);
127 if stop_on_exception && is_err {
128 break;
129 }
130 }
131 }
132 }
133
134 results
135}
136
137async fn process_parallel(
140 fragments: Vec<Exchange>,
141 sub_pipeline: BoxProcessor,
142 parallel_limit: Option<usize>,
143 _stop_on_exception: bool,
144) -> Vec<Result<Exchange, CamelError>> {
145 let semaphore = parallel_limit.map(|limit| std::sync::Arc::new(Semaphore::new(limit)));
146
147 let futures: Vec<_> = fragments
148 .into_iter()
149 .map(|fragment| {
150 let mut pipeline = sub_pipeline.clone();
151 let sem = semaphore.clone();
152 async move {
153 let _permit = match &sem {
155 Some(s) => Some(s.acquire().await.map_err(|e| {
156 CamelError::ProcessorError(format!("semaphore error: {e}"))
157 })?),
158 None => None,
159 };
160
161 tower::ServiceExt::ready(&mut pipeline).await?;
162 pipeline.call(fragment).await
163 }
164 })
165 .collect();
166
167 join_all(futures).await
168}
169
170fn aggregate(
173 results: Vec<Result<Exchange, CamelError>>,
174 original: Exchange,
175 strategy: AggregationStrategy,
176) -> Result<Exchange, CamelError> {
177 match strategy {
178 AggregationStrategy::LastWins => {
179 results.into_iter().last().unwrap_or_else(|| Ok(original))
181 }
182 AggregationStrategy::CollectAll => {
183 let mut bodies = Vec::new();
185 for result in results {
186 let ex = result?;
187 let value = match &ex.input.body {
188 Body::Text(s) => Value::String(s.clone()),
189 Body::Json(v) => v.clone(),
190 Body::Bytes(b) => Value::String(String::from_utf8_lossy(b).into_owned()),
191 Body::Empty => Value::Null,
192 };
193 bodies.push(value);
194 }
195 let mut out = original;
196 out.input.body = Body::Json(Value::Array(bodies));
197 Ok(out)
198 }
199 AggregationStrategy::Original => Ok(original),
200 AggregationStrategy::Custom(fold_fn) => {
201 let mut iter = results.into_iter();
203 let first = iter.next().unwrap_or_else(|| Ok(original.clone()))?;
204 iter.try_fold(first, |acc, next_result| {
205 let next = next_result?;
206 Ok(fold_fn(acc, next))
207 })
208 }
209 }
210}
211
212#[cfg(test)]
215mod tests {
216 use super::*;
217 use camel_api::{BoxProcessorExt, Message};
218 use std::sync::Arc;
219 use std::sync::atomic::{AtomicUsize, Ordering};
220 use tower::ServiceExt;
221
222 fn passthrough_pipeline() -> BoxProcessor {
225 BoxProcessor::from_fn(|ex| Box::pin(async move { Ok(ex) }))
226 }
227
228 fn uppercase_pipeline() -> BoxProcessor {
229 BoxProcessor::from_fn(|mut ex: Exchange| {
230 Box::pin(async move {
231 if let Body::Text(s) = &ex.input.body {
232 ex.input.body = Body::Text(s.to_uppercase());
233 }
234 Ok(ex)
235 })
236 })
237 }
238
239 fn failing_pipeline() -> BoxProcessor {
240 BoxProcessor::from_fn(|_ex| {
241 Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
242 })
243 }
244
245 fn fail_on_nth(n: usize) -> BoxProcessor {
246 let count = Arc::new(AtomicUsize::new(0));
247 BoxProcessor::from_fn(move |ex: Exchange| {
248 let count = Arc::clone(&count);
249 Box::pin(async move {
250 let c = count.fetch_add(1, Ordering::SeqCst);
251 if c == n {
252 Err(CamelError::ProcessorError(format!("fail on {c}")))
253 } else {
254 Ok(ex)
255 }
256 })
257 })
258 }
259
260 fn make_exchange(text: &str) -> Exchange {
261 Exchange::new(Message::new(text))
262 }
263
264 #[tokio::test]
267 async fn test_split_sequential_last_wins() {
268 let config = SplitterConfig::new(camel_api::split_body_lines())
269 .aggregation(AggregationStrategy::LastWins);
270 let mut svc = SplitterService::new(config, uppercase_pipeline());
271
272 let result = svc
273 .ready()
274 .await
275 .unwrap()
276 .call(make_exchange("a\nb\nc"))
277 .await
278 .unwrap();
279 assert_eq!(result.input.body.as_text(), Some("C"));
280 }
281
282 #[tokio::test]
285 async fn test_split_sequential_collect_all() {
286 let config = SplitterConfig::new(camel_api::split_body_lines())
287 .aggregation(AggregationStrategy::CollectAll);
288 let mut svc = SplitterService::new(config, uppercase_pipeline());
289
290 let result = svc
291 .ready()
292 .await
293 .unwrap()
294 .call(make_exchange("a\nb\nc"))
295 .await
296 .unwrap();
297 let expected = serde_json::json!(["A", "B", "C"]);
298 match &result.input.body {
299 Body::Json(v) => assert_eq!(*v, expected),
300 other => panic!("expected JSON body, got {other:?}"),
301 }
302 }
303
304 #[tokio::test]
307 async fn test_split_sequential_original() {
308 let config = SplitterConfig::new(camel_api::split_body_lines())
309 .aggregation(AggregationStrategy::Original);
310 let mut svc = SplitterService::new(config, uppercase_pipeline());
311
312 let result = svc
313 .ready()
314 .await
315 .unwrap()
316 .call(make_exchange("a\nb\nc"))
317 .await
318 .unwrap();
319 assert_eq!(result.input.body.as_text(), Some("a\nb\nc"));
321 }
322
323 #[tokio::test]
326 async fn test_split_sequential_custom_aggregation() {
327 let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
328 Arc::new(|mut acc: Exchange, next: Exchange| {
329 let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
330 let next_text = next.input.body.as_text().unwrap_or("").to_string();
331 acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
332 acc
333 });
334
335 let config = SplitterConfig::new(camel_api::split_body_lines())
336 .aggregation(AggregationStrategy::Custom(joiner));
337 let mut svc = SplitterService::new(config, uppercase_pipeline());
338
339 let result = svc
340 .ready()
341 .await
342 .unwrap()
343 .call(make_exchange("a\nb\nc"))
344 .await
345 .unwrap();
346 assert_eq!(result.input.body.as_text(), Some("A+B+C"));
347 }
348
349 #[tokio::test]
352 async fn test_split_stop_on_exception() {
353 let config = SplitterConfig::new(camel_api::split_body_lines()).stop_on_exception(true);
355 let mut svc = SplitterService::new(config, fail_on_nth(1));
356
357 let result = svc
358 .ready()
359 .await
360 .unwrap()
361 .call(make_exchange("a\nb\nc\nd\ne"))
362 .await;
363
364 assert!(result.is_err(), "expected error due to stop_on_exception");
366 }
367
368 #[tokio::test]
371 async fn test_split_continue_on_exception() {
372 let config = SplitterConfig::new(camel_api::split_body_lines())
374 .stop_on_exception(false)
375 .aggregation(AggregationStrategy::LastWins);
376 let mut svc = SplitterService::new(config, fail_on_nth(1));
377
378 let result = svc
379 .ready()
380 .await
381 .unwrap()
382 .call(make_exchange("a\nb\nc"))
383 .await;
384
385 assert!(result.is_ok(), "last fragment should succeed");
387 }
388
389 #[tokio::test]
392 async fn test_split_empty_fragments() {
393 let config = SplitterConfig::new(camel_api::split_body_lines());
395 let mut svc = SplitterService::new(config, passthrough_pipeline());
396
397 let mut ex = Exchange::new(Message::default()); ex.set_property("marker", Value::Bool(true));
399
400 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
401 assert!(result.input.body.is_empty());
402 assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
403 }
404
405 #[tokio::test]
408 async fn test_split_metadata_properties() {
409 let recorder = BoxProcessor::from_fn(|ex: Exchange| {
413 Box::pin(async move {
414 let idx = ex.property(CAMEL_SPLIT_INDEX).cloned();
415 let size = ex.property(CAMEL_SPLIT_SIZE).cloned();
416 let complete = ex.property(CAMEL_SPLIT_COMPLETE).cloned();
417 let body = serde_json::json!({
418 "index": idx,
419 "size": size,
420 "complete": complete,
421 });
422 let mut out = ex;
423 out.input.body = Body::Json(body);
424 Ok(out)
425 })
426 });
427
428 let config = SplitterConfig::new(camel_api::split_body_lines())
429 .aggregation(AggregationStrategy::CollectAll);
430 let mut svc = SplitterService::new(config, recorder);
431
432 let result = svc
433 .ready()
434 .await
435 .unwrap()
436 .call(make_exchange("x\ny\nz"))
437 .await
438 .unwrap();
439
440 let expected = serde_json::json!([
441 {"index": 0, "size": 3, "complete": false},
442 {"index": 1, "size": 3, "complete": false},
443 {"index": 2, "size": 3, "complete": true},
444 ]);
445 match &result.input.body {
446 Body::Json(v) => assert_eq!(*v, expected),
447 other => panic!("expected JSON body, got {other:?}"),
448 }
449 }
450
451 #[tokio::test]
454 async fn test_poll_ready_delegates_to_sub_pipeline() {
455 use std::sync::atomic::AtomicBool;
456
457 #[derive(Clone)]
459 struct DelayedReady {
460 ready: Arc<AtomicBool>,
461 }
462
463 impl Service<Exchange> for DelayedReady {
464 type Response = Exchange;
465 type Error = CamelError;
466 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
467
468 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
469 if self.ready.load(Ordering::SeqCst) {
470 Poll::Ready(Ok(()))
471 } else {
472 cx.waker().wake_by_ref();
473 Poll::Pending
474 }
475 }
476
477 fn call(&mut self, exchange: Exchange) -> Self::Future {
478 Box::pin(async move { Ok(exchange) })
479 }
480 }
481
482 let ready_flag = Arc::new(AtomicBool::new(false));
483 let inner = DelayedReady {
484 ready: Arc::clone(&ready_flag),
485 };
486 let boxed: BoxProcessor = BoxProcessor::new(inner);
487
488 let config = SplitterConfig::new(camel_api::split_body_lines());
489 let mut svc = SplitterService::new(config, boxed);
490
491 let waker = futures::task::noop_waker();
493 let mut cx = Context::from_waker(&waker);
494 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
495 assert!(
496 poll.is_pending(),
497 "expected Pending when sub_pipeline not ready"
498 );
499
500 ready_flag.store(true, Ordering::SeqCst);
502
503 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
504 assert!(
505 matches!(poll, Poll::Ready(Ok(()))),
506 "expected Ready after sub_pipeline becomes ready"
507 );
508 }
509
510 #[tokio::test]
513 async fn test_split_parallel_basic() {
514 let config = SplitterConfig::new(camel_api::split_body_lines())
515 .parallel(true)
516 .aggregation(AggregationStrategy::CollectAll);
517 let mut svc = SplitterService::new(config, uppercase_pipeline());
518
519 let result = svc
520 .ready()
521 .await
522 .unwrap()
523 .call(make_exchange("a\nb\nc"))
524 .await
525 .unwrap();
526
527 let expected = serde_json::json!(["A", "B", "C"]);
528 match &result.input.body {
529 Body::Json(v) => assert_eq!(*v, expected),
530 other => panic!("expected JSON body, got {other:?}"),
531 }
532 }
533
534 #[tokio::test]
537 async fn test_split_parallel_with_limit() {
538 use std::sync::atomic::AtomicUsize;
539
540 let concurrent = Arc::new(AtomicUsize::new(0));
541 let max_concurrent = Arc::new(AtomicUsize::new(0));
542
543 let c = Arc::clone(&concurrent);
544 let mc = Arc::clone(&max_concurrent);
545 let pipeline = BoxProcessor::from_fn(move |ex: Exchange| {
546 let c = Arc::clone(&c);
547 let mc = Arc::clone(&mc);
548 Box::pin(async move {
549 let current = c.fetch_add(1, Ordering::SeqCst) + 1;
550 mc.fetch_max(current, Ordering::SeqCst);
552 tokio::task::yield_now().await;
554 c.fetch_sub(1, Ordering::SeqCst);
555 Ok(ex)
556 })
557 });
558
559 let config = SplitterConfig::new(camel_api::split_body_lines())
560 .parallel(true)
561 .parallel_limit(2)
562 .aggregation(AggregationStrategy::CollectAll);
563 let mut svc = SplitterService::new(config, pipeline);
564
565 let result = svc
566 .ready()
567 .await
568 .unwrap()
569 .call(make_exchange("a\nb\nc\nd"))
570 .await;
571 assert!(result.is_ok());
572
573 let observed_max = max_concurrent.load(Ordering::SeqCst);
574 assert!(
575 observed_max <= 2,
576 "max concurrency was {observed_max}, expected <= 2"
577 );
578 }
579
580 #[tokio::test]
583 async fn test_split_parallel_stop_on_exception() {
584 let config = SplitterConfig::new(camel_api::split_body_lines())
585 .parallel(true)
586 .stop_on_exception(true);
587 let mut svc = SplitterService::new(config, failing_pipeline());
588
589 let result = svc
590 .ready()
591 .await
592 .unwrap()
593 .call(make_exchange("a\nb\nc"))
594 .await;
595
596 assert!(result.is_err(), "expected error when all fragments fail");
598 }
599}