Skip to main content

camel_processor/resequencer/
batch.rs

1//! Batch resequencing policy — buffer per correlation key, window completion,
2//! sort by expression, burst-emit in order.
3
4use std::collections::HashMap;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::{Arc, Mutex, Weak};
7use std::time::Duration;
8
9use async_trait::async_trait;
10use camel_api::exchange::Exchange;
11use camel_api::resequencer::BatchCompletion;
12use camel_api::value::cmp_values;
13use camel_language_api::Expression;
14use tokio::sync::mpsc;
15use tokio::task::JoinHandle;
16use tokio_util::sync::CancellationToken;
17
18use super::ResequencePolicy;
19
20/// Per-correlation-key bucket holding pending exchanges.
21#[derive(Default)]
22struct Bucket {
23    exchanges: Vec<Exchange>,
24}
25
26/// Batch resequencing policy.
27///
28/// Buffers exchanges per correlation key. Completion is triggered by
29/// window (size and/or timeout). On completion, sorts buffered exchanges
30/// by `sort_expr` and returns them as a burst. Timeout tasks hold a
31/// `Weak<Self>` reference obtained via `Arc::new_cyclic`.
32pub struct BatchPolicy {
33    correlation_expr: Arc<dyn Expression>,
34    sort_expr: Arc<dyn Expression>,
35    completion: BatchCompletion,
36
37    /// Weak self-reference so timeout tasks can upgrade to `Arc<Self>`.
38    weak_self: Weak<Self>,
39
40    /// Per-correlation-key buckets (exchanges pending completion).
41    buckets: Mutex<HashMap<String, Bucket>>,
42
43    /// Timeout cancellation tokens, keyed by correlation key.
44    timeout_tokens: Mutex<HashMap<String, CancellationToken>>,
45
46    /// Timeout task handles, keyed by correlation key.
47    timeout_handles: Mutex<HashMap<String, JoinHandle<()>>>,
48
49    /// Channel to the post-driver for timeout-triggered emissions.
50    /// Set by `ResequencerService` after channel creation.
51    driver_tx: Mutex<Option<mpsc::Sender<Exchange>>>,
52
53    /// Shutdown guard — timeout tasks check this before sending
54    /// to avoid racing with post-driver channel close (M7).
55    shutdown_started: AtomicBool,
56}
57
58impl BatchPolicy {
59    /// Create a new `Arc<BatchPolicy>` using `Arc::new_cyclic` so the
60    /// policy holds a `Weak<Self>` for timeout task spawning.
61    pub fn new_cyclic(
62        correlation_expr: Arc<dyn Expression>,
63        sort_expr: Arc<dyn Expression>,
64        completion: BatchCompletion,
65    ) -> Arc<Self> {
66        Arc::new_cyclic(|weak| Self {
67            correlation_expr,
68            sort_expr,
69            completion,
70            weak_self: weak.clone(),
71            buckets: Mutex::new(HashMap::new()),
72            timeout_tokens: Mutex::new(HashMap::new()),
73            timeout_handles: Mutex::new(HashMap::new()),
74            driver_tx: Mutex::new(None),
75            shutdown_started: AtomicBool::new(false),
76        })
77    }
78
79    /// Set the driver channel (via `set_timeout_tx` trait method).
80    /// Called by `ResequencerService` after channel creation.
81    fn set_driver_tx(&self, tx: mpsc::Sender<Exchange>) {
82        let mut guard = self.driver_tx.lock().unwrap_or_else(|e| e.into_inner());
83        *guard = Some(tx);
84    }
85
86    /// Evaluate the correlation expression against an exchange.
87    async fn eval_key(&self, exchange: &Exchange) -> Result<String, String> {
88        self.correlation_expr
89            .evaluate(exchange)
90            .await
91            // M4: avoid double-quoting for string values — use as_str() for
92            // strings, fall back to to_string() for other types.
93            .map(|v| match v {
94                serde_json::Value::String(s) => s,
95                other => other.to_string(),
96            })
97            .map_err(|e| format!("correlation expression evaluation failed: {e}"))
98    }
99
100    /// Drain a bucket, sort by sort_expr, return sorted Vec.
101    async fn drain_and_sort(&self, mut bucket: Bucket) -> Vec<Exchange> {
102        let mut indexed: Vec<(serde_json::Value, Exchange)> = Vec::new();
103        for ex in bucket.exchanges.drain(..) {
104            let val = self
105                .sort_expr
106                .evaluate(&ex)
107                .await
108                .unwrap_or(serde_json::Value::Null);
109            indexed.push((val, ex));
110        }
111        indexed.sort_by(|a, b| cmp_values(&a.0, &b.0));
112        indexed.into_iter().map(|(_, ex)| ex).collect()
113    }
114
115    /// Check if a bucket count satisfies the size-based completion condition.
116    fn is_complete_by_size(&self, count: usize) -> bool {
117        match self.completion {
118            BatchCompletion::Size(s) => count >= s,
119            BatchCompletion::Timeout(_) => false,
120            BatchCompletion::SizeOrTimeout(s, _) => count >= s,
121        }
122    }
123
124    /// Whether this completion variant needs timeout tasks spawned.
125    fn needs_timeout(&self) -> bool {
126        matches!(
127            self.completion,
128            BatchCompletion::Timeout(_) | BatchCompletion::SizeOrTimeout(..)
129        )
130    }
131
132    /// Take a bucket by key. Returns `Some(Bucket)` if it existed.
133    fn take_bucket(&self, key: &str) -> Option<Bucket> {
134        let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
135        buckets.remove(key)
136    }
137
138    /// Cancel and remove timeout task for a key.
139    fn cancel_timeout(&self, key: &str) {
140        {
141            let mut tokens = self
142                .timeout_tokens
143                .lock()
144                .unwrap_or_else(|e| e.into_inner());
145            if let Some(token) = tokens.remove(key) {
146                token.cancel();
147            }
148        }
149        {
150            let mut handles = self
151                .timeout_handles
152                .lock()
153                .unwrap_or_else(|e| e.into_inner());
154            handles.remove(key);
155        }
156    }
157
158    /// Spawn a timeout task for the given key.
159    /// Must be called from a method that has access to `&self` (which has the `weak_self`).
160    fn spawn_timeout_task(&self, key: String, timeout_ms: u64) {
161        let cancel = CancellationToken::new();
162        let cancel_clone = cancel.clone();
163
164        // Store the cancellation token
165        {
166            let mut tokens = self
167                .timeout_tokens
168                .lock()
169                .unwrap_or_else(|e| e.into_inner());
170            tokens.insert(key.clone(), cancel);
171        }
172
173        let weak = self.weak_self.clone();
174        let key_clone = key.clone();
175        let driver_tx_opt = {
176            let guard = self.driver_tx.lock().unwrap_or_else(|e| e.into_inner());
177            guard.clone()
178        };
179
180        let handle = tokio::spawn(async move {
181            let timeout = Duration::from_millis(timeout_ms);
182
183            tokio::select! {
184                _ = tokio::time::sleep(timeout) => {
185                    if cancel_clone.is_cancelled() {
186                        return;
187                    }
188                }
189                _ = cancel_clone.cancelled() => {
190                    return;
191                }
192            }
193
194            // Upgrade the weak reference — policy may have been dropped (shutdown)
195            let Some(policy) = weak.upgrade() else {
196                return;
197            };
198
199            // M7: don't send if shutdown has started (driver channel may already be closed)
200            if policy.shutdown_started.load(Ordering::SeqCst) {
201                return;
202            }
203
204            // Drain the bucket
205            let bucket = policy.take_bucket(&key_clone);
206            let Some(bucket) = bucket else {
207                return; // bucket already drained by size-based completion
208            };
209
210            let sorted = policy.drain_and_sort(bucket).await;
211
212            // Send via driver channel
213            if let Some(tx) = driver_tx_opt {
214                for ex in sorted {
215                    if tx.send(ex).await.is_err() {
216                        tracing::debug!(
217                            key = %key_clone,
218                            "BatchPolicy timeout: driver channel closed during emission"
219                        );
220                        break;
221                    }
222                }
223            }
224
225            // Clean up handle entry
226            {
227                let mut handles = policy
228                    .timeout_handles
229                    .lock()
230                    .unwrap_or_else(|e| e.into_inner());
231                handles.remove(&key_clone);
232            }
233        });
234
235        {
236            let mut handles = self
237                .timeout_handles
238                .lock()
239                .unwrap_or_else(|e| e.into_inner());
240            handles.insert(key, handle);
241        }
242    }
243}
244
245#[async_trait]
246impl ResequencePolicy for BatchPolicy {
247    async fn accept(&self, input: Exchange) -> Vec<Exchange> {
248        let correlation_id = input.correlation_id().to_owned();
249        let key = match self.eval_key(&input).await {
250            Ok(k) => k,
251            Err(e) => {
252                // log-policy: handler-owned
253                tracing::warn!(
254                    error = %e,
255                    correlation_id = %correlation_id,
256                    "BatchPolicy: correlation expression failed, dropping exchange"
257                );
258                return vec![];
259            }
260        };
261
262        let bucket_count = {
263            let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
264            let bucket = buckets.entry(key.clone()).or_default();
265            bucket.exchanges.push(input);
266            bucket.exchanges.len()
267        };
268
269        // Spawn timeout task if needed (first exchange for this key)
270        if bucket_count == 1 && self.needs_timeout() {
271            let timeout_ms = match self.completion {
272                BatchCompletion::Timeout(t) | BatchCompletion::SizeOrTimeout(_, t) => t,
273                _ => unreachable!(),
274            };
275            self.spawn_timeout_task(key.clone(), timeout_ms);
276        }
277
278        // Check if the bucket is complete (size-based)
279        if self.is_complete_by_size(bucket_count) {
280            self.cancel_timeout(&key);
281            if let Some(bucket) = self.take_bucket(&key) {
282                return self.drain_and_sort(bucket).await;
283            }
284        }
285
286        vec![]
287    }
288
289    async fn flush(&self) -> Vec<Exchange> {
290        // M7: signal timeout tasks that shutdown is in progress
291        self.shutdown_started.store(true, Ordering::SeqCst);
292
293        let all_keys: Vec<String> = {
294            let buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
295            buckets.keys().cloned().collect()
296        };
297
298        let mut all_sorted = Vec::new();
299        for key in &all_keys {
300            self.cancel_timeout(key);
301            if let Some(bucket) = self.take_bucket(key) {
302                let sorted = self.drain_and_sort(bucket).await;
303                all_sorted.extend(sorted);
304            }
305        }
306
307        // Cancel all remaining timeout tasks
308        {
309            let tokens: HashMap<String, CancellationToken> = {
310                let mut guard = self
311                    .timeout_tokens
312                    .lock()
313                    .unwrap_or_else(|e| e.into_inner());
314                std::mem::take(&mut *guard)
315            };
316            for (_, token) in tokens {
317                token.cancel();
318            }
319        }
320        // Drop handles — tasks wind down when cancelled
321        {
322            let _handles = {
323                let mut guard = self
324                    .timeout_handles
325                    .lock()
326                    .unwrap_or_else(|e| e.into_inner());
327                std::mem::take(&mut *guard)
328            };
329        }
330
331        all_sorted
332    }
333
334    fn name(&self) -> &'static str {
335        "batch-resequencer"
336    }
337
338    fn set_timeout_tx(&self, tx: tokio::sync::mpsc::Sender<Exchange>) {
339        self.set_driver_tx(tx);
340    }
341}
342
343// ── Tests ──
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use camel_api::exchange::ExchangePattern;
349    use camel_api::message::Message;
350
351    /// Mock expression that reads a property by name.
352    struct PropExpr(String);
353
354    #[async_trait::async_trait]
355    impl Expression for PropExpr {
356        async fn evaluate(
357            &self,
358            exchange: &Exchange,
359        ) -> Result<serde_json::Value, camel_language_api::LanguageError> {
360            Ok(exchange
361                .property(&self.0)
362                .cloned()
363                .unwrap_or(serde_json::Value::Null))
364        }
365    }
366
367    /// Mock expression that always returns the same string.
368    struct ConstExpr(String);
369
370    #[async_trait::async_trait]
371    impl Expression for ConstExpr {
372        async fn evaluate(
373            &self,
374            _exchange: &Exchange,
375        ) -> Result<serde_json::Value, camel_language_api::LanguageError> {
376            Ok(serde_json::Value::String(self.0.clone()))
377        }
378    }
379
380    /// Mock expression that always fails.
381    struct FailingExpr;
382
383    #[async_trait::async_trait]
384    impl Expression for FailingExpr {
385        async fn evaluate(
386            &self,
387            _exchange: &Exchange,
388        ) -> Result<serde_json::Value, camel_language_api::LanguageError> {
389            Err(camel_language_api::LanguageError::EvalError(
390                "mock eval failure".into(),
391            ))
392        }
393    }
394
395    fn mk_exchange(seq: i64) -> Exchange {
396        let mut ex = Exchange::new(Message::new(camel_api::body::Body::Text(format!(
397            "msg-{seq}"
398        ))));
399        ex.set_property("seq", serde_json::json!(seq));
400        ex.pattern = ExchangePattern::InOnly;
401        ex
402    }
403
404    fn mk_exchange_with_key(seq: i64, key_prop: &str, key_val: &str) -> Exchange {
405        let mut ex = Exchange::new(Message::new(camel_api::body::Body::Text(format!(
406            "msg-{seq}"
407        ))));
408        ex.set_property("seq", serde_json::json!(seq));
409        ex.set_property(key_prop, serde_json::Value::String(key_val.to_string()));
410        ex.pattern = ExchangePattern::InOnly;
411        ex
412    }
413
414    /// C1.1: 3 exchanges with seq [3,1,2], same correlation key, window size 3 →
415    /// on 3rd input accept() returns [1,2,3] sorted by seq.
416    #[tokio::test]
417    async fn batch_size_completion_emits_sorted_burst() {
418        let policy = BatchPolicy::new_cyclic(
419            Arc::new(ConstExpr("same".into())),
420            Arc::new(PropExpr("seq".into())),
421            BatchCompletion::Size(3),
422        );
423
424        assert!(policy.accept(mk_exchange(3)).await.is_empty());
425        assert!(policy.accept(mk_exchange(1)).await.is_empty());
426
427        let emitted = policy.accept(mk_exchange(2)).await;
428        assert_eq!(emitted.len(), 3, "should emit all 3 on completion");
429        let seqs: Vec<i64> = emitted
430            .iter()
431            .map(|ex| ex.property("seq").and_then(|v| v.as_i64()).unwrap_or(-1))
432            .collect();
433        assert_eq!(seqs, vec![1, 2, 3], "should be sorted ascending");
434    }
435
436    /// C1.2: 2 exchanges, timeout window (no size reached) →
437    /// after timeout fires, emit sorted buffered.
438    #[tokio::test]
439    async fn batch_timeout_completion_emits_after_timeout() {
440        let policy = BatchPolicy::new_cyclic(
441            Arc::new(ConstExpr("same".into())),
442            Arc::new(PropExpr("seq".into())),
443            BatchCompletion::Timeout(50),
444        );
445
446        let (tx, mut rx) = mpsc::channel::<Exchange>(16);
447        policy.set_driver_tx(tx);
448
449        assert!(policy.accept(mk_exchange(3)).await.is_empty());
450        assert!(policy.accept(mk_exchange(1)).await.is_empty());
451
452        let emitted: Vec<Exchange> = tokio::time::timeout(Duration::from_millis(500), async {
453            let mut out = Vec::new();
454            out.push(rx.recv().await.unwrap());
455            out.push(rx.recv().await.unwrap());
456            out
457        })
458        .await
459        .expect("timeout should fire within 500ms");
460
461        assert_eq!(emitted.len(), 2);
462        let seqs: Vec<i64> = emitted
463            .iter()
464            .map(|ex| ex.property("seq").and_then(|v| v.as_i64()).unwrap_or(-1))
465            .collect();
466        assert_eq!(seqs, vec![1, 3], "should be sorted ascending");
467    }
468
469    /// C1.3: SizeOrTimeout(3, 5000ms); send 3 → size wins before timeout.
470    #[tokio::test]
471    async fn batch_size_or_timeout_size_wins() {
472        let policy = BatchPolicy::new_cyclic(
473            Arc::new(ConstExpr("same".into())),
474            Arc::new(PropExpr("seq".into())),
475            BatchCompletion::SizeOrTimeout(3, 5_000),
476        );
477
478        assert!(policy.accept(mk_exchange(2)).await.is_empty());
479        assert!(policy.accept(mk_exchange(1)).await.is_empty());
480
481        let emitted = policy.accept(mk_exchange(3)).await;
482        assert_eq!(emitted.len(), 3);
483        let seqs: Vec<i64> = emitted
484            .iter()
485            .map(|ex| ex.property("seq").and_then(|v| v.as_i64()).unwrap_or(-1))
486            .collect();
487        assert_eq!(seqs, vec![1, 2, 3]);
488    }
489
490    /// C1.4: Exchanges with different correlation keys buffer independently.
491    #[tokio::test]
492    async fn batch_multi_key_independence() {
493        let policy = BatchPolicy::new_cyclic(
494            Arc::new(PropExpr("region".into())),
495            Arc::new(PropExpr("seq".into())),
496            BatchCompletion::Size(2),
497        );
498
499        let _ = policy
500            .accept(mk_exchange_with_key(2, "region", "east"))
501            .await;
502        let east_emit = policy
503            .accept(mk_exchange_with_key(1, "region", "east"))
504            .await;
505        assert_eq!(east_emit.len(), 2, "east bucket should complete at size 2");
506
507        let west_result = policy
508            .accept(mk_exchange_with_key(3, "region", "west"))
509            .await;
510        assert!(
511            west_result.is_empty(),
512            "west bucket should NOT complete yet"
513        );
514    }
515
516    /// C1.5: flush() emits remaining buffered exchanges (within-key sorted).
517    /// With a single correlation key, all remain and are sorted together.
518    #[tokio::test]
519    async fn batch_flush_emits_remaining_sorted() {
520        let policy = BatchPolicy::new_cyclic(
521            Arc::new(ConstExpr("same".into())),
522            Arc::new(PropExpr("seq".into())),
523            BatchCompletion::Size(10),
524        );
525
526        assert!(policy.accept(mk_exchange(5)).await.is_empty());
527        assert!(policy.accept(mk_exchange(3)).await.is_empty());
528        assert!(policy.accept(mk_exchange(1)).await.is_empty());
529
530        let flushed = policy.flush().await;
531        assert_eq!(flushed.len(), 3);
532        let seqs: Vec<i64> = flushed
533            .iter()
534            .map(|ex| ex.property("seq").and_then(|v| v.as_i64()).unwrap_or(-1))
535            .collect();
536        assert_eq!(seqs, vec![1, 3, 5]);
537    }
538
539    /// C1.6: Exchange where correlation expression fails → accept()
540    /// returns empty vec (no crash).
541    #[tokio::test]
542    async fn batch_correlation_eval_failure_returns_empty() {
543        let policy = BatchPolicy::new_cyclic(
544            Arc::new(FailingExpr),
545            Arc::new(PropExpr("seq".into())),
546            BatchCompletion::Size(2),
547        );
548
549        let result = policy.accept(mk_exchange(1)).await;
550        assert!(
551            result.is_empty(),
552            "failed correlation should return empty vec, not crash"
553        );
554    }
555
556    /// Verify pure Size completion does not need timeout tasks.
557    #[tokio::test]
558    async fn batch_pure_size_no_timeout_needed() {
559        let policy = BatchPolicy::new_cyclic(
560            Arc::new(ConstExpr("same".into())),
561            Arc::new(PropExpr("seq".into())),
562            BatchCompletion::Size(2),
563        );
564
565        assert!(!policy.needs_timeout());
566    }
567}