apollo_router/services/
router.rs

1#![allow(missing_docs)] // FIXME
2
3use std::any::Any;
4use std::mem;
5
6use ahash::HashMap;
7use bytes::Bytes;
8use displaydoc::Display;
9use futures::Stream;
10use futures::StreamExt;
11use futures::future::Either;
12use http::HeaderValue;
13use http::Method;
14use http::StatusCode;
15use http::header::CONTENT_TYPE;
16use http::header::HeaderName;
17use http_body_util::BodyExt;
18use multer::Multipart;
19use multimap::MultiMap;
20use serde_json_bytes::ByteString;
21use serde_json_bytes::Map as JsonMap;
22use serde_json_bytes::Value;
23use static_assertions::assert_impl_all;
24use thiserror::Error;
25use tower::BoxError;
26use uuid::Uuid;
27
28use self::body::RouterBody;
29use self::service::MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE;
30use self::service::MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE;
31use super::supergraph;
32use crate::Context;
33use crate::context::CONTAINS_GRAPHQL_ERROR;
34use crate::context::ROUTER_RESPONSE_ERRORS;
35use crate::graphql;
36use crate::http_ext::header_map;
37use crate::json_ext::Path;
38use crate::plugins::telemetry::config_new::router::events::RouterResponseBodyExtensionType;
39use crate::services::TryIntoHeaderName;
40use crate::services::TryIntoHeaderValue;
41
42pub type BoxService = tower::util::BoxService<Request, Response, BoxError>;
43pub type BoxCloneService = tower::util::BoxCloneService<Request, Response, BoxError>;
44pub type ServiceResult = Result<Response, BoxError>;
45
46pub type Body = RouterBody;
47pub type Error = hyper::Error;
48
49pub mod body;
50pub(crate) mod pipeline_handle;
51pub(crate) mod service;
52#[cfg(test)]
53mod tests;
54
55assert_impl_all!(Request: Send);
56/// Represents the router processing step of the processing pipeline.
57///
58/// This consists of the parsed graphql Request, HTTP headers and contextual data for extensions.
59#[non_exhaustive]
60pub struct Request {
61    /// Original request to the Router.
62    pub router_request: http::Request<Body>,
63
64    /// Context for extension
65    pub context: Context,
66}
67
68impl From<(http::Request<Body>, Context)> for Request {
69    fn from((router_request, context): (http::Request<Body>, Context)) -> Self {
70        Self {
71            router_request,
72            context,
73        }
74    }
75}
76
77/// Helper type to conveniently construct a body from several types used commonly in tests.
78///
79/// It's only meant for integration tests, as the "real" router should create bodies explicitly accounting for
80/// streaming, size limits, etc.
81pub struct IntoBody(Body);
82
83impl From<Body> for IntoBody {
84    fn from(value: Body) -> Self {
85        Self(value)
86    }
87}
88impl From<String> for IntoBody {
89    fn from(value: String) -> Self {
90        Self(body::from_bytes(value))
91    }
92}
93impl From<Bytes> for IntoBody {
94    fn from(value: Bytes) -> Self {
95        Self(body::from_bytes(value))
96    }
97}
98impl From<Vec<u8>> for IntoBody {
99    fn from(value: Vec<u8>) -> Self {
100        Self(body::from_bytes(value))
101    }
102}
103
104#[buildstructor::buildstructor]
105impl Request {
106    /// This is the constructor (or builder) to use when constructing a real Request.
107    ///
108    /// Required parameters are required in non-testing code to create a Request.
109    #[builder(visibility = "pub")]
110    fn new(
111        context: Context,
112        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
113        uri: http::Uri,
114        method: Method,
115        body: Body,
116    ) -> Result<Request, BoxError> {
117        let mut router_request = http::Request::builder()
118            .uri(uri)
119            .method(method)
120            .body(body)?;
121        *router_request.headers_mut() = header_map(headers)?;
122        Ok(Self {
123            router_request,
124            context,
125        })
126    }
127
128    /// This is the constructor (or builder) to use when constructing a fake Request.
129    ///
130    /// Required parameters are required in non-testing code to create a Request.
131    #[builder(visibility = "pub")]
132    fn fake_new(
133        context: Option<Context>,
134        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
135        uri: Option<http::Uri>,
136        method: Option<Method>,
137        body: Option<IntoBody>,
138    ) -> Result<Request, BoxError> {
139        let mut router_request = http::Request::builder()
140            .uri(uri.unwrap_or_else(|| http::Uri::from_static("http://example.com/")))
141            .method(method.unwrap_or(Method::GET))
142            .body(body.map_or_else(body::empty, |constructed| constructed.0))?;
143        *router_request.headers_mut() = header_map(headers)?;
144        Ok(Self {
145            router_request,
146            context: context.unwrap_or_default(),
147        })
148    }
149}
150
151#[derive(Error, Display, Debug)]
152pub enum ParseError {
153    /// couldn't create a valid http GET uri '{0}'
154    InvalidUri(http::uri::InvalidUri),
155    /// couldn't urlencode the GraphQL request body '{0}'
156    UrlEncodeError(serde_urlencoded::ser::Error),
157    /// couldn't serialize the GraphQL request body '{0}'
158    SerializationError(serde_json::Error),
159}
160
161/// This is handy for tests.
162impl TryFrom<supergraph::Request> for Request {
163    type Error = ParseError;
164    fn try_from(request: supergraph::Request) -> Result<Self, Self::Error> {
165        let supergraph::Request {
166            context,
167            supergraph_request,
168            ..
169        } = request;
170
171        let (mut parts, request) = supergraph_request.into_parts();
172
173        let router_request = if parts.method == Method::GET {
174            // get request
175            let get_path = serde_urlencoded::to_string([
176                ("query", request.query),
177                ("operationName", request.operation_name),
178                (
179                    "extensions",
180                    serde_json::to_string(&request.extensions).ok(),
181                ),
182                ("variables", serde_json::to_string(&request.variables).ok()),
183            ])
184            .map_err(ParseError::UrlEncodeError)?;
185
186            parts.uri = format!("{}?{}", parts.uri, get_path)
187                .parse()
188                .map_err(ParseError::InvalidUri)?;
189
190            http::Request::from_parts(parts, body::empty())
191        } else {
192            http::Request::from_parts(
193                parts,
194                body::from_bytes(
195                    serde_json::to_vec(&request).map_err(ParseError::SerializationError)?,
196                ),
197            )
198        };
199        Ok(Self {
200            router_request,
201            context,
202        })
203    }
204}
205
206assert_impl_all!(Response: Send);
207#[non_exhaustive]
208#[derive(Debug)]
209pub struct Response {
210    pub response: http::Response<Body>,
211    pub context: Context,
212}
213
214#[buildstructor::buildstructor]
215impl Response {
216    fn stash_the_body_in_extensions(&mut self, body_string: String) {
217        self.context.extensions().with_lock(|ext| {
218            ext.insert(RouterResponseBodyExtensionType(body_string));
219        });
220    }
221
222    pub async fn next_response(&mut self) -> Option<Result<Bytes, axum::Error>> {
223        self.response.body_mut().into_data_stream().next().await
224    }
225
226    /// This is the constructor (or builder) to use when constructing a real Response.
227    ///
228    /// Required parameters are required in non-testing code to create a Response.
229    #[builder(visibility = "pub")]
230    fn new(
231        label: Option<String>,
232        data: Option<serde_json_bytes::Value>,
233        path: Option<Path>,
234        errors: Vec<graphql::Error>,
235        // Skip the `Object` type alias to use buildstructor’s map special-casing
236        extensions: JsonMap<ByteString, serde_json_bytes::Value>,
237        status_code: Option<StatusCode>,
238        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
239        context: Context,
240    ) -> Result<Self, BoxError> {
241        if !errors.is_empty() {
242            Self::add_errors_to_context(&errors, &context);
243        }
244
245        // Build a response
246        let b = graphql::Response::builder()
247            .and_label(label)
248            .and_path(path)
249            .errors(errors)
250            .extensions(extensions);
251        let res = match data {
252            Some(data) => b.data(data).build(),
253            None => b.build(),
254        };
255
256        // Build an HTTP Response
257        let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
258        for (key, values) in headers {
259            let header_name: HeaderName = key.try_into()?;
260            for value in values {
261                let header_value: HeaderValue = value.try_into()?;
262                builder = builder.header(header_name.clone(), header_value);
263            }
264        }
265
266        let body_string = serde_json::to_string(&res)?;
267
268        let body = body::from_bytes(body_string.clone());
269        let response = builder.body(body)?;
270        // Stash the body in the extensions so we can access it later
271        let mut response = Self { response, context };
272        response.stash_the_body_in_extensions(body_string);
273
274        Ok(response)
275    }
276
277    #[builder(visibility = "pub")]
278    fn http_response_new(
279        response: http::Response<Body>,
280        context: Context,
281        body_to_stash: Option<String>,
282        errors_for_context: Option<Vec<graphql::Error>>,
283    ) -> Result<Self, BoxError> {
284        // There are instances where we have errors that need to be counted for telemetry in this
285        // layer, but we don't want to deserialize the body. In these cases we can pass in the
286        // list of errors to add to context for counting later in the telemetry plugin.
287        if let Some(errors) = errors_for_context
288            && !errors.is_empty()
289        {
290            Self::add_errors_to_context(&errors, &context);
291        }
292        let mut res = Self { response, context };
293        if let Some(body_to_stash) = body_to_stash {
294            res.stash_the_body_in_extensions(body_to_stash)
295        }
296        Ok(res)
297    }
298
299    /// This is the constructor (or builder) to use when constructing a Response that represents a global error.
300    /// It has no path and no response data.
301    /// This is useful for things such as authentication errors.
302    #[builder(visibility = "pub")]
303    fn error_new(
304        errors: Vec<graphql::Error>,
305        status_code: Option<StatusCode>,
306        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
307        context: Context,
308    ) -> Result<Self, BoxError> {
309        Response::new(
310            Default::default(),
311            Default::default(),
312            None,
313            errors,
314            Default::default(),
315            status_code,
316            headers,
317            context,
318        )
319    }
320
321    /// This is the constructor (or builder) to use when constructing a real Response.
322    ///
323    /// Required parameters are required in non-testing code to create a Response.
324    #[builder(visibility = "pub(crate)")]
325    fn infallible_new(
326        label: Option<String>,
327        data: Option<serde_json_bytes::Value>,
328        path: Option<Path>,
329        errors: Vec<graphql::Error>,
330        // Skip the `Object` type alias to use buildstructor’s map special-casing
331        extensions: JsonMap<ByteString, serde_json_bytes::Value>,
332        status_code: Option<StatusCode>,
333        headers: MultiMap<HeaderName, HeaderValue>,
334        context: Context,
335    ) -> Self {
336        if !errors.is_empty() {
337            Self::add_errors_to_context(&errors, &context);
338        }
339
340        // Build a response
341        let b = graphql::Response::builder()
342            .and_label(label)
343            .and_path(path)
344            .errors(errors)
345            .extensions(extensions);
346        let res = match data {
347            Some(data) => b.data(data).build(),
348            None => b.build(),
349        };
350
351        // Build an http Response
352        let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
353        for (header_name, values) in headers {
354            for header_value in values {
355                builder = builder.header(header_name.clone(), header_value);
356            }
357        }
358
359        let body_string = serde_json::to_string(&res).expect("JSON is always a valid string");
360
361        let body = body::from_bytes(body_string.clone());
362        let response = builder.body(body).expect("RouterBody is always valid");
363
364        Self { response, context }
365    }
366
367    fn add_errors_to_context(errors: &[graphql::Error], context: &Context) {
368        context.insert_json_value(CONTAINS_GRAPHQL_ERROR, Value::Bool(true));
369        // This is ONLY guaranteed to capture errors if any were added during router service
370        // processing. We will sometimes avoid this path if no router service errors exist, even
371        // if errors were passed from the supergraph service, because that path builds the
372        // router::Response using parts_new(). This is ok because we only need this context to
373        // count errors introduced in the router service; however, it means that we handle error
374        // counting differently in this layer than others.
375        context
376            .insert(
377                ROUTER_RESPONSE_ERRORS,
378                // We can't serialize the apollo_id, so make a map with id as the key
379                errors
380                    .iter()
381                    .cloned()
382                    .map(|err| (err.apollo_id(), err))
383                    .collect::<HashMap<Uuid, graphql::Error>>(),
384            )
385            .expect("Unable to serialize router response errors list for context");
386    }
387
388    /// EXPERIMENTAL: THIS FUNCTION IS EXPERIMENTAL AND SUBJECT TO POTENTIAL CHANGE.
389    pub async fn into_graphql_response_stream(
390        self,
391    ) -> impl Stream<Item = Result<graphql::Response, serde_json::Error>> {
392        Box::pin(
393            if self
394                .response
395                .headers()
396                .get(CONTENT_TYPE)
397                .iter()
398                .any(|value| {
399                    *value == MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE
400                        || *value == MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE
401                })
402            {
403                let multipart = Multipart::new(
404                    http_body_util::BodyDataStream::new(self.response.into_body()),
405                    "graphql",
406                );
407
408                Either::Left(futures::stream::unfold(multipart, |mut m| async {
409                    if let Ok(Some(response)) = m.next_field().await
410                        && let Ok(bytes) = response.bytes().await
411                    {
412                        return Some((serde_json::from_slice::<graphql::Response>(&bytes), m));
413                    }
414                    None
415                }))
416            } else {
417                let mut body = http_body_util::BodyDataStream::new(self.response.into_body());
418                let res = body.next().await.and_then(|res| res.ok());
419
420                Either::Right(
421                    futures::stream::iter(res.into_iter())
422                        .map(|bytes| serde_json::from_slice::<graphql::Response>(&bytes)),
423                )
424            },
425        )
426    }
427
428    /// This is the constructor (or builder) to use when constructing a fake Response.
429    ///
430    /// Required parameters are required in non-testing code to create a Response.
431    #[builder(visibility = "pub")]
432    fn fake_new(
433        label: Option<String>,
434        data: Option<serde_json_bytes::Value>,
435        path: Option<Path>,
436        errors: Vec<graphql::Error>,
437        // Skip the `Object` type alias to use buildstructor’s map special-casing
438        extensions: JsonMap<ByteString, serde_json_bytes::Value>,
439        status_code: Option<StatusCode>,
440        headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
441        context: Option<Context>,
442    ) -> Result<Self, BoxError> {
443        // Build a response
444        Self::new(
445            label,
446            data,
447            path,
448            errors,
449            extensions,
450            status_code,
451            headers,
452            context.unwrap_or_default(),
453        )
454    }
455}
456
457#[derive(Clone, Default, Debug)]
458pub(crate) struct ClientRequestAccepts {
459    pub(crate) multipart_defer: bool,
460    pub(crate) multipart_subscription: bool,
461    pub(crate) json: bool,
462    pub(crate) wildcard: bool,
463}
464
465impl<T> From<http::Response<T>> for Response
466where
467    T: http_body::Body<Data = Bytes> + Send + 'static,
468    <T as http_body::Body>::Error: Into<BoxError>,
469{
470    fn from(response: http::Response<T>) -> Self {
471        let context: Context = response.extensions().get().cloned().unwrap_or_default();
472
473        Self {
474            response: response.map(convert_to_body),
475            context,
476        }
477    }
478}
479
480impl From<Response> for http::Response<Body> {
481    fn from(mut response: Response) -> Self {
482        response.response.extensions_mut().insert(response.context);
483        response.response
484    }
485}
486
487impl<T> From<http::Request<T>> for Request
488where
489    T: http_body::Body<Data = Bytes> + Send + 'static,
490    <T as http_body::Body>::Error: Into<BoxError>,
491{
492    fn from(request: http::Request<T>) -> Self {
493        let context: Context = request.extensions().get().cloned().unwrap_or_default();
494
495        Self {
496            router_request: request.map(convert_to_body),
497            context,
498        }
499    }
500}
501
502impl From<Request> for http::Request<Body> {
503    fn from(mut request: Request) -> Self {
504        request
505            .router_request
506            .extensions_mut()
507            .insert(request.context);
508        request.router_request
509    }
510}
511
512/// This function is used to convert an `http_body::Body` into a `Body`.
513/// It does a downcast check to see if the body is already a `Body` and if it is, then it just returns it.
514/// There is zero overhead if the body is already a `Body`.
515/// Note that ALL graphql responses are already a stream as they may be part of a deferred or stream response,
516/// therefore, if a body has to be wrapped, the cost is minimal.
517fn convert_to_body<T>(mut b: T) -> Body
518where
519    T: http_body::Body<Data = Bytes> + Send + 'static,
520    <T as http_body::Body>::Error: Into<BoxError>,
521{
522    let val_any = &mut b as &mut dyn Any;
523    match val_any.downcast_mut::<Body>() {
524        Some(body) => mem::take(body),
525        None => Body::new(http_body_util::BodyStream::new(b.map_err(axum::Error::new))),
526    }
527}
528
529#[cfg(test)]
530mod test {
531    use std::pin::Pin;
532    use std::task::Context;
533    use std::task::Poll;
534
535    use tower::BoxError;
536
537    use super::convert_to_body;
538    use crate::services::router;
539
540    struct MockBody {
541        data: Option<&'static str>,
542    }
543    impl http_body::Body for MockBody {
544        type Data = bytes::Bytes;
545        type Error = BoxError;
546
547        fn poll_frame(
548            self: Pin<&mut Self>,
549            _cx: &mut Context<'_>,
550        ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
551            if let Some(data) = self.get_mut().data.take() {
552                Poll::Ready(Some(Ok(http_body::Frame::data(bytes::Bytes::from(data)))))
553            } else {
554                Poll::Ready(None)
555            }
556        }
557    }
558
559    #[tokio::test]
560    async fn test_convert_from_http_body() {
561        let body = convert_to_body(MockBody { data: Some("test") });
562        assert_eq!(
563            &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
564            "test"
565        );
566    }
567
568    #[tokio::test]
569    async fn test_convert_from_hyper_body() {
570        let body = convert_to_body(String::from("test"));
571        assert_eq!(
572            &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
573            "test"
574        );
575    }
576}