1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
//! Utilities used for [`super::AxumHttpServerFactory`]

use async_compression::tokio::write::BrotliDecoder;
use async_compression::tokio::write::GzipDecoder;
use async_compression::tokio::write::ZlibDecoder;
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::*;
use futures::prelude::*;
use http::header::CONTENT_ENCODING;
use http::Request;
use hyper::Body;
use opentelemetry::global;
use opentelemetry::trace::TraceContextExt;
use tokio::io::AsyncWriteExt;
use tower_http::trace::MakeSpan;
use tracing::Span;

use crate::uplink::license_enforcement::LicenseState;
use crate::uplink::license_enforcement::LICENSE_EXPIRED_SHORT_MESSAGE;

pub(crate) const REQUEST_SPAN_NAME: &str = "request";

pub(super) async fn decompress_request_body(
    req: Request<Body>,
    next: Next<Body>,
) -> Result<Response, Response> {
    let (parts, body) = req.into_parts();
    let content_encoding = parts.headers.get(&CONTENT_ENCODING);
    macro_rules! decode_body {
        ($decoder: ident, $error_message: expr) => {{
            let body_bytes = hyper::body::to_bytes(body)
                .map_err(|err| {
                    (
                        StatusCode::BAD_REQUEST,
                        format!("cannot read request body: {err}"),
                    )
                        .into_response()
                })
                .await?;
            let mut decoder = $decoder::new(Vec::new());
            decoder.write_all(&body_bytes).await.map_err(|err| {
                (
                    StatusCode::BAD_REQUEST,
                    format!("{}: {err}", $error_message),
                )
                    .into_response()
            })?;
            decoder.shutdown().await.map_err(|err| {
                (
                    StatusCode::BAD_REQUEST,
                    format!("{}: {err}", $error_message),
                )
                    .into_response()
            })?;

            Ok(next
                .run(Request::from_parts(parts, Body::from(decoder.into_inner())))
                .await)
        }};
    }

    match content_encoding {
        Some(content_encoding) => match content_encoding.to_str() {
            Ok(content_encoding_str) => match content_encoding_str {
                "br" => decode_body!(BrotliDecoder, "cannot decompress (brotli) request body"),
                "gzip" => decode_body!(GzipDecoder, "cannot decompress (gzip) request body"),
                "deflate" => decode_body!(ZlibDecoder, "cannot decompress (deflate) request body"),
                "identity" => Ok(next.run(Request::from_parts(parts, body)).await),
                unknown => {
                    let message = format!("unknown content-encoding header value {unknown:?}");
                    tracing::error!(message);
                    ::tracing::error!(
                       monotonic_counter.apollo_router_http_requests_total = 1u64,
                       status = %400u16,
                       error = %message,
                    );

                    Err((StatusCode::BAD_REQUEST, message).into_response())
                }
            },

            Err(err) => {
                let message = format!("cannot read content-encoding header: {err}");
                ::tracing::error!(
                   monotonic_counter.apollo_router_http_requests_total = 1u64,
                   status = %400u16,
                   error = %message,
                );
                Err((StatusCode::BAD_REQUEST, message).into_response())
            }
        },
        None => Ok(next.run(Request::from_parts(parts, body)).await),
    }
}

#[derive(Clone, Default)]
pub(crate) struct PropagatingMakeSpan {
    pub(crate) license: LicenseState,
}

impl<B> MakeSpan<B> for PropagatingMakeSpan {
    fn make_span(&mut self, request: &http::Request<B>) -> Span {
        // This method needs to be moved to the telemetry plugin once we have a hook for the http request.

        // Before we make the span we need to attach span info that may have come in from the request.
        let context = global::get_text_map_propagator(|propagator| {
            propagator.extract(&opentelemetry_http::HeaderExtractor(request.headers()))
        });

        // If there was no span from the request then it will default to the NOOP span.
        // Attaching the NOOP span has the effect of preventing further tracing.
        if context.span().span_context().is_valid()
            || context.span().span_context().trace_id() != opentelemetry::trace::TraceId::INVALID
        {
            // We have a valid remote span, attach it to the current thread before creating the root span.
            let _context_guard = context.attach();
            self.create_span(request)
        } else {
            // No remote span, we can go ahead and create the span without context.
            self.create_span(request)
        }
    }
}

impl PropagatingMakeSpan {
    fn create_span<B>(&mut self, request: &Request<B>) -> Span {
        if matches!(
            self.license,
            LicenseState::LicensedWarn | LicenseState::LicensedHalt
        ) {
            tracing::error_span!(
                REQUEST_SPAN_NAME,
                "http.method" = %request.method(),
                "http.route" = %request.uri(),
                "http.flavor" = ?request.version(),
                "http.status" = 500, // This prevents setting later
                "otel.name" = ::tracing::field::Empty,
                "otel.kind" = "SERVER",
                "graphql.operation.name" = ::tracing::field::Empty,
                "graphql.operation.type" = ::tracing::field::Empty,
                "apollo_router.license" = LICENSE_EXPIRED_SHORT_MESSAGE,
                "apollo_private.request" = true,
            )
        } else {
            tracing::info_span!(
                REQUEST_SPAN_NAME,
                "http.method" = %request.method(),
                "http.route" = %request.uri(),
                "http.flavor" = ?request.version(),
                "otel.name" = ::tracing::field::Empty,
                "otel.kind" = "SERVER",
                "graphql.operation.name" = ::tracing::field::Empty,
                "graphql.operation.type" = ::tracing::field::Empty,
                "apollo_private.request" = true,
            )
        }
    }
}