Skip to main content

courier/transforms/
batch.rs

1use std::time::{Duration, Instant};
2
3use anyhow::Result;
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::Value;
7use tokio::sync::mpsc::{Receiver, Sender};
8use tokio_util::sync::CancellationToken;
9use tracing_opentelemetry::OpenTelemetrySpanExt;
10
11use crate::config::{parse_config, redact_secret};
12use crate::envelope::Envelope;
13use crate::observability::NodeCtx;
14use crate::observability::trace_context;
15use crate::pipeline::ErrorPolicy;
16use crate::transforms::Transform;
17
18const RESERVED_PAYLOAD_KEYS: &[&str] = &["_batch_count", "_batch_first_timestamp_ms"];
19
20/// Groups envelopes into batches by maximum count and/or maximum time window.
21///
22/// When either limit is reached, the batch is emitted as a single envelope
23/// whose payload contains an array of the original payloads under the
24/// configured `payload_key`.
25///
26/// # Meta derivation
27/// The emitted envelope inherits `meta` from the first envelope in the batch.
28/// Its `key` is set to `batch-{first_timestamp_ms}` so downstream sinks
29/// that require a key have a deterministic value.
30///
31/// # Acknowledgement
32/// Because channels are the acknowledgement boundary, the batched envelope
33/// is acknowledged only when downstream work for it completes. This in turn
34/// acknowledges every constituent envelope, since none are released until
35/// the batch envelope is delivered to the next stage.
36///
37/// # Cancellation
38/// By default (`flush_on_cancel = true`), any partial batch is flushed
39/// immediately when the cancellation token fires. When set to `false`,
40/// partial batches are dropped on cancellation.
41pub struct BatchTransform {
42    id: String,
43    max_size: usize,
44    max_delay: Option<Duration>,
45    payload_key: String,
46    flush_on_cancel: bool,
47    node_ctx: NodeCtx,
48}
49
50impl BatchTransform {
51    pub fn new(
52        id: impl Into<String>,
53        max_size: usize,
54        max_delay_ms: Option<u64>,
55        payload_key: impl Into<String>,
56        flush_on_cancel: bool,
57    ) -> Self {
58        Self {
59            id: id.into(),
60            max_size,
61            max_delay: max_delay_ms.map(Duration::from_millis),
62            payload_key: payload_key.into(),
63            flush_on_cancel,
64            node_ctx: NodeCtx::noop(),
65        }
66    }
67
68    async fn emit_batch(&self, batch: Vec<Envelope>, tx: &Sender<Envelope>) -> Result<()> {
69        if batch.is_empty() {
70            return Ok(());
71        }
72        if is_reserved_payload_key(&self.payload_key) {
73            self.record_failed(batch.len());
74            anyhow::bail!(
75                "batch: payload_key '{}' is reserved for batch metadata",
76                self.payload_key
77            );
78        }
79
80        let first_timestamp_ms = batch[0].meta.timestamp_ms;
81        let mut meta = batch[0].meta.clone();
82        meta.key = Some(format!("batch-{}", first_timestamp_ms));
83
84        // Propagate trace context from the first envelope in the batch.
85        if let Some(parent) = trace_context::extract(&batch[0].meta.headers) {
86            trace_context::inject(&mut meta.headers, &parent);
87        }
88
89        let count = batch.len();
90        let payloads: Vec<Value> = batch.into_iter().map(|e| e.payload).collect();
91        let mut payload = serde_json::Map::with_capacity(3);
92        payload.insert(self.payload_key.clone(), Value::Array(payloads));
93        payload.insert("_batch_count".into(), serde_json::json!(count));
94        payload.insert(
95            "_batch_first_timestamp_ms".into(),
96            serde_json::json!(first_timestamp_ms),
97        );
98        let payload = Value::Object(payload);
99
100        let env = Envelope { meta, payload };
101
102        // Propagate downstream closure so the batcher exits instead of
103        // silently dropping every subsequent batch.
104        tx.send(env).await.map_err(|_| {
105            self.record_failed(count);
106            anyhow::anyhow!("downstream closed")
107        })?;
108        self.record_processed(count);
109        Ok(())
110    }
111
112    fn record_processed(&self, count: usize) {
113        for _ in 0..count {
114            self.node_ctx.record_processed();
115        }
116    }
117
118    fn record_filtered(&self, count: usize) {
119        for _ in 0..count {
120            self.node_ctx.record_filtered();
121        }
122    }
123
124    fn record_failed(&self, count: usize) {
125        for _ in 0..count {
126            self.node_ctx.record_failed();
127        }
128    }
129}
130
131#[async_trait]
132impl Transform for BatchTransform {
133    fn id(&self) -> &str {
134        &self.id
135    }
136
137    fn set_node_ctx(&mut self, ctx: NodeCtx) {
138        self.node_ctx = ctx;
139    }
140
141    async fn run(
142        self: Box<Self>,
143        mut rx: Receiver<Envelope>,
144        tx: Sender<Envelope>,
145        cancel: CancellationToken,
146    ) {
147        let id = self.id.clone();
148        let ctx = self.node_ctx.clone();
149        let mut batch: Vec<Envelope> = Vec::with_capacity(self.max_size);
150        let mut deadline: Option<tokio::time::Instant> = None;
151
152        loop {
153            let timeout = if let Some(d) = deadline {
154                let remaining = d.saturating_duration_since(tokio::time::Instant::now());
155                if remaining.is_zero() {
156                    // Deadline already passed — flush now.
157                    if let Err(e) = self.emit_batch(std::mem::take(&mut batch), &tx).await {
158                        tracing::error!(node_id = %redact_secret(&id), error = %e, "batch emit failed");
159                        break;
160                    }
161                    deadline = None;
162                    continue;
163                }
164                Some(tokio::time::sleep(remaining))
165            } else {
166                None
167            };
168
169            tokio::select! {
170                _ = cancel.cancelled() => {
171                    if self.flush_on_cancel
172                        && !batch.is_empty()
173                        && let Err(e) = self.emit_batch(std::mem::replace(&mut batch, Vec::with_capacity(self.max_size)), &tx).await
174                    {
175                        tracing::error!(node_id = %redact_secret(&id), error = %e, "batch flush on cancel failed");
176                    } else if !self.flush_on_cancel {
177                        self.record_filtered(batch.len());
178                    }
179                    break;
180                }
181
182                maybe = rx.recv() => {
183                    let Some(env) = maybe else {
184                        // Upstream closed; flush any remaining batch.
185                        if !batch.is_empty()
186                            && let Err(e) = self.emit_batch(std::mem::replace(&mut batch, Vec::with_capacity(self.max_size)), &tx).await
187                        {
188                            tracing::error!(node_id = %redact_secret(&id), error = %e, "batch flush on upstream close failed");
189                            break;
190                        }
191                        break;
192                    };
193
194                    let span = tracing::info_span!(
195                        "courier.transform",
196                        pipeline = %redact_secret(ctx.pipeline()),
197                        node_id = %redact_secret(ctx.node_id()),
198                        node_kind = %ctx.node_kind_str(),
199                        envelope.source_id = %env.meta.source_id,
200                        envelope.key = if ctx.log_keys() { env.meta.key.as_deref().unwrap_or("") } else { "" },
201                    );
202                    if let Some(parent) = trace_context::extract(&env.meta.headers) {
203                        let _ = span.set_parent(parent);
204                    }
205                    let span_context = span.context();
206                    let started = Instant::now();
207
208                    batch.push(env);
209                    ctx.record_stage_duration_ms(started.elapsed().as_secs_f64() * 1000.0);
210                    if batch.len() >= self.max_size {
211                        if let Err(e) = self.emit_batch(std::mem::replace(&mut batch, Vec::with_capacity(self.max_size)), &tx).await {
212                            tracing::error!(node_id = %redact_secret(&id), error = %e, "batch emit failed");
213                            break;
214                        }
215                        deadline = None;
216                    } else if deadline.is_none() && self.max_delay.is_some() {
217                        deadline = Some(tokio::time::Instant::now() + self.max_delay.unwrap());
218                    }
219
220                    // Continue the span context for the emitted batch envelope
221                    // (handled inside emit_batch via trace_context::inject on headers).
222                    let _ = span_context;
223                }
224
225                _ = async { timeout.unwrap().await }, if timeout.is_some() => {
226                    if !batch.is_empty()
227                        && let Err(e) = self.emit_batch(std::mem::replace(&mut batch, Vec::with_capacity(self.max_size)), &tx).await
228                    {
229                        tracing::error!(node_id = %redact_secret(&id), error = %e, "batch emit on timeout failed");
230                        break;
231                    }
232                    deadline = None;
233                }
234            }
235        }
236    }
237}
238
239// ------------------------------------------------------------------
240// Config + Factory
241// ------------------------------------------------------------------
242
243#[derive(Debug, Deserialize)]
244struct BatchTransformConfig {
245    max_size: usize,
246    #[serde(default)]
247    max_delay_ms: Option<u64>,
248    #[serde(default = "default_payload_key")]
249    payload_key: String,
250    #[serde(default = "default_flush_on_cancel")]
251    flush_on_cancel: bool,
252}
253
254fn default_payload_key() -> String {
255    "items".into()
256}
257
258fn default_flush_on_cancel() -> bool {
259    true
260}
261
262fn is_reserved_payload_key(key: &str) -> bool {
263    RESERVED_PAYLOAD_KEYS.contains(&key)
264}
265
266/// Registry factory for [`BatchTransform`]. Registered by
267/// `courier::registry::register_builtin` under kind `"batch"`.
268pub fn batch_transform_factory(
269    id: &str,
270    config: Value,
271    _on_error: ErrorPolicy,
272) -> Result<Box<dyn Transform>> {
273    let config: BatchTransformConfig = parse_config("batch", config)?;
274    if config.max_size == 0 {
275        anyhow::bail!("batch: max_size must be greater than 0");
276    }
277    if is_reserved_payload_key(&config.payload_key) {
278        anyhow::bail!(
279            "batch: payload_key '{}' is reserved for batch metadata",
280            config.payload_key
281        );
282    }
283    Ok(Box::new(BatchTransform::new(
284        id,
285        config.max_size,
286        config.max_delay_ms,
287        config.payload_key,
288        config.flush_on_cancel,
289    )))
290}
291
292#[cfg(test)]
293mod tests {
294    use std::time::Duration;
295
296    use serde_json::json;
297    use tokio::sync::mpsc;
298    use tokio_util::sync::CancellationToken;
299
300    use super::*;
301    use crate::Registry;
302    use crate::config::{ErrorPolicyConfig, TransformSpec};
303    use crate::envelope::Envelope;
304    use crate::observability::metrics::testing::{
305        counter_sum, histogram_count, obs_handle_in_memory,
306    };
307    use crate::observability::{NodeCtx, NodeKind};
308
309    #[tokio::test]
310    async fn emits_batch_when_max_size_reached() {
311        let (in_tx, in_rx) = mpsc::channel(10);
312        let t = BatchTransform::new("t", 3, None, "items", true);
313        let (out_tx, mut out_rx) = mpsc::channel(10);
314        let cancel = CancellationToken::new();
315        let cancel2 = cancel.clone();
316        let h = tokio::spawn(async move {
317            Box::new(t).run(in_rx, out_tx, cancel2).await;
318        });
319
320        for i in 0..3 {
321            in_tx
322                .send(Envelope::new("src", json!({ "i": i })))
323                .await
324                .unwrap();
325        }
326
327        let out = out_rx.recv().await.unwrap();
328        assert_eq!(out.payload["items"].as_array().unwrap().len(), 3);
329        assert_eq!(out.payload["_batch_count"], 3);
330
331        drop(in_tx);
332        let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
333    }
334
335    #[tokio::test]
336    async fn emits_batch_on_timeout() {
337        let t = BatchTransform::new("t", 10, Some(50), "items", true);
338        let (in_tx, in_rx) = mpsc::channel(10);
339        let (out_tx, mut out_rx) = mpsc::channel(10);
340        let cancel = CancellationToken::new();
341        let cancel2 = cancel.clone();
342        let h = tokio::spawn(async move {
343            Box::new(t).run(in_rx, out_tx, cancel2).await;
344        });
345
346        in_tx
347            .send(Envelope::new("src", json!({ "i": 1 })))
348            .await
349            .unwrap();
350
351        let out = tokio::time::timeout(Duration::from_millis(200), out_rx.recv())
352            .await
353            .unwrap()
354            .unwrap();
355        assert_eq!(out.payload["items"].as_array().unwrap().len(), 1);
356
357        drop(in_tx);
358        let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
359    }
360
361    #[tokio::test]
362    async fn flushes_partial_batch_on_upstream_close() {
363        let t = BatchTransform::new("t", 10, None, "items", true);
364        let (in_tx, in_rx) = mpsc::channel(10);
365        let (out_tx, mut out_rx) = mpsc::channel(10);
366        let cancel = CancellationToken::new();
367        let cancel2 = cancel.clone();
368        let h = tokio::spawn(async move {
369            Box::new(t).run(in_rx, out_tx, cancel2).await;
370        });
371
372        in_tx
373            .send(Envelope::new("src", json!({ "i": 1 })))
374            .await
375            .unwrap();
376        in_tx
377            .send(Envelope::new("src", json!({ "i": 2 })))
378            .await
379            .unwrap();
380        drop(in_tx);
381
382        let out = tokio::time::timeout(Duration::from_secs(1), out_rx.recv())
383            .await
384            .unwrap()
385            .unwrap();
386        assert_eq!(out.payload["items"].as_array().unwrap().len(), 2);
387
388        let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
389    }
390
391    #[tokio::test]
392    async fn flushes_partial_batch_on_cancel() {
393        let t = BatchTransform::new("t", 10, None, "items", true);
394        let (in_tx, in_rx) = mpsc::channel(10);
395        let (out_tx, mut out_rx) = mpsc::channel(10);
396        let cancel = CancellationToken::new();
397        let cancel2 = cancel.clone();
398        let h = tokio::spawn(async move {
399            Box::new(t).run(in_rx, out_tx, cancel2).await;
400        });
401
402        in_tx
403            .send(Envelope::new("src", json!({ "i": 1 })))
404            .await
405            .unwrap();
406        // Give the batch transform a chance to receive the item before cancelling.
407        tokio::task::yield_now().await;
408        cancel.cancel();
409
410        let out = tokio::time::timeout(Duration::from_secs(1), out_rx.recv())
411            .await
412            .unwrap()
413            .unwrap();
414        assert_eq!(out.payload["items"].as_array().unwrap().len(), 1);
415
416        let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
417    }
418
419    #[tokio::test]
420    async fn drops_partial_batch_on_cancel_when_flush_disabled() {
421        let t = BatchTransform::new("t", 10, None, "items", false);
422        let (in_tx, in_rx) = mpsc::channel(10);
423        let (out_tx, mut out_rx) = mpsc::channel(10);
424        let cancel = CancellationToken::new();
425        let cancel2 = cancel.clone();
426        let h = tokio::spawn(async move {
427            Box::new(t).run(in_rx, out_tx, cancel2).await;
428        });
429
430        in_tx
431            .send(Envelope::new("src", json!({ "i": 1 })))
432            .await
433            .unwrap();
434        cancel.cancel();
435
436        let result = tokio::time::timeout(Duration::from_millis(100), out_rx.recv()).await;
437        assert!(result.is_err() || result.unwrap().is_none());
438
439        let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
440    }
441
442    #[tokio::test]
443    async fn empty_input_produces_nothing() {
444        let t = BatchTransform::new("t", 3, None, "items", true);
445        let (in_tx, in_rx) = mpsc::channel(10);
446        let (out_tx, mut out_rx) = mpsc::channel(10);
447        let cancel = CancellationToken::new();
448        let cancel2 = cancel.clone();
449        let h = tokio::spawn(async move {
450            Box::new(t).run(in_rx, out_tx, cancel2).await;
451        });
452
453        drop(in_tx);
454
455        let result = tokio::time::timeout(Duration::from_millis(100), out_rx.recv()).await;
456        assert!(result.is_err() || result.unwrap().is_none());
457
458        let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
459    }
460
461    #[test]
462    fn factory_resolves_through_registry() {
463        let registry = Registry::with_builtins().unwrap();
464        registry
465            .build_transform(
466                "p/t0",
467                TransformSpec {
468                    kind: "batch".into(),
469                    config: json!({ "max_size": 10 }),
470                    on_error: Some(ErrorPolicyConfig::Drop),
471                },
472            )
473            .unwrap();
474    }
475
476    #[test]
477    fn factory_rejects_zero_max_size() {
478        let registry = Registry::with_builtins().unwrap();
479        let err = registry
480            .build_transform(
481                "p/t0",
482                TransformSpec {
483                    kind: "batch".into(),
484                    config: json!({ "max_size": 0 }),
485                    on_error: None,
486                },
487            )
488            .err()
489            .expect("expected validation error");
490        let msg = format!("{err:#}");
491        assert!(msg.contains("max_size must be greater than 0"), "{msg}");
492    }
493
494    #[test]
495    fn factory_rejects_reserved_payload_keys() {
496        let registry = Registry::with_builtins().unwrap();
497        for payload_key in RESERVED_PAYLOAD_KEYS {
498            let err = registry
499                .build_transform(
500                    "p/t0",
501                    TransformSpec {
502                        kind: "batch".into(),
503                        config: json!({ "max_size": 10, "payload_key": payload_key }),
504                        on_error: None,
505                    },
506                )
507                .err()
508                .expect("expected validation error");
509            let msg = format!("{err:#}");
510            assert!(
511                msg.contains("reserved for batch metadata"),
512                "payload_key={payload_key}: {msg}"
513            );
514        }
515    }
516
517    #[tokio::test]
518    async fn records_metrics_for_emitted_batches() {
519        let (handle, exporter) = obs_handle_in_memory();
520        let mut t = BatchTransform::new("t", 3, None, "items", true);
521        t.set_node_ctx(NodeCtx::for_node(
522            "metrics",
523            "metrics/t0",
524            NodeKind::Transform,
525            handle.clone(),
526        ));
527        let (in_tx, in_rx) = mpsc::channel(10);
528        let (out_tx, mut out_rx) = mpsc::channel(10);
529        let cancel = CancellationToken::new();
530        let h = tokio::spawn(async move {
531            Box::new(t).run(in_rx, out_tx, cancel).await;
532        });
533
534        for i in 0..3 {
535            in_tx
536                .send(Envelope::new("src", json!({ "i": i })))
537                .await
538                .unwrap();
539        }
540
541        let _ = out_rx.recv().await.unwrap();
542        drop(in_tx);
543        let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
544        handle.shutdown();
545
546        let attrs = &[("pipeline", "metrics"), ("node_id", "metrics/t0")];
547        assert_eq!(
548            counter_sum(&exporter, "courier_envelopes_processed_total", attrs),
549            3
550        );
551        assert_eq!(
552            histogram_count(&exporter, "courier_stage_duration_milliseconds", attrs),
553            3
554        );
555    }
556
557    #[tokio::test]
558    async fn records_filtered_metrics_for_dropped_cancel_batch() {
559        let (handle, exporter) = obs_handle_in_memory();
560        let mut t = BatchTransform::new("t", 10, None, "items", false);
561        t.set_node_ctx(NodeCtx::for_node(
562            "metrics",
563            "metrics/t0",
564            NodeKind::Transform,
565            handle.clone(),
566        ));
567        let (in_tx, in_rx) = mpsc::channel(10);
568        let (out_tx, mut out_rx) = mpsc::channel(10);
569        let cancel = CancellationToken::new();
570        let cancel2 = cancel.clone();
571        let h = tokio::spawn(async move {
572            Box::new(t).run(in_rx, out_tx, cancel2).await;
573        });
574
575        in_tx
576            .send(Envelope::new("src", json!({ "i": 1 })))
577            .await
578            .unwrap();
579        tokio::time::sleep(Duration::from_millis(10)).await;
580        cancel.cancel();
581
582        let result = tokio::time::timeout(Duration::from_millis(100), out_rx.recv()).await;
583        assert!(result.is_err() || result.unwrap().is_none());
584        drop(in_tx);
585        let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
586        handle.shutdown();
587
588        let attrs = &[("pipeline", "metrics"), ("node_id", "metrics/t0")];
589        assert_eq!(
590            counter_sum(&exporter, "courier_envelopes_filtered_total", attrs),
591            1
592        );
593        assert_eq!(
594            counter_sum(&exporter, "courier_envelopes_processed_total", attrs),
595            0
596        );
597    }
598
599    #[tokio::test]
600    async fn records_failed_metrics_when_emit_fails() {
601        let (handle, exporter) = obs_handle_in_memory();
602        let mut t = BatchTransform::new("t", 2, None, "items", true);
603        t.set_node_ctx(NodeCtx::for_node(
604            "metrics",
605            "metrics/t0",
606            NodeKind::Transform,
607            handle.clone(),
608        ));
609        let (in_tx, in_rx) = mpsc::channel(10);
610        let (out_tx, out_rx) = mpsc::channel(10);
611        let cancel = CancellationToken::new();
612        let h = tokio::spawn(async move {
613            Box::new(t).run(in_rx, out_tx, cancel).await;
614        });
615
616        drop(out_rx);
617        for i in 0..2 {
618            in_tx
619                .send(Envelope::new("src", json!({ "i": i })))
620                .await
621                .unwrap();
622        }
623
624        let _ = tokio::time::timeout(Duration::from_secs(1), h)
625            .await
626            .unwrap();
627        handle.shutdown();
628
629        let attrs = &[("pipeline", "metrics"), ("node_id", "metrics/t0")];
630        assert_eq!(
631            counter_sum(&exporter, "courier_envelopes_failed_total", attrs),
632            2
633        );
634        assert_eq!(
635            counter_sum(&exporter, "courier_envelopes_processed_total", attrs),
636            0
637        );
638    }
639
640    #[tokio::test]
641    async fn stops_when_downstream_closes() {
642        let t = BatchTransform::new("t", 2, None, "items", true);
643        let (in_tx, in_rx) = mpsc::channel(10);
644        let (out_tx, out_rx) = mpsc::channel(10);
645        let cancel = CancellationToken::new();
646        let h = tokio::spawn(async move {
647            Box::new(t).run(in_rx, out_tx, cancel.clone()).await;
648        });
649
650        // Send one item — not enough to trigger a batch yet.
651        in_tx
652            .send(Envelope::new("src", json!({ "i": 1 })))
653            .await
654            .unwrap();
655
656        // Drop the downstream receiver to simulate sink exit.
657        drop(out_rx);
658
659        // Send a second item — now the batcher tries to emit and sees
660        // downstream is closed. It should stop instead of looping forever.
661        in_tx
662            .send(Envelope::new("src", json!({ "i": 2 })))
663            .await
664            .unwrap();
665
666        // The batcher task should exit promptly.
667        let _ = tokio::time::timeout(Duration::from_secs(1), h)
668            .await
669            .unwrap();
670    }
671}