camel_processor/resequencer/
batch.rs1use std::collections::HashMap;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::{Arc, Mutex, Weak};
7use std::time::Duration;
8
9use async_trait::async_trait;
10use camel_api::exchange::Exchange;
11use camel_api::resequencer::BatchCompletion;
12use camel_api::value::cmp_values;
13use camel_language_api::Expression;
14use tokio::sync::mpsc;
15use tokio::task::JoinHandle;
16use tokio_util::sync::CancellationToken;
17
18use super::ResequencePolicy;
19
20#[derive(Default)]
22struct Bucket {
23 exchanges: Vec<Exchange>,
24}
25
26pub struct BatchPolicy {
33 correlation_expr: Arc<dyn Expression>,
34 sort_expr: Arc<dyn Expression>,
35 completion: BatchCompletion,
36
37 weak_self: Weak<Self>,
39
40 buckets: Mutex<HashMap<String, Bucket>>,
42
43 timeout_tokens: Mutex<HashMap<String, CancellationToken>>,
45
46 timeout_handles: Mutex<HashMap<String, JoinHandle<()>>>,
48
49 driver_tx: Mutex<Option<mpsc::Sender<Exchange>>>,
52
53 shutdown_started: AtomicBool,
56}
57
58impl BatchPolicy {
59 pub fn new_cyclic(
62 correlation_expr: Arc<dyn Expression>,
63 sort_expr: Arc<dyn Expression>,
64 completion: BatchCompletion,
65 ) -> Arc<Self> {
66 Arc::new_cyclic(|weak| Self {
67 correlation_expr,
68 sort_expr,
69 completion,
70 weak_self: weak.clone(),
71 buckets: Mutex::new(HashMap::new()),
72 timeout_tokens: Mutex::new(HashMap::new()),
73 timeout_handles: Mutex::new(HashMap::new()),
74 driver_tx: Mutex::new(None),
75 shutdown_started: AtomicBool::new(false),
76 })
77 }
78
79 fn set_driver_tx(&self, tx: mpsc::Sender<Exchange>) {
82 let mut guard = self.driver_tx.lock().unwrap_or_else(|e| e.into_inner());
83 *guard = Some(tx);
84 }
85
86 async fn eval_key(&self, exchange: &Exchange) -> Result<String, String> {
88 self.correlation_expr
89 .evaluate(exchange)
90 .await
91 .map(|v| match v {
94 serde_json::Value::String(s) => s,
95 other => other.to_string(),
96 })
97 .map_err(|e| format!("correlation expression evaluation failed: {e}"))
98 }
99
100 async fn drain_and_sort(&self, mut bucket: Bucket) -> Vec<Exchange> {
102 let mut indexed: Vec<(serde_json::Value, Exchange)> = Vec::new();
103 for ex in bucket.exchanges.drain(..) {
104 let val = self
105 .sort_expr
106 .evaluate(&ex)
107 .await
108 .unwrap_or(serde_json::Value::Null);
109 indexed.push((val, ex));
110 }
111 indexed.sort_by(|a, b| cmp_values(&a.0, &b.0));
112 indexed.into_iter().map(|(_, ex)| ex).collect()
113 }
114
115 fn is_complete_by_size(&self, count: usize) -> bool {
117 match self.completion {
118 BatchCompletion::Size(s) => count >= s,
119 BatchCompletion::Timeout(_) => false,
120 BatchCompletion::SizeOrTimeout(s, _) => count >= s,
121 }
122 }
123
124 fn needs_timeout(&self) -> bool {
126 matches!(
127 self.completion,
128 BatchCompletion::Timeout(_) | BatchCompletion::SizeOrTimeout(..)
129 )
130 }
131
132 fn take_bucket(&self, key: &str) -> Option<Bucket> {
134 let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
135 buckets.remove(key)
136 }
137
138 fn cancel_timeout(&self, key: &str) {
140 {
141 let mut tokens = self
142 .timeout_tokens
143 .lock()
144 .unwrap_or_else(|e| e.into_inner());
145 if let Some(token) = tokens.remove(key) {
146 token.cancel();
147 }
148 }
149 {
150 let mut handles = self
151 .timeout_handles
152 .lock()
153 .unwrap_or_else(|e| e.into_inner());
154 handles.remove(key);
155 }
156 }
157
158 fn spawn_timeout_task(&self, key: String, timeout_ms: u64) {
161 let cancel = CancellationToken::new();
162 let cancel_clone = cancel.clone();
163
164 {
166 let mut tokens = self
167 .timeout_tokens
168 .lock()
169 .unwrap_or_else(|e| e.into_inner());
170 tokens.insert(key.clone(), cancel);
171 }
172
173 let weak = self.weak_self.clone();
174 let key_clone = key.clone();
175 let driver_tx_opt = {
176 let guard = self.driver_tx.lock().unwrap_or_else(|e| e.into_inner());
177 guard.clone()
178 };
179
180 let handle = tokio::spawn(async move {
181 let timeout = Duration::from_millis(timeout_ms);
182
183 tokio::select! {
184 _ = tokio::time::sleep(timeout) => {
185 if cancel_clone.is_cancelled() {
186 return;
187 }
188 }
189 _ = cancel_clone.cancelled() => {
190 return;
191 }
192 }
193
194 let Some(policy) = weak.upgrade() else {
196 return;
197 };
198
199 if policy.shutdown_started.load(Ordering::SeqCst) {
201 return;
202 }
203
204 let bucket = policy.take_bucket(&key_clone);
206 let Some(bucket) = bucket else {
207 return; };
209
210 let sorted = policy.drain_and_sort(bucket).await;
211
212 if let Some(tx) = driver_tx_opt {
214 for ex in sorted {
215 if tx.send(ex).await.is_err() {
216 tracing::debug!(
217 key = %key_clone,
218 "BatchPolicy timeout: driver channel closed during emission"
219 );
220 break;
221 }
222 }
223 }
224
225 {
227 let mut handles = policy
228 .timeout_handles
229 .lock()
230 .unwrap_or_else(|e| e.into_inner());
231 handles.remove(&key_clone);
232 }
233 });
234
235 {
236 let mut handles = self
237 .timeout_handles
238 .lock()
239 .unwrap_or_else(|e| e.into_inner());
240 handles.insert(key, handle);
241 }
242 }
243}
244
245#[async_trait]
246impl ResequencePolicy for BatchPolicy {
247 async fn accept(&self, input: Exchange) -> Vec<Exchange> {
248 let correlation_id = input.correlation_id().to_owned();
249 let key = match self.eval_key(&input).await {
250 Ok(k) => k,
251 Err(e) => {
252 tracing::warn!(
254 error = %e,
255 correlation_id = %correlation_id,
256 "BatchPolicy: correlation expression failed, dropping exchange"
257 );
258 return vec![];
259 }
260 };
261
262 let bucket_count = {
263 let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
264 let bucket = buckets.entry(key.clone()).or_default();
265 bucket.exchanges.push(input);
266 bucket.exchanges.len()
267 };
268
269 if bucket_count == 1 && self.needs_timeout() {
271 let timeout_ms = match self.completion {
272 BatchCompletion::Timeout(t) | BatchCompletion::SizeOrTimeout(_, t) => t,
273 _ => unreachable!(),
274 };
275 self.spawn_timeout_task(key.clone(), timeout_ms);
276 }
277
278 if self.is_complete_by_size(bucket_count) {
280 self.cancel_timeout(&key);
281 if let Some(bucket) = self.take_bucket(&key) {
282 return self.drain_and_sort(bucket).await;
283 }
284 }
285
286 vec![]
287 }
288
289 async fn flush(&self) -> Vec<Exchange> {
290 self.shutdown_started.store(true, Ordering::SeqCst);
292
293 let all_keys: Vec<String> = {
294 let buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
295 buckets.keys().cloned().collect()
296 };
297
298 let mut all_sorted = Vec::new();
299 for key in &all_keys {
300 self.cancel_timeout(key);
301 if let Some(bucket) = self.take_bucket(key) {
302 let sorted = self.drain_and_sort(bucket).await;
303 all_sorted.extend(sorted);
304 }
305 }
306
307 {
309 let tokens: HashMap<String, CancellationToken> = {
310 let mut guard = self
311 .timeout_tokens
312 .lock()
313 .unwrap_or_else(|e| e.into_inner());
314 std::mem::take(&mut *guard)
315 };
316 for (_, token) in tokens {
317 token.cancel();
318 }
319 }
320 {
322 let _handles = {
323 let mut guard = self
324 .timeout_handles
325 .lock()
326 .unwrap_or_else(|e| e.into_inner());
327 std::mem::take(&mut *guard)
328 };
329 }
330
331 all_sorted
332 }
333
334 fn name(&self) -> &'static str {
335 "batch-resequencer"
336 }
337
338 fn set_timeout_tx(&self, tx: tokio::sync::mpsc::Sender<Exchange>) {
339 self.set_driver_tx(tx);
340 }
341}
342
343#[cfg(test)]
346mod tests {
347 use super::*;
348 use camel_api::exchange::ExchangePattern;
349 use camel_api::message::Message;
350
351 struct PropExpr(String);
353
354 #[async_trait::async_trait]
355 impl Expression for PropExpr {
356 async fn evaluate(
357 &self,
358 exchange: &Exchange,
359 ) -> Result<serde_json::Value, camel_language_api::LanguageError> {
360 Ok(exchange
361 .property(&self.0)
362 .cloned()
363 .unwrap_or(serde_json::Value::Null))
364 }
365 }
366
367 struct ConstExpr(String);
369
370 #[async_trait::async_trait]
371 impl Expression for ConstExpr {
372 async fn evaluate(
373 &self,
374 _exchange: &Exchange,
375 ) -> Result<serde_json::Value, camel_language_api::LanguageError> {
376 Ok(serde_json::Value::String(self.0.clone()))
377 }
378 }
379
380 struct FailingExpr;
382
383 #[async_trait::async_trait]
384 impl Expression for FailingExpr {
385 async fn evaluate(
386 &self,
387 _exchange: &Exchange,
388 ) -> Result<serde_json::Value, camel_language_api::LanguageError> {
389 Err(camel_language_api::LanguageError::EvalError(
390 "mock eval failure".into(),
391 ))
392 }
393 }
394
395 fn mk_exchange(seq: i64) -> Exchange {
396 let mut ex = Exchange::new(Message::new(camel_api::body::Body::Text(format!(
397 "msg-{seq}"
398 ))));
399 ex.set_property("seq", serde_json::json!(seq));
400 ex.pattern = ExchangePattern::InOnly;
401 ex
402 }
403
404 fn mk_exchange_with_key(seq: i64, key_prop: &str, key_val: &str) -> Exchange {
405 let mut ex = Exchange::new(Message::new(camel_api::body::Body::Text(format!(
406 "msg-{seq}"
407 ))));
408 ex.set_property("seq", serde_json::json!(seq));
409 ex.set_property(key_prop, serde_json::Value::String(key_val.to_string()));
410 ex.pattern = ExchangePattern::InOnly;
411 ex
412 }
413
414 #[tokio::test]
417 async fn batch_size_completion_emits_sorted_burst() {
418 let policy = BatchPolicy::new_cyclic(
419 Arc::new(ConstExpr("same".into())),
420 Arc::new(PropExpr("seq".into())),
421 BatchCompletion::Size(3),
422 );
423
424 assert!(policy.accept(mk_exchange(3)).await.is_empty());
425 assert!(policy.accept(mk_exchange(1)).await.is_empty());
426
427 let emitted = policy.accept(mk_exchange(2)).await;
428 assert_eq!(emitted.len(), 3, "should emit all 3 on completion");
429 let seqs: Vec<i64> = emitted
430 .iter()
431 .map(|ex| ex.property("seq").and_then(|v| v.as_i64()).unwrap_or(-1))
432 .collect();
433 assert_eq!(seqs, vec![1, 2, 3], "should be sorted ascending");
434 }
435
436 #[tokio::test]
439 async fn batch_timeout_completion_emits_after_timeout() {
440 let policy = BatchPolicy::new_cyclic(
441 Arc::new(ConstExpr("same".into())),
442 Arc::new(PropExpr("seq".into())),
443 BatchCompletion::Timeout(50),
444 );
445
446 let (tx, mut rx) = mpsc::channel::<Exchange>(16);
447 policy.set_driver_tx(tx);
448
449 assert!(policy.accept(mk_exchange(3)).await.is_empty());
450 assert!(policy.accept(mk_exchange(1)).await.is_empty());
451
452 let emitted: Vec<Exchange> = tokio::time::timeout(Duration::from_millis(500), async {
453 let mut out = Vec::new();
454 out.push(rx.recv().await.unwrap());
455 out.push(rx.recv().await.unwrap());
456 out
457 })
458 .await
459 .expect("timeout should fire within 500ms");
460
461 assert_eq!(emitted.len(), 2);
462 let seqs: Vec<i64> = emitted
463 .iter()
464 .map(|ex| ex.property("seq").and_then(|v| v.as_i64()).unwrap_or(-1))
465 .collect();
466 assert_eq!(seqs, vec![1, 3], "should be sorted ascending");
467 }
468
469 #[tokio::test]
471 async fn batch_size_or_timeout_size_wins() {
472 let policy = BatchPolicy::new_cyclic(
473 Arc::new(ConstExpr("same".into())),
474 Arc::new(PropExpr("seq".into())),
475 BatchCompletion::SizeOrTimeout(3, 5_000),
476 );
477
478 assert!(policy.accept(mk_exchange(2)).await.is_empty());
479 assert!(policy.accept(mk_exchange(1)).await.is_empty());
480
481 let emitted = policy.accept(mk_exchange(3)).await;
482 assert_eq!(emitted.len(), 3);
483 let seqs: Vec<i64> = emitted
484 .iter()
485 .map(|ex| ex.property("seq").and_then(|v| v.as_i64()).unwrap_or(-1))
486 .collect();
487 assert_eq!(seqs, vec![1, 2, 3]);
488 }
489
490 #[tokio::test]
492 async fn batch_multi_key_independence() {
493 let policy = BatchPolicy::new_cyclic(
494 Arc::new(PropExpr("region".into())),
495 Arc::new(PropExpr("seq".into())),
496 BatchCompletion::Size(2),
497 );
498
499 let _ = policy
500 .accept(mk_exchange_with_key(2, "region", "east"))
501 .await;
502 let east_emit = policy
503 .accept(mk_exchange_with_key(1, "region", "east"))
504 .await;
505 assert_eq!(east_emit.len(), 2, "east bucket should complete at size 2");
506
507 let west_result = policy
508 .accept(mk_exchange_with_key(3, "region", "west"))
509 .await;
510 assert!(
511 west_result.is_empty(),
512 "west bucket should NOT complete yet"
513 );
514 }
515
516 #[tokio::test]
519 async fn batch_flush_emits_remaining_sorted() {
520 let policy = BatchPolicy::new_cyclic(
521 Arc::new(ConstExpr("same".into())),
522 Arc::new(PropExpr("seq".into())),
523 BatchCompletion::Size(10),
524 );
525
526 assert!(policy.accept(mk_exchange(5)).await.is_empty());
527 assert!(policy.accept(mk_exchange(3)).await.is_empty());
528 assert!(policy.accept(mk_exchange(1)).await.is_empty());
529
530 let flushed = policy.flush().await;
531 assert_eq!(flushed.len(), 3);
532 let seqs: Vec<i64> = flushed
533 .iter()
534 .map(|ex| ex.property("seq").and_then(|v| v.as_i64()).unwrap_or(-1))
535 .collect();
536 assert_eq!(seqs, vec![1, 3, 5]);
537 }
538
539 #[tokio::test]
542 async fn batch_correlation_eval_failure_returns_empty() {
543 let policy = BatchPolicy::new_cyclic(
544 Arc::new(FailingExpr),
545 Arc::new(PropExpr("seq".into())),
546 BatchCompletion::Size(2),
547 );
548
549 let result = policy.accept(mk_exchange(1)).await;
550 assert!(
551 result.is_empty(),
552 "failed correlation should return empty vec, not crash"
553 );
554 }
555
556 #[tokio::test]
558 async fn batch_pure_size_no_timeout_needed() {
559 let policy = BatchPolicy::new_cyclic(
560 Arc::new(ConstExpr("same".into())),
561 Arc::new(PropExpr("seq".into())),
562 BatchCompletion::Size(2),
563 );
564
565 assert!(!policy.needs_timeout());
566 }
567}