opentelemetry_lambda_extension/
receiver.rs

1//! OTLP receiver for collecting signals from Lambda functions.
2//!
3//! This module provides HTTP endpoints that receive OTLP signals (traces, metrics, logs)
4//! from the Lambda function. It supports both protobuf and JSON content types.
5
6use crate::config::ReceiverConfig;
7use axum::{
8    Json, Router,
9    body::Bytes,
10    extract::State,
11    http::{
12        HeaderMap, StatusCode,
13        header::{CONTENT_ENCODING, CONTENT_TYPE},
14    },
15    response::IntoResponse,
16    routing::{get, post},
17};
18use flate2::read::GzDecoder;
19use opentelemetry_proto::tonic::collector::{
20    logs::v1::ExportLogsServiceRequest, metrics::v1::ExportMetricsServiceRequest,
21    trace::v1::ExportTraceServiceRequest,
22};
23use prost::Message;
24use serde::Serialize;
25use std::io::Read;
26use std::net::SocketAddr;
27use std::sync::Arc;
28use std::sync::atomic::{AtomicU64, Ordering};
29use tokio::net::TcpListener;
30use tokio::sync::{Notify, mpsc};
31use tokio_util::sync::CancellationToken;
32
33/// Signals received from the Lambda function.
34#[non_exhaustive]
35#[derive(Debug, Clone)]
36pub enum Signal {
37    /// Trace spans.
38    Traces(ExportTraceServiceRequest),
39    /// Metrics.
40    Metrics(ExportMetricsServiceRequest),
41    /// Log records.
42    Logs(ExportLogsServiceRequest),
43}
44
45/// Handle for interacting with a running OTLP receiver.
46///
47/// This handle can be used to query the receiver's status, trigger flushes,
48/// and get the actual bound address (useful when port 0 is used for dynamic allocation).
49#[derive(Clone)]
50pub struct ReceiverHandle {
51    state: Arc<ReceiverState>,
52    local_addr: SocketAddr,
53}
54
55impl ReceiverHandle {
56    /// Returns the actual bound address of the receiver.
57    pub fn local_addr(&self) -> SocketAddr {
58        self.local_addr
59    }
60
61    /// Returns the port the receiver is listening on.
62    pub fn port(&self) -> u16 {
63        self.local_addr.port()
64    }
65
66    /// Returns the URL for the OTLP HTTP receiver.
67    pub fn url(&self) -> String {
68        format!("http://{}", self.local_addr)
69    }
70
71    /// Returns the number of signals received.
72    pub fn signals_received(&self) -> u64 {
73        self.state.signals_received.load(Ordering::Relaxed)
74    }
75
76    /// Triggers an immediate flush and waits for it to complete.
77    ///
78    /// Returns `Ok(())` when the flush completes, or `Err` on timeout.
79    pub async fn flush(&self, timeout: std::time::Duration) -> Result<(), FlushError> {
80        // Signal that a flush is requested
81        self.state.flush_requested.notify_one();
82
83        // Wait for flush to complete
84        tokio::time::timeout(timeout, self.state.flush_complete.notified())
85            .await
86            .map_err(|_| FlushError::Timeout)?;
87
88        Ok(())
89    }
90
91    /// Notifies that a flush has completed.
92    ///
93    /// This should be called by the runtime after flushing all signals.
94    pub fn notify_flush_complete(&self) {
95        self.state.flush_complete.notify_waiters();
96    }
97
98    /// Returns a future that resolves when a flush is requested.
99    pub async fn wait_for_flush_request(&self) {
100        self.state.flush_requested.notified().await;
101    }
102
103    /// Returns a reference to the flush request notifier.
104    pub fn flush_requested_notify(&self) -> Arc<Notify> {
105        self.state.flush_requested.clone()
106    }
107}
108
109/// Error returned when a flush operation fails.
110#[non_exhaustive]
111#[derive(Debug, Clone, thiserror::Error)]
112pub enum FlushError {
113    /// The flush operation timed out.
114    #[error("flush operation timed out")]
115    Timeout,
116}
117
118/// OTLP HTTP receiver for collecting signals.
119pub struct OtlpReceiver {
120    config: ReceiverConfig,
121    signal_tx: mpsc::Sender<Signal>,
122    cancel_token: CancellationToken,
123}
124
125impl OtlpReceiver {
126    /// Creates a new OTLP receiver.
127    ///
128    /// # Arguments
129    ///
130    /// * `config` - Receiver configuration
131    /// * `signal_tx` - Channel for sending received signals to the aggregator
132    /// * `cancel_token` - Token for graceful shutdown
133    pub fn new(
134        config: ReceiverConfig,
135        signal_tx: mpsc::Sender<Signal>,
136        cancel_token: CancellationToken,
137    ) -> Self {
138        Self {
139            config,
140            signal_tx,
141            cancel_token,
142        }
143    }
144
145    /// Starts the HTTP receiver and returns a handle for interacting with it.
146    ///
147    /// The handle can be used to query the receiver's status, trigger flushes,
148    /// and get the actual bound address.
149    ///
150    /// # Returns
151    ///
152    /// Returns `Ok((handle, future))` where `handle` can be used to interact with
153    /// the receiver and `future` should be spawned to run the server.
154    ///
155    /// # Errors
156    ///
157    /// Returns an error if the server fails to bind to the address.
158    pub async fn start(
159        self,
160    ) -> Result<
161        (
162            ReceiverHandle,
163            std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
164        ),
165        std::io::Error,
166    > {
167        if !self.config.http_enabled {
168            tracing::info!("HTTP receiver disabled");
169            let state = Arc::new(ReceiverState::new(self.signal_tx));
170            let handle = ReceiverHandle {
171                state,
172                local_addr: SocketAddr::from(([127, 0, 0, 1], 0)),
173            };
174            let cancel_token = self.cancel_token;
175            let future = Box::pin(async move {
176                cancel_token.cancelled().await;
177            });
178            return Ok((handle, future));
179        }
180
181        let addr = SocketAddr::from(([127, 0, 0, 1], self.config.http_port));
182        let listener = TcpListener::bind(addr).await?;
183        let local_addr = listener.local_addr()?;
184
185        let state = Arc::new(ReceiverState::new(self.signal_tx));
186        let handle = ReceiverHandle {
187            state: state.clone(),
188            local_addr,
189        };
190
191        let app = Router::new()
192            .route("/health", get(handle_health))
193            .route("/v1/traces", post(handle_traces))
194            .route("/v1/metrics", post(handle_metrics))
195            .route("/v1/logs", post(handle_logs))
196            .with_state(state);
197
198        tracing::info!(port = local_addr.port(), "OTLP HTTP receiver started");
199
200        let cancel_token = self.cancel_token;
201        let future = Box::pin(async move {
202            let _ = axum::serve(listener, app)
203                .with_graceful_shutdown(cancel_token.cancelled_owned())
204                .await;
205        });
206
207        Ok((handle, future))
208    }
209}
210
211/// Health check response.
212#[derive(Debug, Clone, Serialize)]
213pub struct HealthResponse {
214    /// Status of the receiver ("ready" or "starting").
215    pub status: &'static str,
216    /// Number of signals received.
217    pub signals_received: u64,
218}
219
220struct ReceiverState {
221    signal_tx: mpsc::Sender<Signal>,
222    signals_received: AtomicU64,
223    flush_requested: Arc<Notify>,
224    flush_complete: Arc<Notify>,
225}
226
227impl ReceiverState {
228    fn new(signal_tx: mpsc::Sender<Signal>) -> Self {
229        Self {
230            signal_tx,
231            signals_received: AtomicU64::new(0),
232            flush_requested: Arc::new(Notify::new()),
233            flush_complete: Arc::new(Notify::new()),
234        }
235    }
236}
237
238async fn handle_health(State(state): State<Arc<ReceiverState>>) -> Json<HealthResponse> {
239    Json(HealthResponse {
240        status: "ready",
241        signals_received: state.signals_received.load(Ordering::Relaxed),
242    })
243}
244
245async fn handle_traces(
246    State(state): State<Arc<ReceiverState>>,
247    headers: HeaderMap,
248    body: Bytes,
249) -> impl IntoResponse {
250    let content_type = headers.get(CONTENT_TYPE).cloned();
251    let content_encoding = headers.get(CONTENT_ENCODING).cloned();
252    let request =
253        match parse_request::<ExportTraceServiceRequest>(&content_type, &content_encoding, &body) {
254            Ok(req) => req,
255            Err(e) => return e,
256        };
257
258    match state.signal_tx.try_send(Signal::Traces(request)) {
259        Ok(()) => {
260            state.signals_received.fetch_add(1, Ordering::Relaxed);
261            StatusCode::OK
262        }
263        Err(mpsc::error::TrySendError::Full(_)) => {
264            tracing::warn!("Trace signal channel full, signalling backpressure");
265            StatusCode::SERVICE_UNAVAILABLE
266        }
267        Err(mpsc::error::TrySendError::Closed(_)) => {
268            tracing::error!("Trace signal channel closed");
269            StatusCode::INTERNAL_SERVER_ERROR
270        }
271    }
272}
273
274async fn handle_metrics(
275    State(state): State<Arc<ReceiverState>>,
276    headers: HeaderMap,
277    body: Bytes,
278) -> impl IntoResponse {
279    let content_type = headers.get(CONTENT_TYPE).cloned();
280    let content_encoding = headers.get(CONTENT_ENCODING).cloned();
281    let request =
282        match parse_request::<ExportMetricsServiceRequest>(&content_type, &content_encoding, &body)
283        {
284            Ok(req) => req,
285            Err(e) => return e,
286        };
287
288    match state.signal_tx.try_send(Signal::Metrics(request)) {
289        Ok(()) => {
290            state.signals_received.fetch_add(1, Ordering::Relaxed);
291            StatusCode::OK
292        }
293        Err(mpsc::error::TrySendError::Full(_)) => {
294            tracing::warn!("Metrics signal channel full, signalling backpressure");
295            StatusCode::SERVICE_UNAVAILABLE
296        }
297        Err(mpsc::error::TrySendError::Closed(_)) => {
298            tracing::error!("Metrics signal channel closed");
299            StatusCode::INTERNAL_SERVER_ERROR
300        }
301    }
302}
303
304async fn handle_logs(
305    State(state): State<Arc<ReceiverState>>,
306    headers: HeaderMap,
307    body: Bytes,
308) -> impl IntoResponse {
309    let content_type = headers.get(CONTENT_TYPE).cloned();
310    let content_encoding = headers.get(CONTENT_ENCODING).cloned();
311    let request =
312        match parse_request::<ExportLogsServiceRequest>(&content_type, &content_encoding, &body) {
313            Ok(req) => req,
314            Err(e) => return e,
315        };
316
317    match state.signal_tx.try_send(Signal::Logs(request)) {
318        Ok(()) => {
319            state.signals_received.fetch_add(1, Ordering::Relaxed);
320            StatusCode::OK
321        }
322        Err(mpsc::error::TrySendError::Full(_)) => {
323            tracing::warn!("Logs signal channel full, signalling backpressure");
324            StatusCode::SERVICE_UNAVAILABLE
325        }
326        Err(mpsc::error::TrySendError::Closed(_)) => {
327            tracing::error!("Logs signal channel closed");
328            StatusCode::INTERNAL_SERVER_ERROR
329        }
330    }
331}
332
333fn parse_request<T>(
334    content_type: &Option<axum::http::HeaderValue>,
335    content_encoding: &Option<axum::http::HeaderValue>,
336    body: &Bytes,
337) -> Result<T, StatusCode>
338where
339    T: Message + Default + serde::de::DeserializeOwned,
340{
341    let is_gzip = content_encoding
342        .as_ref()
343        .and_then(|ce| ce.to_str().ok())
344        .is_some_and(|ce| ce.contains("gzip"));
345
346    let decompressed: Vec<u8>;
347    let body_bytes: &[u8] = if is_gzip {
348        decompressed = decompress_gzip(body)?;
349        &decompressed
350    } else {
351        body.as_ref()
352    };
353
354    let is_json = content_type
355        .as_ref()
356        .and_then(|ct| ct.to_str().ok())
357        .is_some_and(|ct| ct.contains("application/json"));
358
359    if is_json {
360        serde_json::from_slice(body_bytes).map_err(|e| {
361            tracing::error!(error = %e, "Failed to parse JSON request");
362            StatusCode::BAD_REQUEST
363        })
364    } else {
365        T::decode(body_bytes).map_err(|e| {
366            tracing::error!(error = %e, "Failed to parse protobuf request");
367            StatusCode::BAD_REQUEST
368        })
369    }
370}
371
372fn decompress_gzip(body: &Bytes) -> Result<Vec<u8>, StatusCode> {
373    let mut decoder = GzDecoder::new(body.as_ref());
374    let mut decompressed = Vec::new();
375    decoder.read_to_end(&mut decompressed).map_err(|e| {
376        tracing::error!(error = %e, "Failed to decompress gzip body");
377        StatusCode::BAD_REQUEST
378    })?;
379    Ok(decompressed)
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use opentelemetry_proto::tonic::trace::v1::{ResourceSpans, ScopeSpans, Span};
386
387    #[test]
388    fn test_signal_debug() {
389        let request = ExportTraceServiceRequest::default();
390        let signal = Signal::Traces(request);
391        let debug = format!("{:?}", signal);
392        assert!(debug.contains("Traces"));
393    }
394
395    #[test]
396    fn test_parse_traces_protobuf() {
397        let request = ExportTraceServiceRequest {
398            resource_spans: vec![ResourceSpans {
399                scope_spans: vec![ScopeSpans {
400                    spans: vec![Span {
401                        name: "test-span".to_string(),
402                        ..Default::default()
403                    }],
404                    ..Default::default()
405                }],
406                ..Default::default()
407            }],
408        };
409
410        let encoded = request.encode_to_vec();
411        let content_type = Some(axum::http::HeaderValue::from_static(
412            "application/x-protobuf",
413        ));
414
415        let parsed: ExportTraceServiceRequest =
416            parse_request(&content_type, &None, &Bytes::from(encoded)).unwrap();
417
418        assert_eq!(
419            parsed.resource_spans[0].scope_spans[0].spans[0].name,
420            "test-span"
421        );
422    }
423
424    #[test]
425    fn test_parse_traces_json() {
426        let json = r#"{"resourceSpans":[]}"#;
427        let content_type = Some(axum::http::HeaderValue::from_static("application/json"));
428
429        let parsed: ExportTraceServiceRequest =
430            parse_request(&content_type, &None, &Bytes::from(json)).unwrap();
431
432        assert!(parsed.resource_spans.is_empty());
433    }
434
435    #[test]
436    fn test_parse_invalid_protobuf() {
437        let content_type = Some(axum::http::HeaderValue::from_static(
438            "application/x-protobuf",
439        ));
440        let result: Result<ExportTraceServiceRequest, _> =
441            parse_request(&content_type, &None, &Bytes::from("invalid"));
442
443        assert!(result.is_err());
444    }
445
446    #[test]
447    fn test_parse_invalid_json() {
448        let content_type = Some(axum::http::HeaderValue::from_static("application/json"));
449        let result: Result<ExportTraceServiceRequest, _> =
450            parse_request(&content_type, &None, &Bytes::from("{invalid}"));
451
452        assert!(result.is_err());
453    }
454
455    #[test]
456    fn test_parse_gzip_compressed_protobuf() {
457        use flate2::Compression;
458        use flate2::write::GzEncoder;
459        use std::io::Write;
460
461        let request = ExportTraceServiceRequest {
462            resource_spans: vec![ResourceSpans {
463                scope_spans: vec![ScopeSpans {
464                    spans: vec![Span {
465                        name: "compressed-span".to_string(),
466                        ..Default::default()
467                    }],
468                    ..Default::default()
469                }],
470                ..Default::default()
471            }],
472        };
473
474        let encoded = request.encode_to_vec();
475        let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
476        encoder.write_all(&encoded).unwrap();
477        let compressed = encoder.finish().unwrap();
478
479        let content_type = Some(axum::http::HeaderValue::from_static(
480            "application/x-protobuf",
481        ));
482        let content_encoding = Some(axum::http::HeaderValue::from_static("gzip"));
483
484        let parsed: ExportTraceServiceRequest =
485            parse_request(&content_type, &content_encoding, &Bytes::from(compressed)).unwrap();
486
487        assert_eq!(
488            parsed.resource_spans[0].scope_spans[0].spans[0].name,
489            "compressed-span"
490        );
491    }
492}