Skip to main content

apollo_router/services/
router.rs

1#![allow(missing_docs)] // FIXME
2
3use std::any::Any;
4use std::mem;
5
6use ahash::HashMap;
7use bytes::Bytes;
8use displaydoc::Display;
9use futures::Stream;
10use futures::StreamExt;
11use futures::future::Either;
12use http::HeaderValue;
13use http::Method;
14use http::StatusCode;
15use http::header::CONTENT_TYPE;
16use http::header::HeaderName;
17use http_body_util::BodyExt;
18use multer::Multipart;
19use multimap::MultiMap;
20use serde_json_bytes::ByteString;
21use serde_json_bytes::Map as JsonMap;
22use serde_json_bytes::Value;
23use static_assertions::assert_impl_all;
24use thiserror::Error;
25use tower::BoxError;
26use uuid::Uuid;
27
28use self::body::RouterBody;
29use self::service::MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE;
30use self::service::MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE;
31use super::supergraph;
32use crate::Context;
33use crate::context::CHUNK_CONTAINS_GRAPHQL_ERROR;
34use crate::context::CONTAINS_GRAPHQL_ERROR;
35use crate::context::ROUTER_RESPONSE_ERRORS;
36use crate::graphql;
37use crate::http_ext::header_map;
38use crate::json_ext::Path;
39use crate::plugins::telemetry::config_new::router::events::RouterResponseBodyExtensionType;
40use crate::services::TryIntoHeaderName;
41use crate::services::TryIntoHeaderValue;
42
43pub type BoxService = tower::util::BoxService<Request, Response, BoxError>;
44pub type BoxCloneService = tower::util::BoxCloneService<Request, Response, BoxError>;
45pub type ServiceResult = Result<Response, BoxError>;
46
47pub type Body = RouterBody;
48pub type Error = hyper::Error;
49
50mod batching;
51pub mod body;
52pub(crate) mod pipeline_handle;
53pub(crate) mod service;
54#[cfg(test)]
55mod tests;
56mod tower_compat;
57
58assert_impl_all!(Request: Send);
59/// Represents the router processing step of the processing pipeline.
60///
61/// This consists of the parsed graphql Request, HTTP headers and contextual data for extensions.
62#[non_exhaustive]
63pub struct Request {
64    /// Original request to the Router.
65    pub router_request: http::Request<Body>,
66
67    /// Context for extension
68    pub context: Context,
69}
70
71impl From<(http::Request<Body>, Context)> for Request {
72    fn from((router_request, context): (http::Request<Body>, Context)) -> Self {
73        Self {
74            router_request,
75            context,
76        }
77    }
78}
79
80/// Helper type to conveniently construct a body from several types used commonly in tests.
81///
82/// It's only meant for integration tests, as the "real" router should create bodies explicitly accounting for
83/// streaming, size limits, etc.
84pub struct IntoBody(Body);
85
86impl From<Body> for IntoBody {
87    fn from(value: Body) -> Self {
88        Self(value)
89    }
90}
91impl From<String> for IntoBody {
92    fn from(value: String) -> Self {
93        Self(body::from_bytes(value))
94    }
95}
96impl From<Bytes> for IntoBody {
97    fn from(value: Bytes) -> Self {
98        Self(body::from_bytes(value))
99    }
100}
101impl From<Vec<u8>> for IntoBody {
102    fn from(value: Vec<u8>) -> Self {
103        Self(body::from_bytes(value))
104    }
105}
106
107#[buildstructor::buildstructor]
108impl Request {
109    /// This is the constructor (or builder) to use when constructing a real Request.
110    ///
111    /// Required parameters are required in non-testing code to create a Request.
112    #[builder(visibility = "pub")]
113    fn new(
114        context: Context,
115        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
116        uri: http::Uri,
117        method: Method,
118        body: Body,
119    ) -> Result<Request, BoxError> {
120        let mut router_request = http::Request::builder()
121            .uri(uri)
122            .method(method)
123            .body(body)?;
124        *router_request.headers_mut() = header_map(headers)?;
125        Ok(Self {
126            router_request,
127            context,
128        })
129    }
130
131    /// This is the constructor (or builder) to use when constructing a fake Request.
132    ///
133    /// Required parameters are required in non-testing code to create a Request.
134    #[builder(visibility = "pub")]
135    fn fake_new(
136        context: Option<Context>,
137        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
138        uri: Option<http::Uri>,
139        method: Option<Method>,
140        body: Option<IntoBody>,
141    ) -> Result<Request, BoxError> {
142        let mut router_request = http::Request::builder()
143            .uri(uri.unwrap_or_else(|| http::Uri::from_static("http://example.com/")))
144            .method(method.unwrap_or(Method::GET))
145            .body(body.map_or_else(body::empty, |constructed| constructed.0))?;
146        *router_request.headers_mut() = header_map(headers)?;
147        Ok(Self {
148            router_request,
149            context: context.unwrap_or_default(),
150        })
151    }
152}
153
154#[derive(Error, Display, Debug)]
155pub enum ParseError {
156    /// couldn't create a valid http GET uri '{0}'
157    InvalidUri(http::uri::InvalidUri),
158    /// couldn't urlencode the GraphQL request body '{0}'
159    UrlEncodeError(serde_urlencoded::ser::Error),
160    /// couldn't serialize the GraphQL request body '{0}'
161    SerializationError(serde_json::Error),
162}
163
164/// This is handy for tests.
165impl TryFrom<supergraph::Request> for Request {
166    type Error = ParseError;
167    fn try_from(request: supergraph::Request) -> Result<Self, Self::Error> {
168        let supergraph::Request {
169            context,
170            supergraph_request,
171            ..
172        } = request;
173
174        let (mut parts, request) = supergraph_request.into_parts();
175
176        let router_request = if parts.method == Method::GET {
177            // get request
178            let get_path = serde_urlencoded::to_string([
179                ("query", request.query),
180                ("operationName", request.operation_name),
181                (
182                    "extensions",
183                    serde_json::to_string(&request.extensions).ok(),
184                ),
185                ("variables", serde_json::to_string(&request.variables).ok()),
186            ])
187            .map_err(ParseError::UrlEncodeError)?;
188
189            parts.uri = format!("{}?{}", parts.uri, get_path)
190                .parse()
191                .map_err(ParseError::InvalidUri)?;
192
193            http::Request::from_parts(parts, body::empty())
194        } else {
195            http::Request::from_parts(
196                parts,
197                body::from_bytes(
198                    serde_json::to_vec(&request).map_err(ParseError::SerializationError)?,
199                ),
200            )
201        };
202        Ok(Self {
203            router_request,
204            context,
205        })
206    }
207}
208
209assert_impl_all!(Response: Send);
210#[non_exhaustive]
211#[derive(Debug)]
212pub struct Response {
213    pub response: http::Response<Body>,
214    pub context: Context,
215}
216
217#[buildstructor::buildstructor]
218impl Response {
219    fn stash_the_body_in_extensions(&mut self, body_string: String) {
220        self.context.extensions().with_lock(|ext| {
221            ext.insert(RouterResponseBodyExtensionType(body_string));
222        });
223    }
224
225    pub async fn next_response(&mut self) -> Option<Result<Bytes, axum::Error>> {
226        self.response.body_mut().into_data_stream().next().await
227    }
228
229    /// This is the constructor (or builder) to use when constructing a real Response.
230    ///
231    /// Required parameters are required in non-testing code to create a Response.
232    #[builder(visibility = "pub")]
233    fn new(
234        label: Option<String>,
235        data: Option<serde_json_bytes::Value>,
236        path: Option<Path>,
237        errors: Vec<graphql::Error>,
238        // Skip the `Object` type alias to use buildstructor’s map special-casing
239        extensions: JsonMap<ByteString, serde_json_bytes::Value>,
240        status_code: Option<StatusCode>,
241        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
242        context: Context,
243    ) -> Result<Self, BoxError> {
244        if !errors.is_empty() {
245            Self::add_errors_to_context(&errors, &context);
246        }
247
248        // Build a response
249        let b = graphql::Response::builder()
250            .and_label(label)
251            .and_path(path)
252            .errors(errors)
253            .extensions(extensions);
254        let res = match data {
255            Some(data) => b.data(data).build(),
256            None => b.build(),
257        };
258
259        // Build an HTTP Response
260        let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
261        for (key, values) in headers {
262            let header_name: HeaderName = key.try_into()?;
263            for value in values {
264                let header_value: HeaderValue = value.try_into()?;
265                builder = builder.header(header_name.clone(), header_value);
266            }
267        }
268
269        let body_string = serde_json::to_string(&res)?;
270
271        let body = body::from_bytes(body_string.clone());
272        let response = builder.body(body)?;
273        // Stash the body in the extensions so we can access it later
274        let mut response = Self { response, context };
275        response.stash_the_body_in_extensions(body_string);
276
277        Ok(response)
278    }
279
280    #[builder(visibility = "pub")]
281    fn http_response_new(
282        response: http::Response<Body>,
283        context: Context,
284        body_to_stash: Option<String>,
285        errors_for_context: Option<Vec<graphql::Error>>,
286    ) -> Result<Self, BoxError> {
287        // There are instances where we have errors that need to be counted for telemetry in this
288        // layer, but we don't want to deserialize the body. In these cases we can pass in the
289        // list of errors to add to context for counting later in the telemetry plugin.
290        if let Some(errors) = errors_for_context
291            && !errors.is_empty()
292        {
293            Self::add_errors_to_context(&errors, &context);
294        }
295        let mut res = Self { response, context };
296        if let Some(body_to_stash) = body_to_stash {
297            res.stash_the_body_in_extensions(body_to_stash)
298        }
299        Ok(res)
300    }
301
302    /// This is the constructor (or builder) to use when constructing a Response that represents a global error.
303    /// It has no path and no response data.
304    /// This is useful for things such as authentication errors.
305    #[builder(visibility = "pub")]
306    fn error_new(
307        errors: Vec<graphql::Error>,
308        status_code: Option<StatusCode>,
309        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
310        context: Context,
311    ) -> Result<Self, BoxError> {
312        Response::new(
313            Default::default(),
314            Default::default(),
315            None,
316            errors,
317            Default::default(),
318            status_code,
319            headers,
320            context,
321        )
322    }
323
324    /// This is the constructor (or builder) to use when constructing a real Response.
325    ///
326    /// Required parameters are required in non-testing code to create a Response.
327    #[builder(visibility = "pub(crate)")]
328    fn infallible_new(
329        label: Option<String>,
330        data: Option<serde_json_bytes::Value>,
331        path: Option<Path>,
332        errors: Vec<graphql::Error>,
333        // Skip the `Object` type alias to use buildstructor’s map special-casing
334        extensions: JsonMap<ByteString, serde_json_bytes::Value>,
335        status_code: Option<StatusCode>,
336        headers: MultiMap<HeaderName, HeaderValue>,
337        context: Context,
338    ) -> Self {
339        if !errors.is_empty() {
340            Self::add_errors_to_context(&errors, &context);
341        }
342
343        // Build a response
344        let b = graphql::Response::builder()
345            .and_label(label)
346            .and_path(path)
347            .errors(errors)
348            .extensions(extensions);
349        let res = match data {
350            Some(data) => b.data(data).build(),
351            None => b.build(),
352        };
353
354        // Build an http Response
355        let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
356        for (header_name, values) in headers {
357            for header_value in values {
358                builder = builder.header(header_name.clone(), header_value);
359            }
360        }
361
362        let body_string = serde_json::to_string(&res).expect("JSON is always a valid string");
363
364        let body = body::from_bytes(body_string.clone());
365        let response = builder.body(body).expect("RouterBody is always valid");
366
367        Self { response, context }
368    }
369
370    fn add_errors_to_context(errors: &[graphql::Error], context: &Context) {
371        context.insert_json_value(CONTAINS_GRAPHQL_ERROR, Value::Bool(true));
372        context.insert_json_value(CHUNK_CONTAINS_GRAPHQL_ERROR, Value::Bool(true));
373        // This is ONLY guaranteed to capture errors if any were added during router service
374        // processing. We will sometimes avoid this path if no router service errors exist, even
375        // if errors were passed from the supergraph service, because that path builds the
376        // router::Response using parts_new(). This is ok because we only need this context to
377        // count errors introduced in the router service; however, it means that we handle error
378        // counting differently in this layer than others.
379        context
380            .insert(
381                ROUTER_RESPONSE_ERRORS,
382                // We can't serialize the apollo_id, so make a map with id as the key
383                errors
384                    .iter()
385                    .cloned()
386                    .map(|err| (err.apollo_id(), err))
387                    .collect::<HashMap<Uuid, graphql::Error>>(),
388            )
389            .expect("Unable to serialize router response errors list for context");
390    }
391
392    /// EXPERIMENTAL: THIS FUNCTION IS EXPERIMENTAL AND SUBJECT TO POTENTIAL CHANGE.
393    pub async fn into_graphql_response_stream(
394        self,
395    ) -> impl Stream<Item = Result<graphql::Response, serde_json::Error>> {
396        Box::pin(
397            if self
398                .response
399                .headers()
400                .get(CONTENT_TYPE)
401                .iter()
402                .any(|value| {
403                    *value == MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE
404                        || *value == MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE
405                })
406            {
407                let multipart = Multipart::new(
408                    http_body_util::BodyDataStream::new(self.response.into_body()),
409                    "graphql",
410                );
411
412                Either::Left(futures::stream::unfold(multipart, |mut m| async {
413                    if let Ok(Some(response)) = m.next_field().await
414                        && let Ok(bytes) = response.bytes().await
415                    {
416                        return Some((serde_json::from_slice::<graphql::Response>(&bytes), m));
417                    }
418                    None
419                }))
420            } else {
421                let mut body = http_body_util::BodyDataStream::new(self.response.into_body());
422                let res = body.next().await.and_then(|res| res.ok());
423
424                Either::Right(
425                    futures::stream::iter(res)
426                        .map(|bytes| serde_json::from_slice::<graphql::Response>(&bytes)),
427                )
428            },
429        )
430    }
431
432    /// This is the constructor (or builder) to use when constructing a fake Response.
433    ///
434    /// Required parameters are required in non-testing code to create a Response.
435    #[builder(visibility = "pub")]
436    fn fake_new(
437        label: Option<String>,
438        data: Option<serde_json_bytes::Value>,
439        path: Option<Path>,
440        errors: Vec<graphql::Error>,
441        // Skip the `Object` type alias to use buildstructor’s map special-casing
442        extensions: JsonMap<ByteString, serde_json_bytes::Value>,
443        status_code: Option<StatusCode>,
444        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
445        context: Option<Context>,
446    ) -> Result<Self, BoxError> {
447        // Build a response
448        Self::new(
449            label,
450            data,
451            path,
452            errors,
453            extensions,
454            status_code,
455            headers,
456            context.unwrap_or_default(),
457        )
458    }
459}
460
461#[derive(Clone, Default, Debug)]
462pub(crate) struct ClientRequestAccepts {
463    pub(crate) multipart_defer: bool,
464    pub(crate) multipart_subscription: bool,
465    pub(crate) json: bool,
466    pub(crate) wildcard: bool,
467}
468
469impl<T> From<http::Response<T>> for Response
470where
471    T: http_body::Body<Data = Bytes> + Send + 'static,
472    <T as http_body::Body>::Error: Into<BoxError>,
473{
474    fn from(response: http::Response<T>) -> Self {
475        let context: Context = response.extensions().get().cloned().unwrap_or_default();
476
477        Self {
478            response: response.map(convert_to_body),
479            context,
480        }
481    }
482}
483
484impl From<Response> for http::Response<Body> {
485    fn from(mut response: Response) -> Self {
486        response.response.extensions_mut().insert(response.context);
487        response.response
488    }
489}
490
491impl<T> From<http::Request<T>> for Request
492where
493    T: http_body::Body<Data = Bytes> + Send + 'static,
494    <T as http_body::Body>::Error: Into<BoxError>,
495{
496    fn from(request: http::Request<T>) -> Self {
497        let context: Context = request.extensions().get().cloned().unwrap_or_default();
498
499        Self {
500            router_request: request.map(convert_to_body),
501            context,
502        }
503    }
504}
505
506impl From<Request> for http::Request<Body> {
507    fn from(mut request: Request) -> Self {
508        request
509            .router_request
510            .extensions_mut()
511            .insert(request.context);
512        request.router_request
513    }
514}
515
516/// This function is used to convert an `http_body::Body` into a `Body`.
517/// It does a downcast check to see if the body is already a `Body` and if it is, then it just returns it.
518/// There is zero overhead if the body is already a `Body`.
519/// Note that ALL graphql responses are already a stream as they may be part of a deferred or stream response,
520/// therefore, if a body has to be wrapped, the cost is minimal.
521fn convert_to_body<T>(mut b: T) -> Body
522where
523    T: http_body::Body<Data = Bytes> + Send + 'static,
524    <T as http_body::Body>::Error: Into<BoxError>,
525{
526    let val_any = &mut b as &mut dyn Any;
527    match val_any.downcast_mut::<Body>() {
528        Some(body) => mem::take(body),
529        None => Body::new(http_body_util::BodyStream::new(b.map_err(axum::Error::new))),
530    }
531}
532
533#[cfg(test)]
534mod test {
535    use std::pin::Pin;
536    use std::task::Context;
537    use std::task::Poll;
538
539    use tower::BoxError;
540
541    use super::convert_to_body;
542    use crate::services::router;
543
544    struct MockBody {
545        data: Option<&'static str>,
546    }
547    impl http_body::Body for MockBody {
548        type Data = bytes::Bytes;
549        type Error = BoxError;
550
551        fn poll_frame(
552            self: Pin<&mut Self>,
553            _cx: &mut Context<'_>,
554        ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
555            if let Some(data) = self.get_mut().data.take() {
556                Poll::Ready(Some(Ok(http_body::Frame::data(bytes::Bytes::from(data)))))
557            } else {
558                Poll::Ready(None)
559            }
560        }
561    }
562
563    #[tokio::test]
564    async fn test_convert_from_http_body() {
565        let body = convert_to_body(MockBody { data: Some("test") });
566        assert_eq!(
567            &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
568            "test"
569        );
570    }
571
572    #[tokio::test]
573    async fn test_convert_from_hyper_body() {
574        let body = convert_to_body(String::from("test"));
575        assert_eq!(
576            &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
577            "test"
578        );
579    }
580}