Skip to main content

apollo_router/services/
router.rs

1#![allow(missing_docs)] // FIXME
2
3use std::any::Any;
4use std::mem;
5
6use ahash::HashMap;
7use bytes::Bytes;
8use displaydoc::Display;
9use futures::Stream;
10use futures::StreamExt;
11use futures::future::Either;
12use http::HeaderValue;
13use http::Method;
14use http::StatusCode;
15use http::header::CONTENT_TYPE;
16use http::header::HeaderName;
17use http_body_util::BodyExt;
18use multer::Multipart;
19use multimap::MultiMap;
20use serde_json_bytes::ByteString;
21use serde_json_bytes::Map as JsonMap;
22use serde_json_bytes::Value;
23use static_assertions::assert_impl_all;
24use thiserror::Error;
25use tower::BoxError;
26use uuid::Uuid;
27
28use self::body::RouterBody;
29use self::service::MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE;
30use self::service::MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE;
31use super::supergraph;
32use crate::Context;
33use crate::context::CONTAINS_GRAPHQL_ERROR;
34use crate::context::ROUTER_RESPONSE_ERRORS;
35use crate::graphql;
36use crate::http_ext::header_map;
37use crate::json_ext::Path;
38use crate::plugins::telemetry::config_new::router::events::RouterResponseBodyExtensionType;
39use crate::services::TryIntoHeaderName;
40use crate::services::TryIntoHeaderValue;
41
42pub type BoxService = tower::util::BoxService<Request, Response, BoxError>;
43pub type BoxCloneService = tower::util::BoxCloneService<Request, Response, BoxError>;
44pub type ServiceResult = Result<Response, BoxError>;
45
46pub type Body = RouterBody;
47pub type Error = hyper::Error;
48
49mod batching;
50pub mod body;
51pub(crate) mod pipeline_handle;
52pub(crate) mod service;
53#[cfg(test)]
54mod tests;
55mod tower_compat;
56
57assert_impl_all!(Request: Send);
58/// Represents the router processing step of the processing pipeline.
59///
60/// This consists of the parsed graphql Request, HTTP headers and contextual data for extensions.
61#[non_exhaustive]
62pub struct Request {
63    /// Original request to the Router.
64    pub router_request: http::Request<Body>,
65
66    /// Context for extension
67    pub context: Context,
68}
69
70impl From<(http::Request<Body>, Context)> for Request {
71    fn from((router_request, context): (http::Request<Body>, Context)) -> Self {
72        Self {
73            router_request,
74            context,
75        }
76    }
77}
78
79/// Helper type to conveniently construct a body from several types used commonly in tests.
80///
81/// It's only meant for integration tests, as the "real" router should create bodies explicitly accounting for
82/// streaming, size limits, etc.
83pub struct IntoBody(Body);
84
85impl From<Body> for IntoBody {
86    fn from(value: Body) -> Self {
87        Self(value)
88    }
89}
90impl From<String> for IntoBody {
91    fn from(value: String) -> Self {
92        Self(body::from_bytes(value))
93    }
94}
95impl From<Bytes> for IntoBody {
96    fn from(value: Bytes) -> Self {
97        Self(body::from_bytes(value))
98    }
99}
100impl From<Vec<u8>> for IntoBody {
101    fn from(value: Vec<u8>) -> Self {
102        Self(body::from_bytes(value))
103    }
104}
105
106#[buildstructor::buildstructor]
107impl Request {
108    /// This is the constructor (or builder) to use when constructing a real Request.
109    ///
110    /// Required parameters are required in non-testing code to create a Request.
111    #[builder(visibility = "pub")]
112    fn new(
113        context: Context,
114        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
115        uri: http::Uri,
116        method: Method,
117        body: Body,
118    ) -> Result<Request, BoxError> {
119        let mut router_request = http::Request::builder()
120            .uri(uri)
121            .method(method)
122            .body(body)?;
123        *router_request.headers_mut() = header_map(headers)?;
124        Ok(Self {
125            router_request,
126            context,
127        })
128    }
129
130    /// This is the constructor (or builder) to use when constructing a fake Request.
131    ///
132    /// Required parameters are required in non-testing code to create a Request.
133    #[builder(visibility = "pub")]
134    fn fake_new(
135        context: Option<Context>,
136        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
137        uri: Option<http::Uri>,
138        method: Option<Method>,
139        body: Option<IntoBody>,
140    ) -> Result<Request, BoxError> {
141        let mut router_request = http::Request::builder()
142            .uri(uri.unwrap_or_else(|| http::Uri::from_static("http://example.com/")))
143            .method(method.unwrap_or(Method::GET))
144            .body(body.map_or_else(body::empty, |constructed| constructed.0))?;
145        *router_request.headers_mut() = header_map(headers)?;
146        Ok(Self {
147            router_request,
148            context: context.unwrap_or_default(),
149        })
150    }
151}
152
153#[derive(Error, Display, Debug)]
154pub enum ParseError {
155    /// couldn't create a valid http GET uri '{0}'
156    InvalidUri(http::uri::InvalidUri),
157    /// couldn't urlencode the GraphQL request body '{0}'
158    UrlEncodeError(serde_urlencoded::ser::Error),
159    /// couldn't serialize the GraphQL request body '{0}'
160    SerializationError(serde_json::Error),
161}
162
163/// This is handy for tests.
164impl TryFrom<supergraph::Request> for Request {
165    type Error = ParseError;
166    fn try_from(request: supergraph::Request) -> Result<Self, Self::Error> {
167        let supergraph::Request {
168            context,
169            supergraph_request,
170            ..
171        } = request;
172
173        let (mut parts, request) = supergraph_request.into_parts();
174
175        let router_request = if parts.method == Method::GET {
176            // get request
177            let get_path = serde_urlencoded::to_string([
178                ("query", request.query),
179                ("operationName", request.operation_name),
180                (
181                    "extensions",
182                    serde_json::to_string(&request.extensions).ok(),
183                ),
184                ("variables", serde_json::to_string(&request.variables).ok()),
185            ])
186            .map_err(ParseError::UrlEncodeError)?;
187
188            parts.uri = format!("{}?{}", parts.uri, get_path)
189                .parse()
190                .map_err(ParseError::InvalidUri)?;
191
192            http::Request::from_parts(parts, body::empty())
193        } else {
194            http::Request::from_parts(
195                parts,
196                body::from_bytes(
197                    serde_json::to_vec(&request).map_err(ParseError::SerializationError)?,
198                ),
199            )
200        };
201        Ok(Self {
202            router_request,
203            context,
204        })
205    }
206}
207
208assert_impl_all!(Response: Send);
209#[non_exhaustive]
210#[derive(Debug)]
211pub struct Response {
212    pub response: http::Response<Body>,
213    pub context: Context,
214}
215
216#[buildstructor::buildstructor]
217impl Response {
218    fn stash_the_body_in_extensions(&mut self, body_string: String) {
219        self.context.extensions().with_lock(|ext| {
220            ext.insert(RouterResponseBodyExtensionType(body_string));
221        });
222    }
223
224    pub async fn next_response(&mut self) -> Option<Result<Bytes, axum::Error>> {
225        self.response.body_mut().into_data_stream().next().await
226    }
227
228    /// This is the constructor (or builder) to use when constructing a real Response.
229    ///
230    /// Required parameters are required in non-testing code to create a Response.
231    #[builder(visibility = "pub")]
232    fn new(
233        label: Option<String>,
234        data: Option<serde_json_bytes::Value>,
235        path: Option<Path>,
236        errors: Vec<graphql::Error>,
237        // Skip the `Object` type alias to use buildstructor’s map special-casing
238        extensions: JsonMap<ByteString, serde_json_bytes::Value>,
239        status_code: Option<StatusCode>,
240        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
241        context: Context,
242    ) -> Result<Self, BoxError> {
243        if !errors.is_empty() {
244            Self::add_errors_to_context(&errors, &context);
245        }
246
247        // Build a response
248        let b = graphql::Response::builder()
249            .and_label(label)
250            .and_path(path)
251            .errors(errors)
252            .extensions(extensions);
253        let res = match data {
254            Some(data) => b.data(data).build(),
255            None => b.build(),
256        };
257
258        // Build an HTTP Response
259        let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
260        for (key, values) in headers {
261            let header_name: HeaderName = key.try_into()?;
262            for value in values {
263                let header_value: HeaderValue = value.try_into()?;
264                builder = builder.header(header_name.clone(), header_value);
265            }
266        }
267
268        let body_string = serde_json::to_string(&res)?;
269
270        let body = body::from_bytes(body_string.clone());
271        let response = builder.body(body)?;
272        // Stash the body in the extensions so we can access it later
273        let mut response = Self { response, context };
274        response.stash_the_body_in_extensions(body_string);
275
276        Ok(response)
277    }
278
279    #[builder(visibility = "pub")]
280    fn http_response_new(
281        response: http::Response<Body>,
282        context: Context,
283        body_to_stash: Option<String>,
284        errors_for_context: Option<Vec<graphql::Error>>,
285    ) -> Result<Self, BoxError> {
286        // There are instances where we have errors that need to be counted for telemetry in this
287        // layer, but we don't want to deserialize the body. In these cases we can pass in the
288        // list of errors to add to context for counting later in the telemetry plugin.
289        if let Some(errors) = errors_for_context
290            && !errors.is_empty()
291        {
292            Self::add_errors_to_context(&errors, &context);
293        }
294        let mut res = Self { response, context };
295        if let Some(body_to_stash) = body_to_stash {
296            res.stash_the_body_in_extensions(body_to_stash)
297        }
298        Ok(res)
299    }
300
301    /// This is the constructor (or builder) to use when constructing a Response that represents a global error.
302    /// It has no path and no response data.
303    /// This is useful for things such as authentication errors.
304    #[builder(visibility = "pub")]
305    fn error_new(
306        errors: Vec<graphql::Error>,
307        status_code: Option<StatusCode>,
308        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
309        context: Context,
310    ) -> Result<Self, BoxError> {
311        Response::new(
312            Default::default(),
313            Default::default(),
314            None,
315            errors,
316            Default::default(),
317            status_code,
318            headers,
319            context,
320        )
321    }
322
323    /// This is the constructor (or builder) to use when constructing a real Response.
324    ///
325    /// Required parameters are required in non-testing code to create a Response.
326    #[builder(visibility = "pub(crate)")]
327    fn infallible_new(
328        label: Option<String>,
329        data: Option<serde_json_bytes::Value>,
330        path: Option<Path>,
331        errors: Vec<graphql::Error>,
332        // Skip the `Object` type alias to use buildstructor’s map special-casing
333        extensions: JsonMap<ByteString, serde_json_bytes::Value>,
334        status_code: Option<StatusCode>,
335        headers: MultiMap<HeaderName, HeaderValue>,
336        context: Context,
337    ) -> Self {
338        if !errors.is_empty() {
339            Self::add_errors_to_context(&errors, &context);
340        }
341
342        // Build a response
343        let b = graphql::Response::builder()
344            .and_label(label)
345            .and_path(path)
346            .errors(errors)
347            .extensions(extensions);
348        let res = match data {
349            Some(data) => b.data(data).build(),
350            None => b.build(),
351        };
352
353        // Build an http Response
354        let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
355        for (header_name, values) in headers {
356            for header_value in values {
357                builder = builder.header(header_name.clone(), header_value);
358            }
359        }
360
361        let body_string = serde_json::to_string(&res).expect("JSON is always a valid string");
362
363        let body = body::from_bytes(body_string.clone());
364        let response = builder.body(body).expect("RouterBody is always valid");
365
366        Self { response, context }
367    }
368
369    fn add_errors_to_context(errors: &[graphql::Error], context: &Context) {
370        context.insert_json_value(CONTAINS_GRAPHQL_ERROR, Value::Bool(true));
371        // This is ONLY guaranteed to capture errors if any were added during router service
372        // processing. We will sometimes avoid this path if no router service errors exist, even
373        // if errors were passed from the supergraph service, because that path builds the
374        // router::Response using parts_new(). This is ok because we only need this context to
375        // count errors introduced in the router service; however, it means that we handle error
376        // counting differently in this layer than others.
377        context
378            .insert(
379                ROUTER_RESPONSE_ERRORS,
380                // We can't serialize the apollo_id, so make a map with id as the key
381                errors
382                    .iter()
383                    .cloned()
384                    .map(|err| (err.apollo_id(), err))
385                    .collect::<HashMap<Uuid, graphql::Error>>(),
386            )
387            .expect("Unable to serialize router response errors list for context");
388    }
389
390    /// EXPERIMENTAL: THIS FUNCTION IS EXPERIMENTAL AND SUBJECT TO POTENTIAL CHANGE.
391    pub async fn into_graphql_response_stream(
392        self,
393    ) -> impl Stream<Item = Result<graphql::Response, serde_json::Error>> {
394        Box::pin(
395            if self
396                .response
397                .headers()
398                .get(CONTENT_TYPE)
399                .iter()
400                .any(|value| {
401                    *value == MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE
402                        || *value == MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE
403                })
404            {
405                let multipart = Multipart::new(
406                    http_body_util::BodyDataStream::new(self.response.into_body()),
407                    "graphql",
408                );
409
410                Either::Left(futures::stream::unfold(multipart, |mut m| async {
411                    if let Ok(Some(response)) = m.next_field().await
412                        && let Ok(bytes) = response.bytes().await
413                    {
414                        return Some((serde_json::from_slice::<graphql::Response>(&bytes), m));
415                    }
416                    None
417                }))
418            } else {
419                let mut body = http_body_util::BodyDataStream::new(self.response.into_body());
420                let res = body.next().await.and_then(|res| res.ok());
421
422                Either::Right(
423                    futures::stream::iter(res)
424                        .map(|bytes| serde_json::from_slice::<graphql::Response>(&bytes)),
425                )
426            },
427        )
428    }
429
430    /// This is the constructor (or builder) to use when constructing a fake Response.
431    ///
432    /// Required parameters are required in non-testing code to create a Response.
433    #[builder(visibility = "pub")]
434    fn fake_new(
435        label: Option<String>,
436        data: Option<serde_json_bytes::Value>,
437        path: Option<Path>,
438        errors: Vec<graphql::Error>,
439        // Skip the `Object` type alias to use buildstructor’s map special-casing
440        extensions: JsonMap<ByteString, serde_json_bytes::Value>,
441        status_code: Option<StatusCode>,
442        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
443        context: Option<Context>,
444    ) -> Result<Self, BoxError> {
445        // Build a response
446        Self::new(
447            label,
448            data,
449            path,
450            errors,
451            extensions,
452            status_code,
453            headers,
454            context.unwrap_or_default(),
455        )
456    }
457}
458
459#[derive(Clone, Default, Debug)]
460pub(crate) struct ClientRequestAccepts {
461    pub(crate) multipart_defer: bool,
462    pub(crate) multipart_subscription: bool,
463    pub(crate) json: bool,
464    pub(crate) wildcard: bool,
465}
466
467impl<T> From<http::Response<T>> for Response
468where
469    T: http_body::Body<Data = Bytes> + Send + 'static,
470    <T as http_body::Body>::Error: Into<BoxError>,
471{
472    fn from(response: http::Response<T>) -> Self {
473        let context: Context = response.extensions().get().cloned().unwrap_or_default();
474
475        Self {
476            response: response.map(convert_to_body),
477            context,
478        }
479    }
480}
481
482impl From<Response> for http::Response<Body> {
483    fn from(mut response: Response) -> Self {
484        response.response.extensions_mut().insert(response.context);
485        response.response
486    }
487}
488
489impl<T> From<http::Request<T>> for Request
490where
491    T: http_body::Body<Data = Bytes> + Send + 'static,
492    <T as http_body::Body>::Error: Into<BoxError>,
493{
494    fn from(request: http::Request<T>) -> Self {
495        let context: Context = request.extensions().get().cloned().unwrap_or_default();
496
497        Self {
498            router_request: request.map(convert_to_body),
499            context,
500        }
501    }
502}
503
504impl From<Request> for http::Request<Body> {
505    fn from(mut request: Request) -> Self {
506        request
507            .router_request
508            .extensions_mut()
509            .insert(request.context);
510        request.router_request
511    }
512}
513
514/// This function is used to convert an `http_body::Body` into a `Body`.
515/// It does a downcast check to see if the body is already a `Body` and if it is, then it just returns it.
516/// There is zero overhead if the body is already a `Body`.
517/// Note that ALL graphql responses are already a stream as they may be part of a deferred or stream response,
518/// therefore, if a body has to be wrapped, the cost is minimal.
519fn convert_to_body<T>(mut b: T) -> Body
520where
521    T: http_body::Body<Data = Bytes> + Send + 'static,
522    <T as http_body::Body>::Error: Into<BoxError>,
523{
524    let val_any = &mut b as &mut dyn Any;
525    match val_any.downcast_mut::<Body>() {
526        Some(body) => mem::take(body),
527        None => Body::new(http_body_util::BodyStream::new(b.map_err(axum::Error::new))),
528    }
529}
530
531#[cfg(test)]
532mod test {
533    use std::pin::Pin;
534    use std::task::Context;
535    use std::task::Poll;
536
537    use tower::BoxError;
538
539    use super::convert_to_body;
540    use crate::services::router;
541
542    struct MockBody {
543        data: Option<&'static str>,
544    }
545    impl http_body::Body for MockBody {
546        type Data = bytes::Bytes;
547        type Error = BoxError;
548
549        fn poll_frame(
550            self: Pin<&mut Self>,
551            _cx: &mut Context<'_>,
552        ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
553            if let Some(data) = self.get_mut().data.take() {
554                Poll::Ready(Some(Ok(http_body::Frame::data(bytes::Bytes::from(data)))))
555            } else {
556                Poll::Ready(None)
557            }
558        }
559    }
560
561    #[tokio::test]
562    async fn test_convert_from_http_body() {
563        let body = convert_to_body(MockBody { data: Some("test") });
564        assert_eq!(
565            &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
566            "test"
567        );
568    }
569
570    #[tokio::test]
571    async fn test_convert_from_hyper_body() {
572        let body = convert_to_body(String::from("test"));
573        assert_eq!(
574            &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
575            "test"
576        );
577    }
578}