1#![allow(missing_docs)] use std::any::Any;
4use std::mem;
5
6use ahash::HashMap;
7use bytes::Bytes;
8use displaydoc::Display;
9use futures::Stream;
10use futures::StreamExt;
11use futures::future::Either;
12use http::HeaderValue;
13use http::Method;
14use http::StatusCode;
15use http::header::CONTENT_TYPE;
16use http::header::HeaderName;
17use http_body_util::BodyExt;
18use multer::Multipart;
19use multimap::MultiMap;
20use serde_json_bytes::ByteString;
21use serde_json_bytes::Map as JsonMap;
22use serde_json_bytes::Value;
23use static_assertions::assert_impl_all;
24use thiserror::Error;
25use tower::BoxError;
26use uuid::Uuid;
27
28use self::body::RouterBody;
29use self::service::MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE;
30use self::service::MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE;
31use super::supergraph;
32use crate::Context;
33use crate::context::CONTAINS_GRAPHQL_ERROR;
34use crate::context::ROUTER_RESPONSE_ERRORS;
35use crate::graphql;
36use crate::http_ext::header_map;
37use crate::json_ext::Path;
38use crate::plugins::telemetry::config_new::router::events::RouterResponseBodyExtensionType;
39use crate::services::TryIntoHeaderName;
40use crate::services::TryIntoHeaderValue;
41
42pub type BoxService = tower::util::BoxService<Request, Response, BoxError>;
43pub type BoxCloneService = tower::util::BoxCloneService<Request, Response, BoxError>;
44pub type ServiceResult = Result<Response, BoxError>;
45
46pub type Body = RouterBody;
47pub type Error = hyper::Error;
48
49pub mod body;
50pub(crate) mod pipeline_handle;
51pub(crate) mod service;
52#[cfg(test)]
53mod tests;
54
55assert_impl_all!(Request: Send);
56#[non_exhaustive]
60pub struct Request {
61 pub router_request: http::Request<Body>,
63
64 pub context: Context,
66}
67
68impl From<(http::Request<Body>, Context)> for Request {
69 fn from((router_request, context): (http::Request<Body>, Context)) -> Self {
70 Self {
71 router_request,
72 context,
73 }
74 }
75}
76
77pub struct IntoBody(Body);
82
83impl From<Body> for IntoBody {
84 fn from(value: Body) -> Self {
85 Self(value)
86 }
87}
88impl From<String> for IntoBody {
89 fn from(value: String) -> Self {
90 Self(body::from_bytes(value))
91 }
92}
93impl From<Bytes> for IntoBody {
94 fn from(value: Bytes) -> Self {
95 Self(body::from_bytes(value))
96 }
97}
98impl From<Vec<u8>> for IntoBody {
99 fn from(value: Vec<u8>) -> Self {
100 Self(body::from_bytes(value))
101 }
102}
103
104#[buildstructor::buildstructor]
105impl Request {
106 #[builder(visibility = "pub")]
110 fn new(
111 context: Context,
112 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
113 uri: http::Uri,
114 method: Method,
115 body: Body,
116 ) -> Result<Request, BoxError> {
117 let mut router_request = http::Request::builder()
118 .uri(uri)
119 .method(method)
120 .body(body)?;
121 *router_request.headers_mut() = header_map(headers)?;
122 Ok(Self {
123 router_request,
124 context,
125 })
126 }
127
128 #[builder(visibility = "pub")]
132 fn fake_new(
133 context: Option<Context>,
134 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
135 uri: Option<http::Uri>,
136 method: Option<Method>,
137 body: Option<IntoBody>,
138 ) -> Result<Request, BoxError> {
139 let mut router_request = http::Request::builder()
140 .uri(uri.unwrap_or_else(|| http::Uri::from_static("http://example.com/")))
141 .method(method.unwrap_or(Method::GET))
142 .body(body.map_or_else(body::empty, |constructed| constructed.0))?;
143 *router_request.headers_mut() = header_map(headers)?;
144 Ok(Self {
145 router_request,
146 context: context.unwrap_or_default(),
147 })
148 }
149}
150
151#[derive(Error, Display, Debug)]
152pub enum ParseError {
153 InvalidUri(http::uri::InvalidUri),
155 UrlEncodeError(serde_urlencoded::ser::Error),
157 SerializationError(serde_json::Error),
159}
160
161impl TryFrom<supergraph::Request> for Request {
163 type Error = ParseError;
164 fn try_from(request: supergraph::Request) -> Result<Self, Self::Error> {
165 let supergraph::Request {
166 context,
167 supergraph_request,
168 ..
169 } = request;
170
171 let (mut parts, request) = supergraph_request.into_parts();
172
173 let router_request = if parts.method == Method::GET {
174 let get_path = serde_urlencoded::to_string([
176 ("query", request.query),
177 ("operationName", request.operation_name),
178 (
179 "extensions",
180 serde_json::to_string(&request.extensions).ok(),
181 ),
182 ("variables", serde_json::to_string(&request.variables).ok()),
183 ])
184 .map_err(ParseError::UrlEncodeError)?;
185
186 parts.uri = format!("{}?{}", parts.uri, get_path)
187 .parse()
188 .map_err(ParseError::InvalidUri)?;
189
190 http::Request::from_parts(parts, body::empty())
191 } else {
192 http::Request::from_parts(
193 parts,
194 body::from_bytes(
195 serde_json::to_vec(&request).map_err(ParseError::SerializationError)?,
196 ),
197 )
198 };
199 Ok(Self {
200 router_request,
201 context,
202 })
203 }
204}
205
206assert_impl_all!(Response: Send);
207#[non_exhaustive]
208#[derive(Debug)]
209pub struct Response {
210 pub response: http::Response<Body>,
211 pub context: Context,
212}
213
214#[buildstructor::buildstructor]
215impl Response {
216 fn stash_the_body_in_extensions(&mut self, body_string: String) {
217 self.context.extensions().with_lock(|ext| {
218 ext.insert(RouterResponseBodyExtensionType(body_string));
219 });
220 }
221
222 pub async fn next_response(&mut self) -> Option<Result<Bytes, axum::Error>> {
223 self.response.body_mut().into_data_stream().next().await
224 }
225
226 #[builder(visibility = "pub")]
230 fn new(
231 label: Option<String>,
232 data: Option<serde_json_bytes::Value>,
233 path: Option<Path>,
234 errors: Vec<graphql::Error>,
235 extensions: JsonMap<ByteString, serde_json_bytes::Value>,
237 status_code: Option<StatusCode>,
238 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
239 context: Context,
240 ) -> Result<Self, BoxError> {
241 if !errors.is_empty() {
242 Self::add_errors_to_context(&errors, &context);
243 }
244
245 let b = graphql::Response::builder()
247 .and_label(label)
248 .and_path(path)
249 .errors(errors)
250 .extensions(extensions);
251 let res = match data {
252 Some(data) => b.data(data).build(),
253 None => b.build(),
254 };
255
256 let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
258 for (key, values) in headers {
259 let header_name: HeaderName = key.try_into()?;
260 for value in values {
261 let header_value: HeaderValue = value.try_into()?;
262 builder = builder.header(header_name.clone(), header_value);
263 }
264 }
265
266 let body_string = serde_json::to_string(&res)?;
267
268 let body = body::from_bytes(body_string.clone());
269 let response = builder.body(body)?;
270 let mut response = Self { response, context };
272 response.stash_the_body_in_extensions(body_string);
273
274 Ok(response)
275 }
276
277 #[builder(visibility = "pub")]
278 fn http_response_new(
279 response: http::Response<Body>,
280 context: Context,
281 body_to_stash: Option<String>,
282 errors_for_context: Option<Vec<graphql::Error>>,
283 ) -> Result<Self, BoxError> {
284 if let Some(errors) = errors_for_context
288 && !errors.is_empty()
289 {
290 Self::add_errors_to_context(&errors, &context);
291 }
292 let mut res = Self { response, context };
293 if let Some(body_to_stash) = body_to_stash {
294 res.stash_the_body_in_extensions(body_to_stash)
295 }
296 Ok(res)
297 }
298
299 #[builder(visibility = "pub")]
303 fn error_new(
304 errors: Vec<graphql::Error>,
305 status_code: Option<StatusCode>,
306 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
307 context: Context,
308 ) -> Result<Self, BoxError> {
309 Response::new(
310 Default::default(),
311 Default::default(),
312 None,
313 errors,
314 Default::default(),
315 status_code,
316 headers,
317 context,
318 )
319 }
320
321 #[builder(visibility = "pub(crate)")]
325 fn infallible_new(
326 label: Option<String>,
327 data: Option<serde_json_bytes::Value>,
328 path: Option<Path>,
329 errors: Vec<graphql::Error>,
330 extensions: JsonMap<ByteString, serde_json_bytes::Value>,
332 status_code: Option<StatusCode>,
333 headers: MultiMap<HeaderName, HeaderValue>,
334 context: Context,
335 ) -> Self {
336 if !errors.is_empty() {
337 Self::add_errors_to_context(&errors, &context);
338 }
339
340 let b = graphql::Response::builder()
342 .and_label(label)
343 .and_path(path)
344 .errors(errors)
345 .extensions(extensions);
346 let res = match data {
347 Some(data) => b.data(data).build(),
348 None => b.build(),
349 };
350
351 let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
353 for (header_name, values) in headers {
354 for header_value in values {
355 builder = builder.header(header_name.clone(), header_value);
356 }
357 }
358
359 let body_string = serde_json::to_string(&res).expect("JSON is always a valid string");
360
361 let body = body::from_bytes(body_string.clone());
362 let response = builder.body(body).expect("RouterBody is always valid");
363
364 Self { response, context }
365 }
366
367 fn add_errors_to_context(errors: &[graphql::Error], context: &Context) {
368 context.insert_json_value(CONTAINS_GRAPHQL_ERROR, Value::Bool(true));
369 context
376 .insert(
377 ROUTER_RESPONSE_ERRORS,
378 errors
380 .iter()
381 .cloned()
382 .map(|err| (err.apollo_id(), err))
383 .collect::<HashMap<Uuid, graphql::Error>>(),
384 )
385 .expect("Unable to serialize router response errors list for context");
386 }
387
388 pub async fn into_graphql_response_stream(
390 self,
391 ) -> impl Stream<Item = Result<graphql::Response, serde_json::Error>> {
392 Box::pin(
393 if self
394 .response
395 .headers()
396 .get(CONTENT_TYPE)
397 .iter()
398 .any(|value| {
399 *value == MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE
400 || *value == MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE
401 })
402 {
403 let multipart = Multipart::new(
404 http_body_util::BodyDataStream::new(self.response.into_body()),
405 "graphql",
406 );
407
408 Either::Left(futures::stream::unfold(multipart, |mut m| async {
409 if let Ok(Some(response)) = m.next_field().await
410 && let Ok(bytes) = response.bytes().await
411 {
412 return Some((serde_json::from_slice::<graphql::Response>(&bytes), m));
413 }
414 None
415 }))
416 } else {
417 let mut body = http_body_util::BodyDataStream::new(self.response.into_body());
418 let res = body.next().await.and_then(|res| res.ok());
419
420 Either::Right(
421 futures::stream::iter(res.into_iter())
422 .map(|bytes| serde_json::from_slice::<graphql::Response>(&bytes)),
423 )
424 },
425 )
426 }
427
428 #[builder(visibility = "pub")]
432 fn fake_new(
433 label: Option<String>,
434 data: Option<serde_json_bytes::Value>,
435 path: Option<Path>,
436 errors: Vec<graphql::Error>,
437 extensions: JsonMap<ByteString, serde_json_bytes::Value>,
439 status_code: Option<StatusCode>,
440 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
441 context: Option<Context>,
442 ) -> Result<Self, BoxError> {
443 Self::new(
445 label,
446 data,
447 path,
448 errors,
449 extensions,
450 status_code,
451 headers,
452 context.unwrap_or_default(),
453 )
454 }
455}
456
457#[derive(Clone, Default, Debug)]
458pub(crate) struct ClientRequestAccepts {
459 pub(crate) multipart_defer: bool,
460 pub(crate) multipart_subscription: bool,
461 pub(crate) json: bool,
462 pub(crate) wildcard: bool,
463}
464
465impl<T> From<http::Response<T>> for Response
466where
467 T: http_body::Body<Data = Bytes> + Send + 'static,
468 <T as http_body::Body>::Error: Into<BoxError>,
469{
470 fn from(response: http::Response<T>) -> Self {
471 let context: Context = response.extensions().get().cloned().unwrap_or_default();
472
473 Self {
474 response: response.map(convert_to_body),
475 context,
476 }
477 }
478}
479
480impl From<Response> for http::Response<Body> {
481 fn from(mut response: Response) -> Self {
482 response.response.extensions_mut().insert(response.context);
483 response.response
484 }
485}
486
487impl<T> From<http::Request<T>> for Request
488where
489 T: http_body::Body<Data = Bytes> + Send + 'static,
490 <T as http_body::Body>::Error: Into<BoxError>,
491{
492 fn from(request: http::Request<T>) -> Self {
493 let context: Context = request.extensions().get().cloned().unwrap_or_default();
494
495 Self {
496 router_request: request.map(convert_to_body),
497 context,
498 }
499 }
500}
501
502impl From<Request> for http::Request<Body> {
503 fn from(mut request: Request) -> Self {
504 request
505 .router_request
506 .extensions_mut()
507 .insert(request.context);
508 request.router_request
509 }
510}
511
512fn convert_to_body<T>(mut b: T) -> Body
518where
519 T: http_body::Body<Data = Bytes> + Send + 'static,
520 <T as http_body::Body>::Error: Into<BoxError>,
521{
522 let val_any = &mut b as &mut dyn Any;
523 match val_any.downcast_mut::<Body>() {
524 Some(body) => mem::take(body),
525 None => Body::new(http_body_util::BodyStream::new(b.map_err(axum::Error::new))),
526 }
527}
528
529#[cfg(test)]
530mod test {
531 use std::pin::Pin;
532 use std::task::Context;
533 use std::task::Poll;
534
535 use tower::BoxError;
536
537 use super::convert_to_body;
538 use crate::services::router;
539
540 struct MockBody {
541 data: Option<&'static str>,
542 }
543 impl http_body::Body for MockBody {
544 type Data = bytes::Bytes;
545 type Error = BoxError;
546
547 fn poll_frame(
548 self: Pin<&mut Self>,
549 _cx: &mut Context<'_>,
550 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
551 if let Some(data) = self.get_mut().data.take() {
552 Poll::Ready(Some(Ok(http_body::Frame::data(bytes::Bytes::from(data)))))
553 } else {
554 Poll::Ready(None)
555 }
556 }
557 }
558
559 #[tokio::test]
560 async fn test_convert_from_http_body() {
561 let body = convert_to_body(MockBody { data: Some("test") });
562 assert_eq!(
563 &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
564 "test"
565 );
566 }
567
568 #[tokio::test]
569 async fn test_convert_from_hyper_body() {
570 let body = convert_to_body(String::from("test"));
571 assert_eq!(
572 &String::from_utf8(router::body::into_bytes(body).await.unwrap().to_vec()).unwrap(),
573 "test"
574 );
575 }
576}