Skip to main content

courier/sinks/
api.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use anyhow::{Context, Result, anyhow};
5use async_trait::async_trait;
6use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
7use reqwest::{Client, Method, Url};
8use serde::Deserialize;
9use serde_json::Value;
10
11use crate::config::{parse_config, redact_secret};
12use crate::envelope::Envelope;
13use crate::observability::trace_context::{TRACEPARENT, TRACESTATE};
14use crate::pipeline::ErrorPolicy;
15use crate::retry::RetryPolicy;
16use crate::sinks::{ManagedSink, Sink, WriteOne};
17
18/// HTTP sink. Sends each envelope to a configured endpoint as a JSON body.
19///
20/// Returns `Err` for any non-2xx response, network error, or timeout — so
21/// `ManagedSink` applies the configured retry / `on_error` policy uniformly.
22pub struct ApiSink {
23    id: String,
24    url: String,
25    method: Method,
26    headers: HeaderMap,
27    body_format: BodyFormat,
28    client: Client,
29}
30
31/// Shape of the JSON body sent to the endpoint.
32///
33/// `Payload` (default) sends only `env.payload` — the common case for
34/// webhook receivers that already know the schema. `Envelope` sends the
35/// whole envelope (`meta` + `payload`) for receivers that need the
36/// metadata as well.
37#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum BodyFormat {
40    #[default]
41    Payload,
42    Envelope,
43}
44
45impl ApiSink {
46    pub fn new(
47        id: impl Into<String>,
48        url: impl Into<String>,
49        method: Method,
50        headers: HeaderMap,
51        body_format: BodyFormat,
52        timeout: Option<Duration>,
53    ) -> Result<Self> {
54        let mut builder = Client::builder();
55        if let Some(t) = timeout {
56            builder = builder.timeout(t);
57        }
58        let client = builder
59            .build()
60            .map_err(|e| anyhow!("failed to build HTTP client: {e}"))?;
61        Ok(Self {
62            id: id.into(),
63            url: url.into(),
64            method,
65            headers,
66            body_format,
67            client,
68        })
69    }
70}
71
72#[async_trait]
73impl WriteOne for ApiSink {
74    fn id(&self) -> &str {
75        &self.id
76    }
77
78    async fn write(&self, env: &Envelope) -> Result<()> {
79        let body = match self.body_format {
80            BodyFormat::Payload => &env.payload,
81            BodyFormat::Envelope => &serde_json::to_value(env)?,
82        };
83
84        let mut headers = self.headers.clone();
85        for key in [TRACEPARENT, TRACESTATE] {
86            if let Some(value) = env.meta.headers.get(key) {
87                let name = HeaderName::from_static(key);
88                match HeaderValue::try_from(value) {
89                    Ok(value) => {
90                        headers.insert(name, value);
91                    }
92                    Err(_) => {
93                        log::warn!("skipping invalid trace context header value for {key}");
94                    }
95                }
96            }
97        }
98
99        let resp = self
100            .client
101            .request(self.method.clone(), &self.url)
102            .headers(headers)
103            .json(body)
104            .send()
105            .await
106            .map_err(|e| {
107                let e = e.without_url();
108                anyhow!("HTTP request to {} failed: {e}", redact_secret(&self.url))
109            })?;
110
111        let status = resp.status();
112        if !status.is_success() {
113            // Surface the response body when present so failures are
114            // diagnosable from logs / dead-letter entries.
115            let body = resp.text().await.unwrap_or_default();
116            return Err(anyhow!("HTTP error {status}: {body}"));
117        }
118
119        log::debug!(
120            "[{}] {} {} -> {}",
121            redact_secret(&self.id),
122            self.method,
123            redact_secret(&self.url),
124            status
125        );
126        Ok(())
127    }
128}
129
130#[derive(Debug, Deserialize)]
131struct ApiSinkConfig {
132    url: String,
133    #[serde(default)]
134    method: Option<String>,
135    #[serde(default)]
136    headers: HashMap<String, String>,
137    #[serde(default)]
138    body: BodyFormat,
139    #[serde(default)]
140    timeout_secs: Option<u64>,
141}
142
143/// Registry factory for [`ApiSink`]. Registered by
144/// `courier::registry::register_builtin` under kind `"api"`.
145///
146/// Retry and error policy are managed centrally by the registry and applied
147/// to every sink uniformly — no per-sink config needed.
148pub fn api_sink_factory(
149    id: &str,
150    config: Value,
151    on_error: ErrorPolicy,
152    retry: Option<RetryPolicy>,
153) -> Result<Box<dyn Sink>> {
154    let config: ApiSinkConfig = parse_config("api", config)?;
155    Url::parse(&config.url).with_context(|| {
156        format!(
157            "invalid config for component type 'api': invalid url '{}'",
158            redact_secret(&config.url)
159        )
160    })?;
161
162    let method = match config.method.as_deref() {
163        None => Method::POST,
164        Some(m) => m.parse::<Method>().map_err(|_| {
165            anyhow!(
166                "invalid config for component type 'api': unsupported HTTP method '{}'",
167                redact_secret(m)
168            )
169        })?,
170    };
171
172    let mut headers = HeaderMap::with_capacity(config.headers.len());
173    for (k, v) in config.headers {
174        let name = HeaderName::try_from(&k).map_err(|_| {
175            anyhow!("invalid config for component type 'api': invalid header name '{k}'")
176        })?;
177        let value = HeaderValue::try_from(&v).map_err(|_| {
178            anyhow!("invalid config for component type 'api': invalid value for header '{k}'")
179        })?;
180        headers.insert(name, value);
181    }
182
183    let timeout = config.timeout_secs.map(Duration::from_secs);
184    let api = ApiSink::new(id, config.url, method, headers, config.body, timeout)?;
185
186    let mut sink = ManagedSink::new(api).with_error_policy(on_error);
187    if let Some(policy) = retry {
188        sink = sink.with_retry(policy);
189    }
190    Ok(Box::new(sink))
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use serde_json::json;
197    use wiremock::matchers::{body_json, header, method as method_matcher, path};
198    use wiremock::{Mock, MockServer, ResponseTemplate};
199
200    fn build_sink(
201        url: String,
202        method: Method,
203        headers: HeaderMap,
204        body_format: BodyFormat,
205    ) -> ApiSink {
206        ApiSink::new("api-sink", url, method, headers, body_format, None).unwrap()
207    }
208
209    fn closing_local_url(path: &str) -> String {
210        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
211        let addr = listener.local_addr().unwrap();
212        std::thread::spawn(move || {
213            if let Ok((stream, _)) = listener.accept() {
214                drop(stream);
215            }
216        });
217        format!("http://{addr}{path}")
218    }
219
220    #[test]
221    fn factory_rejects_invalid_url() {
222        let err = api_sink_factory(
223            "api",
224            json!({
225                "url": "not a url"
226            }),
227            ErrorPolicy::Drop,
228            None,
229        )
230        .err()
231        .expect("expected invalid URL to fail");
232        let msg = format!("{err:#}");
233        assert!(
234            msg.contains("invalid config for component type 'api'"),
235            "{msg}"
236        );
237        assert!(msg.contains("invalid url"), "{msg}");
238    }
239
240    #[tokio::test]
241    async fn posts_payload_as_json_by_default() {
242        let server = MockServer::start().await;
243        Mock::given(method_matcher("POST"))
244            .and(path("/hook"))
245            .and(body_json(json!({ "n": 7 })))
246            .respond_with(ResponseTemplate::new(202))
247            .expect(1)
248            .mount(&server)
249            .await;
250
251        let sink = build_sink(
252            format!("{}/hook", server.uri()),
253            Method::POST,
254            HeaderMap::new(),
255            BodyFormat::Payload,
256        );
257
258        let env = Envelope::new("src", json!({ "n": 7 }));
259        sink.write(&env).await.expect("write should succeed");
260    }
261
262    #[tokio::test]
263    async fn sends_full_envelope_when_body_envelope() {
264        let server = MockServer::start().await;
265        Mock::given(method_matcher("POST"))
266            .and(path("/hook"))
267            // Match on a payload field nested under "payload" — proves the
268            // envelope wrapper is present, not just the bare payload.
269            .and(wiremock::matchers::body_partial_json(json!({
270                "payload": { "n": 1 },
271                "meta": { "source_id": "src" }
272            })))
273            .respond_with(ResponseTemplate::new(200))
274            .expect(1)
275            .mount(&server)
276            .await;
277
278        let sink = build_sink(
279            format!("{}/hook", server.uri()),
280            Method::POST,
281            HeaderMap::new(),
282            BodyFormat::Envelope,
283        );
284
285        let env = Envelope::new("src", json!({ "n": 1 }));
286        sink.write(&env).await.unwrap();
287    }
288
289    #[tokio::test]
290    async fn forwards_custom_headers_and_method() {
291        let server = MockServer::start().await;
292        Mock::given(method_matcher("PUT"))
293            .and(path("/items/42"))
294            .and(header("authorization", "Bearer token-123"))
295            .and(header("x-courier-source", "courier"))
296            .respond_with(ResponseTemplate::new(204))
297            .expect(1)
298            .mount(&server)
299            .await;
300
301        let mut headers = HeaderMap::new();
302        headers.insert(
303            "authorization",
304            HeaderValue::from_static("Bearer token-123"),
305        );
306        headers.insert("x-courier-source", HeaderValue::from_static("courier"));
307
308        let sink = build_sink(
309            format!("{}/items/42", server.uri()),
310            Method::PUT,
311            headers,
312            BodyFormat::Payload,
313        );
314
315        sink.write(&Envelope::new("src", json!({}))).await.unwrap();
316    }
317
318    #[tokio::test]
319    async fn forwards_trace_context_headers() {
320        let server = MockServer::start().await;
321        let traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
322        Mock::given(method_matcher("POST"))
323            .and(path("/hook"))
324            .and(header("traceparent", traceparent))
325            .respond_with(ResponseTemplate::new(202))
326            .expect(1)
327            .mount(&server)
328            .await;
329
330        let sink = build_sink(
331            format!("{}/hook", server.uri()),
332            Method::POST,
333            HeaderMap::new(),
334            BodyFormat::Payload,
335        );
336
337        let mut env = Envelope::new("src", json!({}));
338        env.meta
339            .headers
340            .insert(TRACEPARENT.to_string(), traceparent.to_string());
341        sink.write(&env).await.unwrap();
342    }
343
344    #[tokio::test]
345    async fn skips_invalid_trace_context_headers() {
346        let server = MockServer::start().await;
347        Mock::given(method_matcher("POST"))
348            .and(path("/hook"))
349            .and(body_json(json!({ "n": 1 })))
350            .respond_with(ResponseTemplate::new(202))
351            .expect(1)
352            .mount(&server)
353            .await;
354
355        let sink = build_sink(
356            format!("{}/hook", server.uri()),
357            Method::POST,
358            HeaderMap::new(),
359            BodyFormat::Payload,
360        );
361
362        let mut env = Envelope::new("src", json!({ "n": 1 }));
363        env.meta
364            .headers
365            .insert(TRACEPARENT.to_string(), "invalid\ntraceparent".to_string());
366        env.meta
367            .headers
368            .insert(TRACESTATE.to_string(), "invalid\ntracestate".to_string());
369
370        sink.write(&env)
371            .await
372            .expect("invalid trace headers should not fail delivery");
373    }
374
375    #[tokio::test]
376    async fn non_2xx_response_is_an_error() {
377        let server = MockServer::start().await;
378        Mock::given(method_matcher("POST"))
379            .and(path("/hook"))
380            .respond_with(ResponseTemplate::new(500).set_body_string("boom"))
381            .mount(&server)
382            .await;
383
384        let sink = build_sink(
385            format!("{}/hook", server.uri()),
386            Method::POST,
387            HeaderMap::new(),
388            BodyFormat::Payload,
389        );
390
391        let err = sink
392            .write(&Envelope::new("src", json!({})))
393            .await
394            .expect_err("expected non-2xx to surface as an error");
395        let msg = format!("{err:#}");
396        assert!(msg.contains("500"), "{msg}");
397        assert!(msg.contains("boom"), "{msg}");
398    }
399
400    #[tokio::test]
401    async fn send_errors_do_not_repeat_url_from_reqwest_error() {
402        let url = closing_local_url("/token-in-url");
403        let sink = ApiSink::new(
404            "api-sink",
405            url.clone(),
406            Method::POST,
407            HeaderMap::new(),
408            BodyFormat::Payload,
409            Some(Duration::from_millis(500)),
410        )
411        .unwrap();
412
413        let err = sink
414            .write(&Envelope::new("src", json!({})))
415            .await
416            .expect_err("expected connection failure");
417        let msg = format!("{err:#}");
418        assert_eq!(msg.matches(&url).count(), 1, "{msg}");
419    }
420
421    // -----------------------------------------------------------------
422    // Factory / config parsing
423    // -----------------------------------------------------------------
424
425    #[tokio::test]
426    async fn factory_defaults_method_to_post() {
427        let server = MockServer::start().await;
428        Mock::given(method_matcher("POST"))
429            .and(path("/hook"))
430            .respond_with(ResponseTemplate::new(200))
431            .expect(1)
432            .mount(&server)
433            .await;
434
435        let sink = api_sink_factory(
436            "api",
437            json!({ "url": format!("{}/hook", server.uri()) }),
438            ErrorPolicy::Drop,
439            None,
440        )
441        .unwrap();
442
443        let (tx, rx) = tokio::sync::mpsc::channel(1);
444        let cancel = tokio_util::sync::CancellationToken::new();
445        let handle = tokio::spawn(async move { sink.run(rx, cancel).await });
446
447        tx.send(Envelope::new("src", json!({"hello": "world"})))
448            .await
449            .unwrap();
450        drop(tx);
451        handle.await.unwrap();
452    }
453
454    #[test]
455    fn factory_rejects_invalid_method() {
456        let err = api_sink_factory(
457            "api",
458            json!({ "url": "https://example.test/", "method": "FOO BAR" }),
459            ErrorPolicy::Drop,
460            None,
461        )
462        .err()
463        .expect("expected invalid-method error");
464        let msg = format!("{err:#}");
465        assert!(msg.contains("unsupported HTTP method"), "{msg}");
466    }
467
468    #[test]
469    fn factory_rejects_invalid_header_name() {
470        let err = api_sink_factory(
471            "api",
472            json!({
473                "url": "https://example.test/",
474                "headers": { "bad header": "value" }
475            }),
476            ErrorPolicy::Drop,
477            None,
478        )
479        .err()
480        .expect("expected invalid-header error");
481        let msg = format!("{err:#}");
482        assert!(msg.contains("invalid header name"), "{msg}");
483    }
484
485    #[test]
486    fn factory_reports_missing_url_with_uniform_prefix() {
487        let err = api_sink_factory("api", json!({}), ErrorPolicy::Drop, None)
488            .err()
489            .expect("expected missing-url error");
490        let msg = format!("{err:#}");
491        assert!(
492            msg.contains("invalid config for component type 'api'"),
493            "{msg}"
494        );
495    }
496}