1#![allow(missing_docs)] use std::any::Any;
4use std::mem;
5
6use bytes::Bytes;
7use futures::Stream;
8use futures::StreamExt;
9use futures::future::Either;
10use http::HeaderValue;
11use http::Method;
12use http::StatusCode;
13use http::header::CONTENT_TYPE;
14use http::header::HeaderName;
15use multer::Multipart;
16use multimap::MultiMap;
17use serde_json_bytes::ByteString;
18use serde_json_bytes::Map as JsonMap;
19use serde_json_bytes::Value;
20use static_assertions::assert_impl_all;
21use tower::BoxError;
22
23use self::body::RouterBody;
24use self::service::MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE;
25use self::service::MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE;
26use super::supergraph;
27use crate::Context;
28use crate::graphql;
29use crate::http_ext::header_map;
30use crate::json_ext::Path;
31use crate::services;
32use crate::services::TryIntoHeaderName;
33use crate::services::TryIntoHeaderValue;
34
35pub type BoxService = tower::util::BoxService<Request, Response, BoxError>;
36pub type BoxCloneService = tower::util::BoxCloneService<Request, Response, BoxError>;
37pub type ServiceResult = Result<Response, BoxError>;
38pub type Body = hyper::Body;
40pub type Error = hyper::Error;
41
42pub mod body;
43pub(crate) mod pipeline_handle;
44pub(crate) mod service;
45#[cfg(test)]
46mod tests;
47
48assert_impl_all!(Request: Send);
49#[non_exhaustive]
53pub struct Request {
54 pub router_request: http::Request<Body>,
56
57 pub context: Context,
59}
60
61impl From<(http::Request<Body>, Context)> for Request {
62 fn from((router_request, context): (http::Request<Body>, Context)) -> Self {
63 Self {
64 router_request,
65 context,
66 }
67 }
68}
69
70#[buildstructor::buildstructor]
71impl Request {
72 #[builder(visibility = "pub")]
76 fn new(
77 context: Context,
78 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
79 uri: http::Uri,
80 method: Method,
81 body: Body,
82 ) -> Result<Request, BoxError> {
83 let mut router_request = http::Request::builder()
84 .uri(uri)
85 .method(method)
86 .body(body)?;
87 *router_request.headers_mut() = header_map(headers)?;
88 Ok(Self {
89 router_request,
90 context,
91 })
92 }
93
94 #[builder(visibility = "pub")]
98 fn fake_new(
99 context: Option<Context>,
100 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
101 uri: Option<http::Uri>,
102 method: Option<Method>,
103 body: Option<Body>,
104 ) -> Result<Request, BoxError> {
105 let mut router_request = http::Request::builder()
106 .uri(uri.unwrap_or_else(|| http::Uri::from_static("http://example.com/")))
107 .method(method.unwrap_or(Method::GET))
108 .body(body.unwrap_or_else(Body::empty))?;
109 *router_request.headers_mut() = header_map(headers)?;
110 Ok(Self {
111 router_request,
112 context: context.unwrap_or_default(),
113 })
114 }
115}
116
117use displaydoc::Display;
118use thiserror::Error;
119
120#[derive(Error, Display, Debug)]
121pub enum ParseError {
122 InvalidUri(http::uri::InvalidUri),
124 UrlEncodeError(serde_urlencoded::ser::Error),
126 SerializationError(serde_json::Error),
128}
129
130impl TryFrom<supergraph::Request> for Request {
132 type Error = ParseError;
133 fn try_from(request: supergraph::Request) -> Result<Self, Self::Error> {
134 let supergraph::Request {
135 context,
136 supergraph_request,
137 ..
138 } = request;
139
140 let (mut parts, request) = supergraph_request.into_parts();
141
142 let router_request = if parts.method == Method::GET {
143 let get_path = serde_urlencoded::to_string([
145 ("query", request.query),
146 ("operationName", request.operation_name),
147 (
148 "extensions",
149 serde_json::to_string(&request.extensions).ok(),
150 ),
151 ("variables", serde_json::to_string(&request.variables).ok()),
152 ])
153 .map_err(ParseError::UrlEncodeError)?;
154
155 parts.uri = format!("{}?{}", parts.uri, get_path)
156 .parse()
157 .map_err(ParseError::InvalidUri)?;
158
159 http::Request::from_parts(parts, RouterBody::empty().into_inner())
160 } else {
161 http::Request::from_parts(
162 parts,
163 RouterBody::from(
164 serde_json::to_vec(&request).map_err(ParseError::SerializationError)?,
165 )
166 .into_inner(),
167 )
168 };
169 Ok(Self {
170 router_request,
171 context,
172 })
173 }
174}
175
176assert_impl_all!(Response: Send);
177#[non_exhaustive]
178#[derive(Debug)]
179pub struct Response {
180 pub response: http::Response<Body>,
181 pub context: Context,
182}
183
184#[buildstructor::buildstructor]
185impl Response {
186 pub async fn next_response(&mut self) -> Option<Result<Bytes, Error>> {
187 self.response.body_mut().next().await
188 }
189
190 #[deprecated]
191 pub fn map<F>(self, f: F) -> Response
192 where
193 F: FnOnce(Body) -> Body,
194 {
195 Response {
196 context: self.context,
197 response: self.response.map(f),
198 }
199 }
200
201 #[builder(visibility = "pub")]
205 fn new(
206 label: Option<String>,
207 data: Option<Value>,
208 path: Option<Path>,
209 errors: Vec<graphql::Error>,
210 extensions: JsonMap<ByteString, Value>,
212 status_code: Option<StatusCode>,
213 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
214 context: Context,
215 ) -> Result<Self, BoxError> {
216 let b = graphql::Response::builder()
218 .and_label(label)
219 .and_path(path)
220 .errors(errors)
221 .extensions(extensions);
222 let res = match data {
223 Some(data) => b.data(data).build(),
224 None => b.build(),
225 };
226
227 let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
229 for (key, values) in headers {
230 let header_name: HeaderName = key.try_into()?;
231 for value in values {
232 let header_value: HeaderValue = value.try_into()?;
233 builder = builder.header(header_name.clone(), header_value);
234 }
235 }
236
237 let response = builder.body(RouterBody::from(serde_json::to_vec(&res)?).into_inner())?;
240
241 Ok(Self { response, context })
242 }
243
244 #[builder(visibility = "pub")]
248 fn error_new(
249 errors: Vec<graphql::Error>,
250 status_code: Option<StatusCode>,
251 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
252 context: Context,
253 ) -> Result<Self, BoxError> {
254 Response::new(
255 Default::default(),
256 Default::default(),
257 None,
258 errors,
259 Default::default(),
260 status_code,
261 headers,
262 context,
263 )
264 }
265
266 #[builder(visibility = "pub(crate)")]
270 fn infallible_new(
271 label: Option<String>,
272 data: Option<Value>,
273 path: Option<Path>,
274 errors: Vec<graphql::Error>,
275 extensions: JsonMap<ByteString, Value>,
277 status_code: Option<StatusCode>,
278 headers: MultiMap<HeaderName, HeaderValue>,
279 context: Context,
280 ) -> Self {
281 let b = graphql::Response::builder()
283 .and_label(label)
284 .and_path(path)
285 .errors(errors)
286 .extensions(extensions);
287 let res = match data {
288 Some(data) => b.data(data).build(),
289 None => b.build(),
290 };
291
292 let mut builder = http::Response::builder().status(status_code.unwrap_or(StatusCode::OK));
294 for (header_name, values) in headers {
295 for header_value in values {
296 builder = builder.header(header_name.clone(), header_value);
297 }
298 }
299
300 let response = builder
301 .body(RouterBody::from(serde_json::to_vec(&res).expect("can't fail")).into_inner())
302 .expect("can't fail");
303
304 Self { response, context }
305 }
306
307 pub async fn into_graphql_response_stream(
309 self,
310 ) -> impl Stream<Item = Result<crate::graphql::Response, serde_json::Error>> {
311 Box::pin(
312 if self
313 .response
314 .headers()
315 .get(CONTENT_TYPE)
316 .iter()
317 .any(|value| {
318 *value == MULTIPART_DEFER_CONTENT_TYPE_HEADER_VALUE
319 || *value == MULTIPART_SUBSCRIPTION_CONTENT_TYPE_HEADER_VALUE
320 })
321 {
322 let multipart = Multipart::new(self.response.into_body(), "graphql");
323
324 Either::Left(futures::stream::unfold(multipart, |mut m| async {
325 if let Ok(Some(response)) = m.next_field().await {
326 if let Ok(bytes) = response.bytes().await {
327 return Some((
328 serde_json::from_slice::<crate::graphql::Response>(&bytes),
329 m,
330 ));
331 }
332 }
333 None
334 }))
335 } else {
336 let mut body = self.response.into_body();
337 let res = body.next().await.and_then(|res| res.ok());
338
339 Either::Right(
340 futures::stream::iter(res.into_iter())
341 .map(|bytes| serde_json::from_slice::<crate::graphql::Response>(&bytes)),
342 )
343 },
344 )
345 }
346
347 #[builder(visibility = "pub")]
351 fn fake_new(
352 label: Option<String>,
353 data: Option<Value>,
354 path: Option<Path>,
355 errors: Vec<graphql::Error>,
356 extensions: JsonMap<ByteString, Value>,
358 status_code: Option<StatusCode>,
359 headers: MultiMap<TryIntoHeaderName, TryIntoHeaderValue>,
360 context: Option<Context>,
361 ) -> Result<Self, BoxError> {
362 Self::new(
364 label,
365 data,
366 path,
367 errors,
368 extensions,
369 status_code,
370 headers,
371 context.unwrap_or_default(),
372 )
373 }
374}
375
376#[derive(Clone, Default, Debug)]
377pub(crate) struct ClientRequestAccepts {
378 pub(crate) multipart_defer: bool,
379 pub(crate) multipart_subscription: bool,
380 pub(crate) json: bool,
381 pub(crate) wildcard: bool,
382}
383
384impl<T> From<http::Response<T>> for Response
385where
386 T: http_body::Body<Data = Bytes> + Send + 'static,
387 <T as http_body::Body>::Error: Into<BoxError>,
388{
389 fn from(response: http::Response<T>) -> Self {
390 let context: Context = response.extensions().get().cloned().unwrap_or_default();
391
392 Self {
393 response: response.map(convert_to_body),
394 context,
395 }
396 }
397}
398
399impl From<Response> for http::Response<Body> {
400 fn from(mut response: Response) -> Self {
401 response.response.extensions_mut().insert(response.context);
402 response.response
403 }
404}
405
406impl<T> From<http::Request<T>> for Request
407where
408 T: http_body::Body<Data = Bytes> + Send + 'static,
409 <T as http_body::Body>::Error: Into<BoxError>,
410{
411 fn from(request: http::Request<T>) -> Self {
412 let context: Context = request.extensions().get().cloned().unwrap_or_default();
413
414 Self {
415 router_request: request.map(convert_to_body),
416 context,
417 }
418 }
419}
420
421impl From<Request> for http::Request<Body> {
422 fn from(mut request: Request) -> Self {
423 request
424 .router_request
425 .extensions_mut()
426 .insert(request.context);
427 request.router_request
428 }
429}
430
431fn convert_to_body<T>(mut b: T) -> Body
437where
438 T: http_body::Body<Data = Bytes> + Send + 'static,
439 <T as http_body::Body>::Error: Into<BoxError>,
440{
441 let val_any = &mut b as &mut dyn Any;
442 match val_any.downcast_mut::<Body>() {
443 Some(body) => mem::take(body),
444 None => Body::wrap_stream(services::http::body_stream::BodyStream::new(
445 b.map_err(Into::into),
446 )),
447 }
448}
449
450#[cfg(test)]
451mod test {
452 use std::pin::Pin;
453 use std::task::Context;
454 use std::task::Poll;
455
456 use http::HeaderMap;
457 use tower::BoxError;
458
459 use crate::services::router::body::get_body_bytes;
460 use crate::services::router::convert_to_body;
461
462 struct MockBody {
463 data: Option<&'static str>,
464 }
465 impl http_body::Body for MockBody {
466 type Data = bytes::Bytes;
467 type Error = BoxError;
468
469 fn poll_data(
470 mut self: Pin<&mut Self>,
471 _cx: &mut Context<'_>,
472 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
473 if let Some(data) = self.data.take() {
474 Poll::Ready(Some(Ok(bytes::Bytes::from(data))))
475 } else {
476 Poll::Ready(None)
477 }
478 }
479
480 fn poll_trailers(
481 self: Pin<&mut Self>,
482 _cx: &mut Context<'_>,
483 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
484 Poll::Ready(Ok(None))
485 }
486 }
487
488 #[tokio::test]
489 async fn test_convert_from_http_body() {
490 let body = convert_to_body(MockBody { data: Some("test") });
491 assert_eq!(
492 &String::from_utf8(get_body_bytes(body).await.unwrap().to_vec()).unwrap(),
493 "test"
494 );
495 }
496
497 #[tokio::test]
498 async fn test_convert_from_hyper_body() {
499 let body = convert_to_body(hyper::Body::from("test"));
500 assert_eq!(
501 &String::from_utf8(get_body_bytes(body).await.unwrap().to_vec()).unwrap(),
502 "test"
503 );
504 }
505}