1use std::time::{Duration, Instant};
2
3use anyhow::Result;
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::Value;
7use tokio::sync::mpsc::{Receiver, Sender};
8use tokio_util::sync::CancellationToken;
9use tracing_opentelemetry::OpenTelemetrySpanExt;
10
11use crate::config::{parse_config, redact_secret};
12use crate::envelope::Envelope;
13use crate::observability::NodeCtx;
14use crate::observability::trace_context;
15use crate::pipeline::ErrorPolicy;
16use crate::transforms::Transform;
17
18const RESERVED_PAYLOAD_KEYS: &[&str] = &["_batch_count", "_batch_first_timestamp_ms"];
19
20pub struct BatchTransform {
42 id: String,
43 max_size: usize,
44 max_delay: Option<Duration>,
45 payload_key: String,
46 flush_on_cancel: bool,
47 node_ctx: NodeCtx,
48}
49
50impl BatchTransform {
51 pub fn new(
52 id: impl Into<String>,
53 max_size: usize,
54 max_delay_ms: Option<u64>,
55 payload_key: impl Into<String>,
56 flush_on_cancel: bool,
57 ) -> Self {
58 Self {
59 id: id.into(),
60 max_size,
61 max_delay: max_delay_ms.map(Duration::from_millis),
62 payload_key: payload_key.into(),
63 flush_on_cancel,
64 node_ctx: NodeCtx::noop(),
65 }
66 }
67
68 async fn emit_batch(&self, batch: Vec<Envelope>, tx: &Sender<Envelope>) -> Result<()> {
69 if batch.is_empty() {
70 return Ok(());
71 }
72 if is_reserved_payload_key(&self.payload_key) {
73 self.record_failed(batch.len());
74 anyhow::bail!(
75 "batch: payload_key '{}' is reserved for batch metadata",
76 self.payload_key
77 );
78 }
79
80 let first_timestamp_ms = batch[0].meta.timestamp_ms;
81 let mut meta = batch[0].meta.clone();
82 meta.key = Some(format!("batch-{}", first_timestamp_ms));
83
84 if let Some(parent) = trace_context::extract(&batch[0].meta.headers) {
86 trace_context::inject(&mut meta.headers, &parent);
87 }
88
89 let count = batch.len();
90 let payloads: Vec<Value> = batch.into_iter().map(|e| e.payload).collect();
91 let mut payload = serde_json::Map::with_capacity(3);
92 payload.insert(self.payload_key.clone(), Value::Array(payloads));
93 payload.insert("_batch_count".into(), serde_json::json!(count));
94 payload.insert(
95 "_batch_first_timestamp_ms".into(),
96 serde_json::json!(first_timestamp_ms),
97 );
98 let payload = Value::Object(payload);
99
100 let env = Envelope { meta, payload };
101
102 tx.send(env).await.map_err(|_| {
105 self.record_failed(count);
106 anyhow::anyhow!("downstream closed")
107 })?;
108 self.record_processed(count);
109 Ok(())
110 }
111
112 fn record_processed(&self, count: usize) {
113 for _ in 0..count {
114 self.node_ctx.record_processed();
115 }
116 }
117
118 fn record_filtered(&self, count: usize) {
119 for _ in 0..count {
120 self.node_ctx.record_filtered();
121 }
122 }
123
124 fn record_failed(&self, count: usize) {
125 for _ in 0..count {
126 self.node_ctx.record_failed();
127 }
128 }
129}
130
131#[async_trait]
132impl Transform for BatchTransform {
133 fn id(&self) -> &str {
134 &self.id
135 }
136
137 fn set_node_ctx(&mut self, ctx: NodeCtx) {
138 self.node_ctx = ctx;
139 }
140
141 async fn run(
142 self: Box<Self>,
143 mut rx: Receiver<Envelope>,
144 tx: Sender<Envelope>,
145 cancel: CancellationToken,
146 ) {
147 let id = self.id.clone();
148 let ctx = self.node_ctx.clone();
149 let mut batch: Vec<Envelope> = Vec::with_capacity(self.max_size);
150 let mut deadline: Option<tokio::time::Instant> = None;
151
152 loop {
153 let timeout = if let Some(d) = deadline {
154 let remaining = d.saturating_duration_since(tokio::time::Instant::now());
155 if remaining.is_zero() {
156 if let Err(e) = self.emit_batch(std::mem::take(&mut batch), &tx).await {
158 tracing::error!(node_id = %redact_secret(&id), error = %e, "batch emit failed");
159 break;
160 }
161 deadline = None;
162 continue;
163 }
164 Some(tokio::time::sleep(remaining))
165 } else {
166 None
167 };
168
169 tokio::select! {
170 _ = cancel.cancelled() => {
171 if self.flush_on_cancel
172 && !batch.is_empty()
173 && let Err(e) = self.emit_batch(std::mem::replace(&mut batch, Vec::with_capacity(self.max_size)), &tx).await
174 {
175 tracing::error!(node_id = %redact_secret(&id), error = %e, "batch flush on cancel failed");
176 } else if !self.flush_on_cancel {
177 self.record_filtered(batch.len());
178 }
179 break;
180 }
181
182 maybe = rx.recv() => {
183 let Some(env) = maybe else {
184 if !batch.is_empty()
186 && let Err(e) = self.emit_batch(std::mem::replace(&mut batch, Vec::with_capacity(self.max_size)), &tx).await
187 {
188 tracing::error!(node_id = %redact_secret(&id), error = %e, "batch flush on upstream close failed");
189 break;
190 }
191 break;
192 };
193
194 let span = tracing::info_span!(
195 "courier.transform",
196 pipeline = %redact_secret(ctx.pipeline()),
197 node_id = %redact_secret(ctx.node_id()),
198 node_kind = %ctx.node_kind_str(),
199 envelope.source_id = %env.meta.source_id,
200 envelope.key = if ctx.log_keys() { env.meta.key.as_deref().unwrap_or("") } else { "" },
201 );
202 if let Some(parent) = trace_context::extract(&env.meta.headers) {
203 let _ = span.set_parent(parent);
204 }
205 let span_context = span.context();
206 let started = Instant::now();
207
208 batch.push(env);
209 ctx.record_stage_duration_ms(started.elapsed().as_secs_f64() * 1000.0);
210 if batch.len() >= self.max_size {
211 if let Err(e) = self.emit_batch(std::mem::replace(&mut batch, Vec::with_capacity(self.max_size)), &tx).await {
212 tracing::error!(node_id = %redact_secret(&id), error = %e, "batch emit failed");
213 break;
214 }
215 deadline = None;
216 } else if deadline.is_none() && self.max_delay.is_some() {
217 deadline = Some(tokio::time::Instant::now() + self.max_delay.unwrap());
218 }
219
220 let _ = span_context;
223 }
224
225 _ = async { timeout.unwrap().await }, if timeout.is_some() => {
226 if !batch.is_empty()
227 && let Err(e) = self.emit_batch(std::mem::replace(&mut batch, Vec::with_capacity(self.max_size)), &tx).await
228 {
229 tracing::error!(node_id = %redact_secret(&id), error = %e, "batch emit on timeout failed");
230 break;
231 }
232 deadline = None;
233 }
234 }
235 }
236 }
237}
238
239#[derive(Debug, Deserialize)]
244struct BatchTransformConfig {
245 max_size: usize,
246 #[serde(default)]
247 max_delay_ms: Option<u64>,
248 #[serde(default = "default_payload_key")]
249 payload_key: String,
250 #[serde(default = "default_flush_on_cancel")]
251 flush_on_cancel: bool,
252}
253
254fn default_payload_key() -> String {
255 "items".into()
256}
257
258fn default_flush_on_cancel() -> bool {
259 true
260}
261
262fn is_reserved_payload_key(key: &str) -> bool {
263 RESERVED_PAYLOAD_KEYS.contains(&key)
264}
265
266pub fn batch_transform_factory(
269 id: &str,
270 config: Value,
271 _on_error: ErrorPolicy,
272) -> Result<Box<dyn Transform>> {
273 let config: BatchTransformConfig = parse_config("batch", config)?;
274 if config.max_size == 0 {
275 anyhow::bail!("batch: max_size must be greater than 0");
276 }
277 if is_reserved_payload_key(&config.payload_key) {
278 anyhow::bail!(
279 "batch: payload_key '{}' is reserved for batch metadata",
280 config.payload_key
281 );
282 }
283 Ok(Box::new(BatchTransform::new(
284 id,
285 config.max_size,
286 config.max_delay_ms,
287 config.payload_key,
288 config.flush_on_cancel,
289 )))
290}
291
292#[cfg(test)]
293mod tests {
294 use std::time::Duration;
295
296 use serde_json::json;
297 use tokio::sync::mpsc;
298 use tokio_util::sync::CancellationToken;
299
300 use super::*;
301 use crate::Registry;
302 use crate::config::{ErrorPolicyConfig, TransformSpec};
303 use crate::envelope::Envelope;
304 use crate::observability::metrics::testing::{
305 counter_sum, histogram_count, obs_handle_in_memory,
306 };
307 use crate::observability::{NodeCtx, NodeKind};
308
309 #[tokio::test]
310 async fn emits_batch_when_max_size_reached() {
311 let (in_tx, in_rx) = mpsc::channel(10);
312 let t = BatchTransform::new("t", 3, None, "items", true);
313 let (out_tx, mut out_rx) = mpsc::channel(10);
314 let cancel = CancellationToken::new();
315 let cancel2 = cancel.clone();
316 let h = tokio::spawn(async move {
317 Box::new(t).run(in_rx, out_tx, cancel2).await;
318 });
319
320 for i in 0..3 {
321 in_tx
322 .send(Envelope::new("src", json!({ "i": i })))
323 .await
324 .unwrap();
325 }
326
327 let out = out_rx.recv().await.unwrap();
328 assert_eq!(out.payload["items"].as_array().unwrap().len(), 3);
329 assert_eq!(out.payload["_batch_count"], 3);
330
331 drop(in_tx);
332 let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
333 }
334
335 #[tokio::test]
336 async fn emits_batch_on_timeout() {
337 let t = BatchTransform::new("t", 10, Some(50), "items", true);
338 let (in_tx, in_rx) = mpsc::channel(10);
339 let (out_tx, mut out_rx) = mpsc::channel(10);
340 let cancel = CancellationToken::new();
341 let cancel2 = cancel.clone();
342 let h = tokio::spawn(async move {
343 Box::new(t).run(in_rx, out_tx, cancel2).await;
344 });
345
346 in_tx
347 .send(Envelope::new("src", json!({ "i": 1 })))
348 .await
349 .unwrap();
350
351 let out = tokio::time::timeout(Duration::from_millis(200), out_rx.recv())
352 .await
353 .unwrap()
354 .unwrap();
355 assert_eq!(out.payload["items"].as_array().unwrap().len(), 1);
356
357 drop(in_tx);
358 let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
359 }
360
361 #[tokio::test]
362 async fn flushes_partial_batch_on_upstream_close() {
363 let t = BatchTransform::new("t", 10, None, "items", true);
364 let (in_tx, in_rx) = mpsc::channel(10);
365 let (out_tx, mut out_rx) = mpsc::channel(10);
366 let cancel = CancellationToken::new();
367 let cancel2 = cancel.clone();
368 let h = tokio::spawn(async move {
369 Box::new(t).run(in_rx, out_tx, cancel2).await;
370 });
371
372 in_tx
373 .send(Envelope::new("src", json!({ "i": 1 })))
374 .await
375 .unwrap();
376 in_tx
377 .send(Envelope::new("src", json!({ "i": 2 })))
378 .await
379 .unwrap();
380 drop(in_tx);
381
382 let out = tokio::time::timeout(Duration::from_secs(1), out_rx.recv())
383 .await
384 .unwrap()
385 .unwrap();
386 assert_eq!(out.payload["items"].as_array().unwrap().len(), 2);
387
388 let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
389 }
390
391 #[tokio::test]
392 async fn flushes_partial_batch_on_cancel() {
393 let t = BatchTransform::new("t", 10, None, "items", true);
394 let (in_tx, in_rx) = mpsc::channel(10);
395 let (out_tx, mut out_rx) = mpsc::channel(10);
396 let cancel = CancellationToken::new();
397 let cancel2 = cancel.clone();
398 let h = tokio::spawn(async move {
399 Box::new(t).run(in_rx, out_tx, cancel2).await;
400 });
401
402 in_tx
403 .send(Envelope::new("src", json!({ "i": 1 })))
404 .await
405 .unwrap();
406 tokio::task::yield_now().await;
408 cancel.cancel();
409
410 let out = tokio::time::timeout(Duration::from_secs(1), out_rx.recv())
411 .await
412 .unwrap()
413 .unwrap();
414 assert_eq!(out.payload["items"].as_array().unwrap().len(), 1);
415
416 let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
417 }
418
419 #[tokio::test]
420 async fn drops_partial_batch_on_cancel_when_flush_disabled() {
421 let t = BatchTransform::new("t", 10, None, "items", false);
422 let (in_tx, in_rx) = mpsc::channel(10);
423 let (out_tx, mut out_rx) = mpsc::channel(10);
424 let cancel = CancellationToken::new();
425 let cancel2 = cancel.clone();
426 let h = tokio::spawn(async move {
427 Box::new(t).run(in_rx, out_tx, cancel2).await;
428 });
429
430 in_tx
431 .send(Envelope::new("src", json!({ "i": 1 })))
432 .await
433 .unwrap();
434 cancel.cancel();
435
436 let result = tokio::time::timeout(Duration::from_millis(100), out_rx.recv()).await;
437 assert!(result.is_err() || result.unwrap().is_none());
438
439 let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
440 }
441
442 #[tokio::test]
443 async fn empty_input_produces_nothing() {
444 let t = BatchTransform::new("t", 3, None, "items", true);
445 let (in_tx, in_rx) = mpsc::channel(10);
446 let (out_tx, mut out_rx) = mpsc::channel(10);
447 let cancel = CancellationToken::new();
448 let cancel2 = cancel.clone();
449 let h = tokio::spawn(async move {
450 Box::new(t).run(in_rx, out_tx, cancel2).await;
451 });
452
453 drop(in_tx);
454
455 let result = tokio::time::timeout(Duration::from_millis(100), out_rx.recv()).await;
456 assert!(result.is_err() || result.unwrap().is_none());
457
458 let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
459 }
460
461 #[test]
462 fn factory_resolves_through_registry() {
463 let registry = Registry::with_builtins().unwrap();
464 registry
465 .build_transform(
466 "p/t0",
467 TransformSpec {
468 kind: "batch".into(),
469 config: json!({ "max_size": 10 }),
470 on_error: Some(ErrorPolicyConfig::Drop),
471 },
472 )
473 .unwrap();
474 }
475
476 #[test]
477 fn factory_rejects_zero_max_size() {
478 let registry = Registry::with_builtins().unwrap();
479 let err = registry
480 .build_transform(
481 "p/t0",
482 TransformSpec {
483 kind: "batch".into(),
484 config: json!({ "max_size": 0 }),
485 on_error: None,
486 },
487 )
488 .err()
489 .expect("expected validation error");
490 let msg = format!("{err:#}");
491 assert!(msg.contains("max_size must be greater than 0"), "{msg}");
492 }
493
494 #[test]
495 fn factory_rejects_reserved_payload_keys() {
496 let registry = Registry::with_builtins().unwrap();
497 for payload_key in RESERVED_PAYLOAD_KEYS {
498 let err = registry
499 .build_transform(
500 "p/t0",
501 TransformSpec {
502 kind: "batch".into(),
503 config: json!({ "max_size": 10, "payload_key": payload_key }),
504 on_error: None,
505 },
506 )
507 .err()
508 .expect("expected validation error");
509 let msg = format!("{err:#}");
510 assert!(
511 msg.contains("reserved for batch metadata"),
512 "payload_key={payload_key}: {msg}"
513 );
514 }
515 }
516
517 #[tokio::test]
518 async fn records_metrics_for_emitted_batches() {
519 let (handle, exporter) = obs_handle_in_memory();
520 let mut t = BatchTransform::new("t", 3, None, "items", true);
521 t.set_node_ctx(NodeCtx::for_node(
522 "metrics",
523 "metrics/t0",
524 NodeKind::Transform,
525 handle.clone(),
526 ));
527 let (in_tx, in_rx) = mpsc::channel(10);
528 let (out_tx, mut out_rx) = mpsc::channel(10);
529 let cancel = CancellationToken::new();
530 let h = tokio::spawn(async move {
531 Box::new(t).run(in_rx, out_tx, cancel).await;
532 });
533
534 for i in 0..3 {
535 in_tx
536 .send(Envelope::new("src", json!({ "i": i })))
537 .await
538 .unwrap();
539 }
540
541 let _ = out_rx.recv().await.unwrap();
542 drop(in_tx);
543 let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
544 handle.shutdown();
545
546 let attrs = &[("pipeline", "metrics"), ("node_id", "metrics/t0")];
547 assert_eq!(
548 counter_sum(&exporter, "courier_envelopes_processed_total", attrs),
549 3
550 );
551 assert_eq!(
552 histogram_count(&exporter, "courier_stage_duration_milliseconds", attrs),
553 3
554 );
555 }
556
557 #[tokio::test]
558 async fn records_filtered_metrics_for_dropped_cancel_batch() {
559 let (handle, exporter) = obs_handle_in_memory();
560 let mut t = BatchTransform::new("t", 10, None, "items", false);
561 t.set_node_ctx(NodeCtx::for_node(
562 "metrics",
563 "metrics/t0",
564 NodeKind::Transform,
565 handle.clone(),
566 ));
567 let (in_tx, in_rx) = mpsc::channel(10);
568 let (out_tx, mut out_rx) = mpsc::channel(10);
569 let cancel = CancellationToken::new();
570 let cancel2 = cancel.clone();
571 let h = tokio::spawn(async move {
572 Box::new(t).run(in_rx, out_tx, cancel2).await;
573 });
574
575 in_tx
576 .send(Envelope::new("src", json!({ "i": 1 })))
577 .await
578 .unwrap();
579 tokio::time::sleep(Duration::from_millis(10)).await;
580 cancel.cancel();
581
582 let result = tokio::time::timeout(Duration::from_millis(100), out_rx.recv()).await;
583 assert!(result.is_err() || result.unwrap().is_none());
584 drop(in_tx);
585 let _ = tokio::time::timeout(Duration::from_secs(1), h).await;
586 handle.shutdown();
587
588 let attrs = &[("pipeline", "metrics"), ("node_id", "metrics/t0")];
589 assert_eq!(
590 counter_sum(&exporter, "courier_envelopes_filtered_total", attrs),
591 1
592 );
593 assert_eq!(
594 counter_sum(&exporter, "courier_envelopes_processed_total", attrs),
595 0
596 );
597 }
598
599 #[tokio::test]
600 async fn records_failed_metrics_when_emit_fails() {
601 let (handle, exporter) = obs_handle_in_memory();
602 let mut t = BatchTransform::new("t", 2, None, "items", true);
603 t.set_node_ctx(NodeCtx::for_node(
604 "metrics",
605 "metrics/t0",
606 NodeKind::Transform,
607 handle.clone(),
608 ));
609 let (in_tx, in_rx) = mpsc::channel(10);
610 let (out_tx, out_rx) = mpsc::channel(10);
611 let cancel = CancellationToken::new();
612 let h = tokio::spawn(async move {
613 Box::new(t).run(in_rx, out_tx, cancel).await;
614 });
615
616 drop(out_rx);
617 for i in 0..2 {
618 in_tx
619 .send(Envelope::new("src", json!({ "i": i })))
620 .await
621 .unwrap();
622 }
623
624 let _ = tokio::time::timeout(Duration::from_secs(1), h)
625 .await
626 .unwrap();
627 handle.shutdown();
628
629 let attrs = &[("pipeline", "metrics"), ("node_id", "metrics/t0")];
630 assert_eq!(
631 counter_sum(&exporter, "courier_envelopes_failed_total", attrs),
632 2
633 );
634 assert_eq!(
635 counter_sum(&exporter, "courier_envelopes_processed_total", attrs),
636 0
637 );
638 }
639
640 #[tokio::test]
641 async fn stops_when_downstream_closes() {
642 let t = BatchTransform::new("t", 2, None, "items", true);
643 let (in_tx, in_rx) = mpsc::channel(10);
644 let (out_tx, out_rx) = mpsc::channel(10);
645 let cancel = CancellationToken::new();
646 let h = tokio::spawn(async move {
647 Box::new(t).run(in_rx, out_tx, cancel.clone()).await;
648 });
649
650 in_tx
652 .send(Envelope::new("src", json!({ "i": 1 })))
653 .await
654 .unwrap();
655
656 drop(out_rx);
658
659 in_tx
662 .send(Envelope::new("src", json!({ "i": 2 })))
663 .await
664 .unwrap();
665
666 let _ = tokio::time::timeout(Duration::from_secs(1), h)
668 .await
669 .unwrap();
670 }
671}