1#![allow(missing_docs)] use std::any::Any;
4use std::mem;
5
6use ahash::HashMap;
7use bytes::Bytes;
8use displaydoc::Display;
9use futures::Stream;
10use futures::StreamExt;
11use futures::future::Either;
12use http::HeaderValue;
13use http::Method;
14use http::StatusCode;
15use http::header::CONTENT_TYPE;
16use http::header::HeaderName;
17use http_body_util::BodyExt;
18use multer::Multipart;
19use multimap::MultiMap;
20use serde_json_bytes::ByteString;
21use serde_json_bytes::Map as JsonMap;
22use serde_json_bytes::Value;
23use static_assertions::assert_impl_all;
24use thiserror::Error;
25use tower::BoxError;
26use uuid::Uuid;
27
28use self::body::RouterBody;
29use self::service::MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE;
30use self::service::MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE;
31use super::supergraph;
32use crate::Context;
33use crate::context::CHUNK_CONTAINS_GRAPHQL_ERROR;
34use crate::context::CONTAINS_GRAPHQL_ERROR;
35use crate::context::ROUTER_RESPONSE_ERRORS;
36use crate::graphql;
37use crate::http_ext::header_map;
38use crate::json_ext::Path;
39use crate::plugins::telemetry::config_new::router::events::RouterResponseBodyExtensionType;
40use crate::services::TryIntoHeaderName;
41use crate::services::TryIntoHeaderValue;
42
43pub type BoxService = tower::util::BoxService<Request, Response, BoxError>;
44pub type BoxCloneService = tower::util::BoxCloneService<Request, Response, BoxError>;
45pub type ServiceResult = Result<Response, BoxError>;
46
47pub type Body = RouterBody;
48pub type Error = hyper::Error;
49
50mod batching;
51pub mod body;
52pub(crate) mod pipeline_handle;
53pub(crate) mod service;
54#[cfg(test)]
55mod tests;
56mod tower_compat;
57
58assert_impl_all!(Request: Send);
59#[non_exhaustive]
63pub struct Request {
64 pub router_request: http::Request<Body>,
66
67 pub context: Context,
69}
70
71impl From<(http::Request<Body>, Context)> for Request {
72 fn from((router_request, context): (http::Request<Body>, Context)) -> Self {
73 Self {
74 router_request,
75 context,
76 }
77 }
78}
79
80pub struct IntoBody(Body);
85
86impl From<Body> for IntoBody {
87 fn from(value: Body) -> Self {
88 Self(value)
89 }
90}
91impl From<String> for IntoBody {
92 fn from(value: String) -> Self {
93 Self(body::from_bytes(value))
94 }
95}
96impl From<Bytes> for IntoBody {
97 fn from(value: Bytes) -> Self {
98 Self(body::from_bytes(value))
99 }
100}
101impl From<Vec<u8>> for IntoBody {
102 fn from(value: Vec<u8>) -> Self {
103 Self(body::from_bytes(value))
104 }
105}
106
107#[buildstructor::buildstructor]
108impl Request {
109 #[builder(visibility = "pub")]
113 fn new(
114 context: Context,
115 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
116 uri: http::Uri,
117 method: Method,
118 body: Body,
119 ) -> Result<Request, BoxError> {
120 let mut router_request = http::Request::builder()
121 .uri(uri)
122 .method(method)
123 .body(body)?;
124 *router_request.headers_mut() = header_map(headers)?;
125 Ok(Self {
126 router_request,
127 context,
128 })
129 }
130
131 #[builder(visibility = "pub")]
135 fn fake_new(
136 context: Option<Context>,
137 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
138 uri: Option<http::Uri>,
139 method: Option<Method>,
140 body: Option<IntoBody>,
141 ) -> Result<Request, BoxError> {
142 let mut router_request = http::Request::builder()
143 .uri(uri.unwrap_or_else(|| http::Uri::from_static("http://example.com/")))
144 .method(method.unwrap_or(Method::GET))
145 .body(body.map_or_else(body::empty, |constructed| constructed.0))?;
146 *router_request.headers_mut() = header_map(headers)?;
147 Ok(Self {
148 router_request,
149 context: context.unwrap_or_default(),
150 })
151 }
152}
153
154#[derive(Error, Display, Debug)]
155pub enum ParseError {
156 InvalidUri(http::uri::InvalidUri),
158 UrlEncodeError(serde_urlencoded::ser::Error),
160 SerializationError(serde_json::Error),
162}
163
164impl TryFrom<supergraph::Request> for Request {
166 type Error = ParseError;
167 fn try_from(request: supergraph::Request) -> Result<Self, Self::Error> {
168 let supergraph::Request {
169 context,
170 supergraph_request,
171 ..
172 } = request;
173
174 let (mut parts, request) = supergraph_request.into_parts();
175
176 let router_request = if parts.method == Method::GET {
177 let get_path = serde_urlencoded::to_string([
179 ("query", request.query),
180 ("operationName", request.operation_name),
181 (
182 "extensions",
183 serde_json::to_string(&request.extensions).ok(),
184 ),
185 ("variables", serde_json::to_string(&request.variables).ok()),
186 ])
187 .map_err(ParseError::UrlEncodeError)?;
188
189 parts.uri = format!("{}?{}", parts.uri, get_path)
190 .parse()
191 .map_err(ParseError::InvalidUri)?;
192
193 http::Request::from_parts(parts, body::empty())
194 } else {
195 http::Request::from_parts(
196 parts,
197 body::from_bytes(
198 serde_json::to_vec(&request).map_err(ParseError::SerializationError)?,
199 ),
200 )
201 };
202 Ok(Self {
203 router_request,
204 context,
205 })
206 }
207}
208
209assert_impl_all!(Response: Send);
210#[non_exhaustive]
211#[derive(Debug)]
212pub struct Response {
213 pub response: http::Response<Body>,
214 pub context: Context,
215}
216
217#[buildstructor::buildstructor]
218impl Response {
219 fn stash_the_body_in_extensions(&mut self, body_string: String) {
220 self.context.extensions().with_lock(|ext| {
221 ext.insert(RouterResponseBodyExtensionType(body_string));
222 });
223 }
224
225 pub async fn next_response(&mut self) -> Option<Result<Bytes, axum::Error>> {
226 self.response.body_mut().into_data_stream().next().await
227 }
228
229 #[builder(visibility = "pub")]
233 fn new(
234 label: Option<String>,
235 data: Option<serde_json_bytes::Value>,
236 path: Option<Path>,
237 errors: Vec<graphql::Error>,
238 extensions: JsonMap<ByteString, serde_json_bytes::Value>,
240 status_code: Option<StatusCode>,
241 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
242 context: Context,
243 ) -> Result<Self, BoxError> {
244 if !errors.is_empty() {
245 Self::add_errors_to_context(&errors, &context);
246 }
247
248 let b = graphql::Response::builder()
250 .and_label(label)
251 .and_path(path)
252 .errors(errors)
253 .extensions(extensions);
254 let res = match data {
255 Some(data) => b.data(data).build(),
256 None => b.build(),
257 };
258
259 let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
261 for (key, values) in headers {
262 let header_name: HeaderName = key.try_into()?;
263 for value in values {
264 let header_value: HeaderValue = value.try_into()?;
265 builder = builder.header(header_name.clone(), header_value);
266 }
267 }
268
269 let body_string = serde_json::to_string(&res)?;
270
271 let body = body::from_bytes(body_string.clone());
272 let response = builder.body(body)?;
273 let mut response = Self { response, context };
275 response.stash_the_body_in_extensions(body_string);
276
277 Ok(response)
278 }
279
280 #[builder(visibility = "pub")]
281 fn http_response_new(
282 response: http::Response<Body>,
283 context: Context,
284 body_to_stash: Option<String>,
285 errors_for_context: Option<Vec<graphql::Error>>,
286 ) -> Result<Self, BoxError> {
287 if let Some(errors) = errors_for_context
291 && !errors.is_empty()
292 {
293 Self::add_errors_to_context(&errors, &context);
294 }
295 let mut res = Self { response, context };
296 if let Some(body_to_stash) = body_to_stash {
297 res.stash_the_body_in_extensions(body_to_stash)
298 }
299 Ok(res)
300 }
301
302 #[builder(visibility = "pub")]
306 fn error_new(
307 errors: Vec<graphql::Error>,
308 status_code: Option<StatusCode>,
309 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
310 context: Context,
311 ) -> Result<Self, BoxError> {
312 Response::new(
313 Default::default(),
314 Default::default(),
315 None,
316 errors,
317 Default::default(),
318 status_code,
319 headers,
320 context,
321 )
322 }
323
324 #[builder(visibility = "pub(crate)")]
328 fn infallible_new(
329 label: Option<String>,
330 data: Option<serde_json_bytes::Value>,
331 path: Option<Path>,
332 errors: Vec<graphql::Error>,
333 extensions: JsonMap<ByteString, serde_json_bytes::Value>,
335 status_code: Option<StatusCode>,
336 headers: MultiMap<HeaderName, HeaderValue>,
337 context: Context,
338 ) -> Self {
339 if !errors.is_empty() {
340 Self::add_errors_to_context(&errors, &context);
341 }
342
343 let b = graphql::Response::builder()
345 .and_label(label)
346 .and_path(path)
347 .errors(errors)
348 .extensions(extensions);
349 let res = match data {
350 Some(data) => b.data(data).build(),
351 None => b.build(),
352 };
353
354 let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
356 for (header_name, values) in headers {
357 for header_value in values {
358 builder = builder.header(header_name.clone(), header_value);
359 }
360 }
361
362 let body_string = serde_json::to_string(&res).expect("JSON is always a valid string");
363
364 let body = body::from_bytes(body_string.clone());
365 let response = builder.body(body).expect("RouterBody is always valid");
366
367 Self { response, context }
368 }
369
370 fn add_errors_to_context(errors: &[graphql::Error], context: &Context) {
371 context.insert_json_value(CONTAINS_GRAPHQL_ERROR, Value::Bool(true));
372 context.insert_json_value(CHUNK_CONTAINS_GRAPHQL_ERROR, Value::Bool(true));
373 context
380 .insert(
381 ROUTER_RESPONSE_ERRORS,
382 errors
384 .iter()
385 .cloned()
386 .map(|err| (err.apollo_id(), err))
387 .collect::<HashMap<Uuid, graphql::Error>>(),
388 )
389 .expect("Unable to serialize router response errors list for context");
390 }
391
392 pub async fn into_graphql_response_stream(
394 self,
395 ) -> impl Stream<Item = Result<graphql::Response, serde_json::Error>> {
396 Box::pin(
397 if self
398 .response
399 .headers()
400 .get(CONTENT_TYPE)
401 .iter()
402 .any(|value| {
403 *value == MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE
404 || *value == MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE
405 })
406 {
407 let multipart = Multipart::new(
408 http_body_util::BodyDataStream::new(self.response.into_body()),
409 "graphql",
410 );
411
412 Either::Left(futures::stream::unfold(multipart, |mut m| async {
413 if let Ok(Some(response)) = m.next_field().await
414 && let Ok(bytes) = response.bytes().await
415 {
416 return Some((serde_json::from_slice::<graphql::Response>(&bytes), m));
417 }
418 None
419 }))
420 } else {
421 let mut body = http_body_util::BodyDataStream::new(self.response.into_body());
422 let res = body.next().await.and_then(|res| res.ok());
423
424 Either::Right(
425 futures::stream::iter(res)
426 .map(|bytes| serde_json::from_slice::<graphql::Response>(&bytes)),
427 )
428 },
429 )
430 }
431
432 #[builder(visibility = "pub")]
436 fn fake_new(
437 label: Option<String>,
438 data: Option<serde_json_bytes::Value>,
439 path: Option<Path>,
440 errors: Vec<graphql::Error>,
441 extensions: JsonMap<ByteString, serde_json_bytes::Value>,
443 status_code: Option<StatusCode>,
444 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
445 context: Option<Context>,
446 ) -> Result<Self, BoxError> {
447 Self::new(
449 label,
450 data,
451 path,
452 errors,
453 extensions,
454 status_code,
455 headers,
456 context.unwrap_or_default(),
457 )
458 }
459}
460
461#[derive(Clone, Default, Debug)]
462pub(crate) struct ClientRequestAccepts {
463 pub(crate) multipart_defer: bool,
464 pub(crate) multipart_subscription: bool,
465 pub(crate) json: bool,
466 pub(crate) wildcard: bool,
467}
468
469impl<T> From<http::Response<T>> for Response
470where
471 T: http_body::Body<Data = Bytes> + Send + 'static,
472 <T as http_body::Body>::Error: Into<BoxError>,
473{
474 fn from(response: http::Response<T>) -> Self {
475 let context: Context = response.extensions().get().cloned().unwrap_or_default();
476
477 Self {
478 response: response.map(convert_to_body),
479 context,
480 }
481 }
482}
483
484impl From<Response> for http::Response<Body> {
485 fn from(mut response: Response) -> Self {
486 response.response.extensions_mut().insert(response.context);
487 response.response
488 }
489}
490
491impl<T> From<http::Request<T>> for Request
492where
493 T: http_body::Body<Data = Bytes> + Send + 'static,
494 <T as http_body::Body>::Error: Into<BoxError>,
495{
496 fn from(request: http::Request<T>) -> Self {
497 let context: Context = request.extensions().get().cloned().unwrap_or_default();
498
499 Self {
500 router_request: request.map(convert_to_body),
501 context,
502 }
503 }
504}
505
506impl From<Request> for http::Request<Body> {
507 fn from(mut request: Request) -> Self {
508 request
509 .router_request
510 .extensions_mut()
511 .insert(request.context);
512 request.router_request
513 }
514}
515
516fn convert_to_body<T>(mut b: T) -> Body
522where
523 T: http_body::Body<Data = Bytes> + Send + 'static,
524 <T as http_body::Body>::Error: Into<BoxError>,
525{
526 let val_any = &mut b as &mut dyn Any;
527 match val_any.downcast_mut::<Body>() {
528 Some(body) => mem::take(body),
529 None => Body::new(http_body_util::BodyStream::new(b.map_err(axum::Error::new))),
530 }
531}
532
533#[cfg(test)]
534mod test {
535 use std::pin::Pin;
536 use std::task::Context;
537 use std::task::Poll;
538
539 use tower::BoxError;
540
541 use super::convert_to_body;
542 use crate::services::router;
543
544 struct MockBody {
545 data: Option<&'static str>,
546 }
547 impl http_body::Body for MockBody {
548 type Data = bytes::Bytes;
549 type Error = BoxError;
550
551 fn poll_frame(
552 self: Pin<&mut Self>,
553 _cx: &mut Context<'_>,
554 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
555 if let Some(data) = self.get_mut().data.take() {
556 Poll::Ready(Some(Ok(http_body::Frame::data(bytes::Bytes::from(data)))))
557 } else {
558 Poll::Ready(None)
559 }
560 }
561 }
562
563 #[tokio::test]
564 async fn test_convert_from_http_body() {
565 let body = convert_to_body(MockBody { data: Some("test") });
566 assert_eq!(
567 &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
568 "test"
569 );
570 }
571
572 #[tokio::test]
573 async fn test_convert_from_hyper_body() {
574 let body = convert_to_body(String::from("test"));
575 assert_eq!(
576 &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
577 "test"
578 );
579 }
580}