1use std::time::Duration;
2
3use tokio::sync::mpsc::{self, WeakSender};
4use tokio::task::JoinHandle;
5use tokio_util::sync::CancellationToken;
6
7use crate::config::redact_secret;
8use crate::envelope::Envelope;
9use crate::observability::{NodeCtx, NodeKind, ObsHandle};
10use crate::sinks::Sink;
11use crate::sources::Source;
12use crate::transforms::Transform;
13
14const CHANNEL_DEPTH_SAMPLE_INTERVAL: Duration = Duration::from_millis(300);
18
19#[derive(Debug, Clone, Default, Eq, PartialEq)]
26pub enum ErrorPolicy {
27 #[default]
28 Drop,
29 FailPipeline,
30}
31
32pub struct Pipeline {
37 pub id: String,
38 pub source: Box<dyn Source>,
39 pub transforms: Vec<Box<dyn Transform>>,
40 pub sinks: Vec<Box<dyn Sink>>,
41 pub channel_capacity: usize,
42 pub(crate) obs: Option<ObsHandle>,
48}
49
50impl Pipeline {
51 pub fn new(id: impl Into<String>, source: Box<dyn Source>) -> Self {
52 Self {
53 id: id.into(),
54 source,
55 transforms: Vec::new(),
56 sinks: Vec::new(),
57 channel_capacity: 64,
58 obs: None,
59 }
60 }
61
62 pub fn with_transform(mut self, t: Box<dyn Transform>) -> Self {
63 self.transforms.push(t);
64 self
65 }
66
67 pub fn with_sink(mut self, s: Box<dyn Sink>) -> Self {
68 self.sinks.push(s);
69 self
70 }
71
72 pub fn with_channel_capacity(mut self, cap: usize) -> Self {
73 self.channel_capacity = cap;
74 self
75 }
76
77 pub fn with_observability(mut self, obs: Option<ObsHandle>) -> Self {
78 self.obs = obs;
79 self
80 }
81}
82
83pub(crate) fn spawn_pipeline(p: Pipeline, cancel: CancellationToken) -> Vec<JoinHandle<()>> {
96 let Pipeline {
97 id,
98 mut source,
99 mut transforms,
100 mut sinks,
101 channel_capacity: cap,
102 obs,
103 } = p;
104
105 tracing::info!(pipeline = %redact_secret(&id), "spawning pipeline");
106 let mut handles = Vec::new();
107
108 let (src_tx, mut prev_rx) = mpsc::channel::<Envelope>(cap);
109 let mut prev_node_id = format!("{id}/src");
110 let transforms_total = transforms.len();
111
112 if let Some(handle) = &obs {
113 source.set_node_ctx(NodeCtx::for_node(
114 &id,
115 &prev_node_id,
116 NodeKind::Source,
117 handle.clone(),
118 ));
119 }
120
121 if let Some(handle) = obs.as_ref().filter(|h| h.is_enabled()) {
123 spawn_edge_sampler(
124 &id,
125 &prev_node_id,
126 &next_transform_or_sink_id(&id, &transforms, &sinks),
127 cap,
128 src_tx.downgrade(),
129 handle.clone(),
130 cancel.clone(),
131 &mut handles,
132 );
133 }
134
135 let c = cancel.clone();
136 handles.push(tokio::spawn(async move { source.run(src_tx, c).await }));
137
138 for (i, mut t) in transforms.drain(..).enumerate() {
139 let node_id = format!("{id}/t{i}");
140 if let Some(handle) = &obs {
141 t.set_node_ctx(NodeCtx::for_node(
142 &id,
143 &node_id,
144 NodeKind::Transform,
145 handle.clone(),
146 ));
147 }
148
149 let (next_tx, next_rx) = mpsc::channel::<Envelope>(cap);
150
151 if let Some(handle) = obs.as_ref().filter(|h| h.is_enabled()) {
153 let dest_node_id = transform_or_sink_id_after(&id, i + 1, transforms_total, &sinks);
154 spawn_edge_sampler(
155 &id,
156 &node_id,
157 &dest_node_id,
158 cap,
159 next_tx.downgrade(),
160 handle.clone(),
161 cancel.clone(),
162 &mut handles,
163 );
164 }
165
166 let rx = prev_rx;
167 let c = cancel.clone();
168 handles.push(tokio::spawn(async move { t.run(rx, next_tx, c).await }));
169 prev_rx = next_rx;
170 prev_node_id = node_id;
171 }
172
173 match sinks.len() {
174 0 => {
175 tracing::warn!(
176 pipeline = %redact_secret(&id),
177 "pipeline has no sinks; envelopes will be discarded"
178 );
179 let c = cancel.clone();
180 handles.push(tokio::spawn(async move {
181 loop {
182 tokio::select! {
183 _ = c.cancelled() => break,
184 m = prev_rx.recv() => if m.is_none() { break },
185 }
186 }
187 }));
188 }
189 1 => {
190 let mut sink = sinks.into_iter().next().unwrap();
191 let sink_node_id = format!("{id}/sink0");
192 if let Some(handle) = &obs {
193 sink.set_node_ctx(NodeCtx::for_node(
194 &id,
195 &sink_node_id,
196 NodeKind::Sink,
197 handle.clone(),
198 ));
199 }
200 let c = cancel.clone();
201 handles.push(tokio::spawn(async move { sink.run(prev_rx, c).await }));
202 let _ = prev_node_id; }
204 _ => {
205 let splitter_id = format!("{id}/broadcast");
206 let mut sink_txs = Vec::with_capacity(sinks.len());
207 for (i, mut sink) in sinks.drain(..).enumerate() {
208 let sink_node_id = format!("{id}/sink{i}");
209 if let Some(handle) = &obs {
210 sink.set_node_ctx(NodeCtx::for_node(
211 &id,
212 &sink_node_id,
213 NodeKind::Sink,
214 handle.clone(),
215 ));
216 }
217 let (tx, rx) = mpsc::channel::<Envelope>(cap);
218 if let Some(handle) = obs.as_ref().filter(|h| h.is_enabled()) {
219 spawn_edge_sampler(
220 &id,
221 &splitter_id,
222 &sink_node_id,
223 cap,
224 tx.downgrade(),
225 handle.clone(),
226 cancel.clone(),
227 &mut handles,
228 );
229 }
230 sink_txs.push(tx);
231 let c = cancel.clone();
232 handles.push(tokio::spawn(async move { sink.run(rx, c).await }));
233 }
234 let c = cancel.clone();
235 let splitter_log_id = splitter_id.clone();
236 handles.push(tokio::spawn(async move {
237 'splitter: loop {
238 tokio::select! {
239 _ = c.cancelled() => break,
240 maybe = prev_rx.recv() => {
241 let Some(env) = maybe else { break };
242 for tx in &sink_txs {
243 tokio::select! {
244 _ = c.cancelled() => break 'splitter,
245 res = tx.send(env.clone()) => {
246 if res.is_err() {
247 tracing::debug!(node_id = %redact_secret(&splitter_log_id), "downstream sink closed");
248 }
249 }
250 }
251 }
252 }
253 }
254 }
255 }));
256 }
257 }
258
259 handles
260}
261
262fn next_transform_or_sink_id(
266 id: &str,
267 transforms: &[Box<dyn Transform>],
268 sinks: &[Box<dyn Sink>],
269) -> String {
270 if !transforms.is_empty() {
271 format!("{id}/t0")
272 } else if sinks.len() > 1 {
273 format!("{id}/broadcast")
274 } else {
275 format!("{id}/sink0")
276 }
277}
278
279fn transform_or_sink_id_after(
282 id: &str,
283 next_index: usize,
284 total_transforms: usize,
285 sinks: &[Box<dyn Sink>],
286) -> String {
287 if next_index < total_transforms {
288 format!("{id}/t{next_index}")
289 } else if sinks.len() > 1 {
290 format!("{id}/broadcast")
291 } else {
292 format!("{id}/sink0")
293 }
294}
295
296#[allow(clippy::too_many_arguments)]
301fn spawn_edge_sampler(
302 pipeline: &str,
303 src_node_id: &str,
304 dest_node_id: &str,
305 capacity: usize,
306 tx: WeakSender<Envelope>,
307 handle: ObsHandle,
308 cancel: CancellationToken,
309 handles: &mut Vec<JoinHandle<()>>,
310) {
311 let edge_id = format!(
312 "{pipeline}/edge/{}->{}",
313 short_node_id(pipeline, src_node_id),
314 short_node_id(pipeline, dest_node_id)
315 );
316 let ctx = NodeCtx::for_node(pipeline, &edge_id, NodeKind::Edge, handle);
317 handles.push(tokio::spawn(async move {
318 let mut ticker = tokio::time::interval(CHANNEL_DEPTH_SAMPLE_INTERVAL);
319 ticker.tick().await;
323 loop {
324 tokio::select! {
325 _ = cancel.cancelled() => break,
326 _ = ticker.tick() => {
327 let Some(tx) = tx.upgrade() else {
328 break;
329 };
330 let used = capacity.saturating_sub(tx.capacity()) as u64;
331 ctx.record_channel_capacity_used(used);
332 if tx.is_closed() {
335 break;
336 }
337 }
338 }
339 }
340 }));
341}
342
343fn short_node_id<'a>(pipeline: &str, node_id: &'a str) -> &'a str {
344 node_id
345 .strip_prefix(pipeline)
346 .and_then(|s| s.strip_prefix('/'))
347 .unwrap_or(node_id)
348}
349
350#[cfg(test)]
351mod tests {
352 use anyhow::Result;
353 use async_trait::async_trait;
354 use futures::future::join_all;
355 use opentelemetry::trace::TracerProvider;
356 use opentelemetry_sdk::trace::{InMemorySpanExporter, SdkTracerProvider};
357 use serde_json::json;
358 use std::sync::{Arc, Mutex, OnceLock};
359 use tokio::sync::{
360 Notify,
361 mpsc::{self, Receiver, Sender},
362 };
363 use tracing_subscriber::layer::SubscriberExt;
364
365 use super::*;
366 use crate::observability::metrics::testing::{
367 counter_sum, histogram_count, obs_handle_in_memory,
368 };
369 use crate::observability::trace_context::TRACEPARENT;
370 use crate::observability::{SendStopped, SourceCtx};
371 use crate::sinks::{ManagedSink, WriteOne};
372 use crate::transforms::{BasicTransform, MapOne};
373
374 static TEST_TRACING_GLOBAL: OnceLock<()> = OnceLock::new();
375
376 fn install_test_tracing_global() {
377 TEST_TRACING_GLOBAL.get_or_init(|| {
378 let subscriber =
379 tracing_subscriber::registry().with(tracing_subscriber::filter::LevelFilter::TRACE);
380 let _ = tracing::subscriber::set_global_default(subscriber);
381 });
382 tracing::callsite::rebuild_interest_cache();
383 }
384
385 struct HundredSource {
386 source_ctx: SourceCtx,
387 }
388
389 impl HundredSource {
390 fn new() -> Self {
391 Self {
392 source_ctx: SourceCtx::new("src"),
393 }
394 }
395 }
396
397 #[async_trait]
398 impl Source for HundredSource {
399 fn id(&self) -> &str {
400 "src"
401 }
402
403 fn set_node_ctx(&mut self, ctx: NodeCtx) {
404 self.source_ctx = SourceCtx::from_node_ctx(ctx);
405 }
406
407 async fn run(self: Box<Self>, tx: Sender<Envelope>, cancel: CancellationToken) {
408 for i in 0..100 {
409 let env = Envelope::new("src", json!({ "n": i }));
410 match self.source_ctx.send(&tx, env, &cancel).await {
411 Ok(()) => {}
412 Err(SendStopped::Cancelled) | Err(SendStopped::DownstreamClosed) => break,
413 }
414 }
415 }
416 }
417
418 struct EvenOnly;
419
420 #[async_trait]
421 impl MapOne for EvenOnly {
422 fn id(&self) -> &str {
423 "even_only"
424 }
425
426 async fn map(&self, env: Envelope) -> Result<Option<Envelope>> {
427 let n = env.payload["n"].as_i64().unwrap();
428 Ok((n % 2 == 0).then_some(env))
429 }
430 }
431
432 struct AcceptSink;
433
434 #[async_trait]
435 impl WriteOne for AcceptSink {
436 fn id(&self) -> &str {
437 "accept"
438 }
439
440 async fn write(&self, _env: &Envelope) -> Result<()> {
441 Ok(())
442 }
443 }
444
445 struct BurstSource {
446 count: usize,
447 }
448
449 #[async_trait]
450 impl Source for BurstSource {
451 fn id(&self) -> &str {
452 "burst"
453 }
454
455 async fn run(self: Box<Self>, tx: Sender<Envelope>, cancel: CancellationToken) {
456 for i in 0..self.count {
457 let env = Envelope::new("burst", json!({ "n": i }));
458 tokio::select! {
459 _ = cancel.cancelled() => break,
460 res = tx.send(env) => {
461 if res.is_err() {
462 break;
463 }
464 }
465 }
466 }
467 }
468 }
469
470 struct StallAfterFirstReceiveSink {
471 first_received: Arc<Notify>,
472 }
473
474 #[async_trait]
475 impl Sink for StallAfterFirstReceiveSink {
476 fn id(&self) -> &str {
477 "stall"
478 }
479
480 async fn run(self: Box<Self>, mut rx: Receiver<Envelope>, _cancel: CancellationToken) {
481 if rx.recv().await.is_some() {
482 self.first_received.notify_one();
483 futures::future::pending::<()>().await;
484 }
485 }
486 }
487
488 struct DrainSink;
489
490 #[async_trait]
491 impl Sink for DrainSink {
492 fn id(&self) -> &str {
493 "drain"
494 }
495
496 async fn run(self: Box<Self>, mut rx: Receiver<Envelope>, cancel: CancellationToken) {
497 loop {
498 tokio::select! {
499 _ = cancel.cancelled() => break,
500 maybe = rx.recv() => {
501 if maybe.is_none() {
502 break;
503 }
504 }
505 }
506 }
507 }
508 }
509
510 struct TraceSource {
511 source_ctx: SourceCtx,
512 }
513
514 impl TraceSource {
515 fn new() -> Self {
516 Self {
517 source_ctx: SourceCtx::new("trace/src"),
518 }
519 }
520 }
521
522 #[async_trait]
523 impl Source for TraceSource {
524 fn id(&self) -> &str {
525 "src"
526 }
527
528 fn set_node_ctx(&mut self, ctx: NodeCtx) {
529 self.source_ctx = SourceCtx::from_node_ctx(ctx);
530 }
531
532 async fn run(self: Box<Self>, tx: Sender<Envelope>, cancel: CancellationToken) {
533 let source = self.source_ctx.clone();
534 let mut env = Envelope::new("src", json!({ "n": 1 }));
535 env.meta.headers.insert(
536 TRACEPARENT.to_string(),
537 "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01".to_string(),
538 );
539 match source.send(&tx, env, &cancel).await {
540 Ok(()) | Err(SendStopped::Cancelled) | Err(SendStopped::DownstreamClosed) => {}
541 }
542 }
543 }
544
545 struct PassThrough;
546
547 #[async_trait]
548 impl MapOne for PassThrough {
549 fn id(&self) -> &str {
550 "pass"
551 }
552
553 async fn map(&self, env: Envelope) -> Result<Option<Envelope>> {
554 Ok(Some(env))
555 }
556 }
557
558 struct CaptureSink {
559 seen: Arc<Mutex<Vec<Envelope>>>,
560 }
561
562 #[async_trait]
563 impl WriteOne for CaptureSink {
564 fn id(&self) -> &str {
565 "capture"
566 }
567
568 async fn write(&self, env: &Envelope) -> Result<()> {
569 self.seen.lock().unwrap().push(env.clone());
570 Ok(())
571 }
572 }
573
574 #[tokio::test]
575 async fn node_ctx_records_pipeline_metrics() {
576 let (handle, exporter) = obs_handle_in_memory();
577 let pipeline = Pipeline::new("metrics", Box::new(HundredSource::new()))
578 .with_observability(Some(handle.clone()))
579 .with_transform(Box::new(BasicTransform::new(EvenOnly)))
580 .with_sink(Box::new(ManagedSink::new(AcceptSink)));
581
582 let handles = spawn_pipeline(pipeline, CancellationToken::new());
583 join_all(handles).await;
584 handle.shutdown();
585
586 assert_eq!(
587 counter_sum(
588 &exporter,
589 "courier_envelopes_processed_total",
590 &[("pipeline", "metrics"), ("node_id", "metrics/src")]
591 ),
592 100
593 );
594 assert_eq!(
595 counter_sum(
596 &exporter,
597 "courier_envelopes_processed_total",
598 &[("pipeline", "metrics"), ("node_id", "metrics/t0")]
599 ),
600 50
601 );
602 assert_eq!(
603 counter_sum(
604 &exporter,
605 "courier_envelopes_filtered_total",
606 &[("pipeline", "metrics"), ("node_id", "metrics/t0")]
607 ),
608 50
609 );
610 assert_eq!(
611 counter_sum(
612 &exporter,
613 "courier_envelopes_processed_total",
614 &[("pipeline", "metrics"), ("node_id", "metrics/sink0")]
615 ),
616 50
617 );
618 assert_eq!(
619 histogram_count(
620 &exporter,
621 "courier_stage_duration_milliseconds",
622 &[("pipeline", "metrics"), ("node_id", "metrics/src")]
623 ),
624 100
625 );
626 assert_eq!(
627 histogram_count(
628 &exporter,
629 "courier_stage_duration_milliseconds",
630 &[("pipeline", "metrics"), ("node_id", "metrics/t0")]
631 ),
632 100
633 );
634 assert_eq!(
635 histogram_count(
636 &exporter,
637 "courier_stage_duration_milliseconds",
638 &[("pipeline", "metrics"), ("node_id", "metrics/sink0")]
639 ),
640 50
641 );
642 }
643
644 #[tokio::test]
645 async fn broadcast_splitter_observes_cancel_while_blocked_on_sink_send() {
646 let first_received = Arc::new(Notify::new());
647 let pipeline = Pipeline::new("broadcast-cancel", Box::new(BurstSource { count: 32 }))
648 .with_channel_capacity(1)
649 .with_sink(Box::new(StallAfterFirstReceiveSink {
650 first_received: first_received.clone(),
651 }))
652 .with_sink(Box::new(DrainSink));
653
654 let cancel = CancellationToken::new();
655 let mut handles = spawn_pipeline(pipeline, cancel.clone());
656 let splitter = handles
657 .pop()
658 .expect("broadcast splitter should be the final spawned task");
659
660 first_received.notified().await;
661 tokio::time::sleep(Duration::from_millis(50)).await;
662 cancel.cancel();
663
664 let result = tokio::time::timeout(Duration::from_millis(250), splitter).await;
665 for handle in handles {
666 handle.abort();
667 }
668
669 assert!(
670 result.is_ok(),
671 "broadcast splitter did not exit promptly after cancellation"
672 );
673 }
674
675 #[test]
676 fn trace_context_propagates_across_pipeline() {
677 install_test_tracing_global();
678
679 let exporter = InMemorySpanExporter::default();
680 let provider = SdkTracerProvider::builder()
681 .with_simple_exporter(exporter.clone())
682 .build();
683 let tracer = provider.tracer("courier_test");
684 let subscriber =
685 tracing_subscriber::registry().with(tracing_opentelemetry::layer().with_tracer(tracer));
686 let dispatch = tracing::Dispatch::new(subscriber);
687 let runtime = tokio::runtime::Builder::new_current_thread()
688 .enable_all()
689 .build()
690 .unwrap();
691
692 let (metrics, _metric_exporter) = obs_handle_in_memory();
693 let seen = Arc::new(Mutex::new(Vec::new()));
694
695 tracing::dispatcher::with_default(&dispatch, || {
696 tracing::callsite::rebuild_interest_cache();
697 runtime.block_on(async {
698 let cancel = CancellationToken::new();
699 let (source_tx, transform_rx) = mpsc::channel(8);
700 let mut source = TraceSource::new();
701 source.set_node_ctx(NodeCtx::for_node(
702 "trace",
703 "trace/src",
704 NodeKind::Source,
705 metrics.clone(),
706 ));
707 Box::new(source).run(source_tx, cancel.clone()).await;
708
709 let (sink_tx, sink_rx) = mpsc::channel(8);
710 let mut transform = BasicTransform::new(PassThrough);
711 transform.set_node_ctx(NodeCtx::for_node(
712 "trace",
713 "trace/t0",
714 NodeKind::Transform,
715 metrics.clone(),
716 ));
717 Box::new(transform)
718 .run(transform_rx, sink_tx, cancel.clone())
719 .await;
720
721 let mut sink = ManagedSink::new(CaptureSink { seen: seen.clone() });
722 sink.set_node_ctx(NodeCtx::for_node(
723 "trace",
724 "trace/sink0",
725 NodeKind::Sink,
726 metrics,
727 ));
728 Box::new(sink).run(sink_rx, cancel).await;
729 });
730 tracing::callsite::rebuild_interest_cache();
731 });
732 provider.force_flush().unwrap();
733
734 let captured = seen.lock().unwrap().clone();
735 assert_eq!(captured.len(), 1);
736 assert!(
737 captured[0].meta.headers.contains_key(TRACEPARENT),
738 "sink should see refreshed trace context"
739 );
740
741 let spans = exporter.get_finished_spans().unwrap();
742 let source_span = spans
743 .iter()
744 .find(|s| s.name == "courier.source")
745 .unwrap_or_else(|| panic!("missing source span: {spans:?}"));
746 assert!(
747 source_span.attributes.iter().any(|attr| {
748 attr.key.as_str() == "pipeline"
749 && matches!(&attr.value, opentelemetry::Value::String(value) if value.as_ref() == "trace")
750 }),
751 "source span missing pipeline attribute: {source_span:?}"
752 );
753 assert!(
754 spans.iter().any(|s| s.name == "courier.transform"),
755 "missing transform span: {spans:?}"
756 );
757 assert!(
758 spans.iter().any(|s| s.name == "courier.sink"),
759 "missing sink span: {spans:?}"
760 );
761 let incoming_trace_id = "4bf92f3577b34da6a3ce929d0e0e4736";
762 assert!(
763 spans
764 .iter()
765 .all(|s| s.span_context.trace_id().to_string() == incoming_trace_id),
766 "spans did not share incoming trace id: {spans:?}"
767 );
768 }
769}