1#![allow(missing_docs)] use std::any::Any;
4use std::mem;
5
6use ahash::HashMap;
7use bytes::Bytes;
8use displaydoc::Display;
9use futures::Stream;
10use futures::StreamExt;
11use futures::future::Either;
12use http::HeaderValue;
13use http::Method;
14use http::StatusCode;
15use http::header::CONTENT_TYPE;
16use http::header::HeaderName;
17use http_body_util::BodyExt;
18use multer::Multipart;
19use multimap::MultiMap;
20use serde_json_bytes::ByteString;
21use serde_json_bytes::Map as JsonMap;
22use serde_json_bytes::Value;
23use static_assertions::assert_impl_all;
24use thiserror::Error;
25use tower::BoxError;
26use uuid::Uuid;
27
28use self::body::RouterBody;
29use self::service::MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE;
30use self::service::MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE;
31use super::supergraph;
32use crate::Context;
33use crate::context::CONTAINS_GRAPHQL_ERROR;
34use crate::context::ROUTER_RESPONSE_ERRORS;
35use crate::graphql;
36use crate::http_ext::header_map;
37use crate::json_ext::Path;
38use crate::plugins::telemetry::config_new::router::events::RouterResponseBodyExtensionType;
39use crate::services::TryIntoHeaderName;
40use crate::services::TryIntoHeaderValue;
41
42pub type BoxService = tower::util::BoxService<Request, Response, BoxError>;
43pub type BoxCloneService = tower::util::BoxCloneService<Request, Response, BoxError>;
44pub type ServiceResult = Result<Response, BoxError>;
45
46pub type Body = RouterBody;
47pub type Error = hyper::Error;
48
49mod batching;
50pub mod body;
51pub(crate) mod pipeline_handle;
52pub(crate) mod service;
53#[cfg(test)]
54mod tests;
55mod tower_compat;
56
57assert_impl_all!(Request: Send);
58#[non_exhaustive]
62pub struct Request {
63 pub router_request: http::Request<Body>,
65
66 pub context: Context,
68}
69
70impl From<(http::Request<Body>, Context)> for Request {
71 fn from((router_request, context): (http::Request<Body>, Context)) -> Self {
72 Self {
73 router_request,
74 context,
75 }
76 }
77}
78
79pub struct IntoBody(Body);
84
85impl From<Body> for IntoBody {
86 fn from(value: Body) -> Self {
87 Self(value)
88 }
89}
90impl From<String> for IntoBody {
91 fn from(value: String) -> Self {
92 Self(body::from_bytes(value))
93 }
94}
95impl From<Bytes> for IntoBody {
96 fn from(value: Bytes) -> Self {
97 Self(body::from_bytes(value))
98 }
99}
100impl From<Vec<u8>> for IntoBody {
101 fn from(value: Vec<u8>) -> Self {
102 Self(body::from_bytes(value))
103 }
104}
105
106#[buildstructor::buildstructor]
107impl Request {
108 #[builder(visibility = "pub")]
112 fn new(
113 context: Context,
114 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
115 uri: http::Uri,
116 method: Method,
117 body: Body,
118 ) -> Result<Request, BoxError> {
119 let mut router_request = http::Request::builder()
120 .uri(uri)
121 .method(method)
122 .body(body)?;
123 *router_request.headers_mut() = header_map(headers)?;
124 Ok(Self {
125 router_request,
126 context,
127 })
128 }
129
130 #[builder(visibility = "pub")]
134 fn fake_new(
135 context: Option<Context>,
136 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
137 uri: Option<http::Uri>,
138 method: Option<Method>,
139 body: Option<IntoBody>,
140 ) -> Result<Request, BoxError> {
141 let mut router_request = http::Request::builder()
142 .uri(uri.unwrap_or_else(|| http::Uri::from_static("http://example.com/")))
143 .method(method.unwrap_or(Method::GET))
144 .body(body.map_or_else(body::empty, |constructed| constructed.0))?;
145 *router_request.headers_mut() = header_map(headers)?;
146 Ok(Self {
147 router_request,
148 context: context.unwrap_or_default(),
149 })
150 }
151}
152
153#[derive(Error, Display, Debug)]
154pub enum ParseError {
155 InvalidUri(http::uri::InvalidUri),
157 UrlEncodeError(serde_urlencoded::ser::Error),
159 SerializationError(serde_json::Error),
161}
162
163impl TryFrom<supergraph::Request> for Request {
165 type Error = ParseError;
166 fn try_from(request: supergraph::Request) -> Result<Self, Self::Error> {
167 let supergraph::Request {
168 context,
169 supergraph_request,
170 ..
171 } = request;
172
173 let (mut parts, request) = supergraph_request.into_parts();
174
175 let router_request = if parts.method == Method::GET {
176 let get_path = serde_urlencoded::to_string([
178 ("query", request.query),
179 ("operationName", request.operation_name),
180 (
181 "extensions",
182 serde_json::to_string(&request.extensions).ok(),
183 ),
184 ("variables", serde_json::to_string(&request.variables).ok()),
185 ])
186 .map_err(ParseError::UrlEncodeError)?;
187
188 parts.uri = format!("{}?{}", parts.uri, get_path)
189 .parse()
190 .map_err(ParseError::InvalidUri)?;
191
192 http::Request::from_parts(parts, body::empty())
193 } else {
194 http::Request::from_parts(
195 parts,
196 body::from_bytes(
197 serde_json::to_vec(&request).map_err(ParseError::SerializationError)?,
198 ),
199 )
200 };
201 Ok(Self {
202 router_request,
203 context,
204 })
205 }
206}
207
208assert_impl_all!(Response: Send);
209#[non_exhaustive]
210#[derive(Debug)]
211pub struct Response {
212 pub response: http::Response<Body>,
213 pub context: Context,
214}
215
216#[buildstructor::buildstructor]
217impl Response {
218 fn stash_the_body_in_extensions(&mut self, body_string: String) {
219 self.context.extensions().with_lock(|ext| {
220 ext.insert(RouterResponseBodyExtensionType(body_string));
221 });
222 }
223
224 pub async fn next_response(&mut self) -> Option<Result<Bytes, axum::Error>> {
225 self.response.body_mut().into_data_stream().next().await
226 }
227
228 #[builder(visibility = "pub")]
232 fn new(
233 label: Option<String>,
234 data: Option<serde_json_bytes::Value>,
235 path: Option<Path>,
236 errors: Vec<graphql::Error>,
237 extensions: JsonMap<ByteString, serde_json_bytes::Value>,
239 status_code: Option<StatusCode>,
240 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
241 context: Context,
242 ) -> Result<Self, BoxError> {
243 if !errors.is_empty() {
244 Self::add_errors_to_context(&errors, &context);
245 }
246
247 let b = graphql::Response::builder()
249 .and_label(label)
250 .and_path(path)
251 .errors(errors)
252 .extensions(extensions);
253 let res = match data {
254 Some(data) => b.data(data).build(),
255 None => b.build(),
256 };
257
258 let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
260 for (key, values) in headers {
261 let header_name: HeaderName = key.try_into()?;
262 for value in values {
263 let header_value: HeaderValue = value.try_into()?;
264 builder = builder.header(header_name.clone(), header_value);
265 }
266 }
267
268 let body_string = serde_json::to_string(&res)?;
269
270 let body = body::from_bytes(body_string.clone());
271 let response = builder.body(body)?;
272 let mut response = Self { response, context };
274 response.stash_the_body_in_extensions(body_string);
275
276 Ok(response)
277 }
278
279 #[builder(visibility = "pub")]
280 fn http_response_new(
281 response: http::Response<Body>,
282 context: Context,
283 body_to_stash: Option<String>,
284 errors_for_context: Option<Vec<graphql::Error>>,
285 ) -> Result<Self, BoxError> {
286 if let Some(errors) = errors_for_context
290 && !errors.is_empty()
291 {
292 Self::add_errors_to_context(&errors, &context);
293 }
294 let mut res = Self { response, context };
295 if let Some(body_to_stash) = body_to_stash {
296 res.stash_the_body_in_extensions(body_to_stash)
297 }
298 Ok(res)
299 }
300
301 #[builder(visibility = "pub")]
305 fn error_new(
306 errors: Vec<graphql::Error>,
307 status_code: Option<StatusCode>,
308 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
309 context: Context,
310 ) -> Result<Self, BoxError> {
311 Response::new(
312 Default::default(),
313 Default::default(),
314 None,
315 errors,
316 Default::default(),
317 status_code,
318 headers,
319 context,
320 )
321 }
322
323 #[builder(visibility = "pub(crate)")]
327 fn infallible_new(
328 label: Option<String>,
329 data: Option<serde_json_bytes::Value>,
330 path: Option<Path>,
331 errors: Vec<graphql::Error>,
332 extensions: JsonMap<ByteString, serde_json_bytes::Value>,
334 status_code: Option<StatusCode>,
335 headers: MultiMap<HeaderName, HeaderValue>,
336 context: Context,
337 ) -> Self {
338 if !errors.is_empty() {
339 Self::add_errors_to_context(&errors, &context);
340 }
341
342 let b = graphql::Response::builder()
344 .and_label(label)
345 .and_path(path)
346 .errors(errors)
347 .extensions(extensions);
348 let res = match data {
349 Some(data) => b.data(data).build(),
350 None => b.build(),
351 };
352
353 let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
355 for (header_name, values) in headers {
356 for header_value in values {
357 builder = builder.header(header_name.clone(), header_value);
358 }
359 }
360
361 let body_string = serde_json::to_string(&res).expect("JSON is always a valid string");
362
363 let body = body::from_bytes(body_string.clone());
364 let response = builder.body(body).expect("RouterBody is always valid");
365
366 Self { response, context }
367 }
368
369 fn add_errors_to_context(errors: &[graphql::Error], context: &Context) {
370 context.insert_json_value(CONTAINS_GRAPHQL_ERROR, Value::Bool(true));
371 context
378 .insert(
379 ROUTER_RESPONSE_ERRORS,
380 errors
382 .iter()
383 .cloned()
384 .map(|err| (err.apollo_id(), err))
385 .collect::<HashMap<Uuid, graphql::Error>>(),
386 )
387 .expect("Unable to serialize router response errors list for context");
388 }
389
390 pub async fn into_graphql_response_stream(
392 self,
393 ) -> impl Stream<Item = Result<graphql::Response, serde_json::Error>> {
394 Box::pin(
395 if self
396 .response
397 .headers()
398 .get(CONTENT_TYPE)
399 .iter()
400 .any(|value| {
401 *value == MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE
402 || *value == MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE
403 })
404 {
405 let multipart = Multipart::new(
406 http_body_util::BodyDataStream::new(self.response.into_body()),
407 "graphql",
408 );
409
410 Either::Left(futures::stream::unfold(multipart, |mut m| async {
411 if let Ok(Some(response)) = m.next_field().await
412 && let Ok(bytes) = response.bytes().await
413 {
414 return Some((serde_json::from_slice::<graphql::Response>(&bytes), m));
415 }
416 None
417 }))
418 } else {
419 let mut body = http_body_util::BodyDataStream::new(self.response.into_body());
420 let res = body.next().await.and_then(|res| res.ok());
421
422 Either::Right(
423 futures::stream::iter(res)
424 .map(|bytes| serde_json::from_slice::<graphql::Response>(&bytes)),
425 )
426 },
427 )
428 }
429
430 #[builder(visibility = "pub")]
434 fn fake_new(
435 label: Option<String>,
436 data: Option<serde_json_bytes::Value>,
437 path: Option<Path>,
438 errors: Vec<graphql::Error>,
439 extensions: JsonMap<ByteString, serde_json_bytes::Value>,
441 status_code: Option<StatusCode>,
442 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
443 context: Option<Context>,
444 ) -> Result<Self, BoxError> {
445 Self::new(
447 label,
448 data,
449 path,
450 errors,
451 extensions,
452 status_code,
453 headers,
454 context.unwrap_or_default(),
455 )
456 }
457}
458
459#[derive(Clone, Default, Debug)]
460pub(crate) struct ClientRequestAccepts {
461 pub(crate) multipart_defer: bool,
462 pub(crate) multipart_subscription: bool,
463 pub(crate) json: bool,
464 pub(crate) wildcard: bool,
465}
466
467impl<T> From<http::Response<T>> for Response
468where
469 T: http_body::Body<Data = Bytes> + Send + 'static,
470 <T as http_body::Body>::Error: Into<BoxError>,
471{
472 fn from(response: http::Response<T>) -> Self {
473 let context: Context = response.extensions().get().cloned().unwrap_or_default();
474
475 Self {
476 response: response.map(convert_to_body),
477 context,
478 }
479 }
480}
481
482impl From<Response> for http::Response<Body> {
483 fn from(mut response: Response) -> Self {
484 response.response.extensions_mut().insert(response.context);
485 response.response
486 }
487}
488
489impl<T> From<http::Request<T>> for Request
490where
491 T: http_body::Body<Data = Bytes> + Send + 'static,
492 <T as http_body::Body>::Error: Into<BoxError>,
493{
494 fn from(request: http::Request<T>) -> Self {
495 let context: Context = request.extensions().get().cloned().unwrap_or_default();
496
497 Self {
498 router_request: request.map(convert_to_body),
499 context,
500 }
501 }
502}
503
504impl From<Request> for http::Request<Body> {
505 fn from(mut request: Request) -> Self {
506 request
507 .router_request
508 .extensions_mut()
509 .insert(request.context);
510 request.router_request
511 }
512}
513
514fn convert_to_body<T>(mut b: T) -> Body
520where
521 T: http_body::Body<Data = Bytes> + Send + 'static,
522 <T as http_body::Body>::Error: Into<BoxError>,
523{
524 let val_any = &mut b as &mut dyn Any;
525 match val_any.downcast_mut::<Body>() {
526 Some(body) => mem::take(body),
527 None => Body::new(http_body_util::BodyStream::new(b.map_err(axum::Error::new))),
528 }
529}
530
531#[cfg(test)]
532mod test {
533 use std::pin::Pin;
534 use std::task::Context;
535 use std::task::Poll;
536
537 use tower::BoxError;
538
539 use super::convert_to_body;
540 use crate::services::router;
541
542 struct MockBody {
543 data: Option<&'static str>,
544 }
545 impl http_body::Body for MockBody {
546 type Data = bytes::Bytes;
547 type Error = BoxError;
548
549 fn poll_frame(
550 self: Pin<&mut Self>,
551 _cx: &mut Context<'_>,
552 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
553 if let Some(data) = self.get_mut().data.take() {
554 Poll::Ready(Some(Ok(http_body::Frame::data(bytes::Bytes::from(data)))))
555 } else {
556 Poll::Ready(None)
557 }
558 }
559 }
560
561 #[tokio::test]
562 async fn test_convert_from_http_body() {
563 let body = convert_to_body(MockBody { data: Some("test") });
564 assert_eq!(
565 &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
566 "test"
567 );
568 }
569
570 #[tokio::test]
571 async fn test_convert_from_hyper_body() {
572 let body = convert_to_body(String::from("test"));
573 assert_eq!(
574 &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
575 "test"
576 );
577 }
578}