apollo_router/layers/
map_first_graphql_response.rs1use std::future::ready;
6use std::task::Poll;
7
8use futures::FutureExt;
9use futures::StreamExt;
10use futures::future::BoxFuture;
11use futures::stream::once;
12use tower::Layer;
13use tower::Service;
14
15use crate::Context;
16use crate::graphql;
17use crate::services::supergraph;
18
19pub struct MapFirstGraphqlResponseLayer<Callback> {
21 pub(super) callback: Callback,
22}
23
24pub struct MapFirstGraphqlResponseService<InnerService, Callback> {
26 inner: InnerService,
27 callback: Callback,
28}
29
30impl<InnerService, Callback> Layer<InnerService> for MapFirstGraphqlResponseLayer<Callback>
31where
32 Callback: Clone,
33{
34 type Service = MapFirstGraphqlResponseService<InnerService, Callback>;
35
36 fn layer(&self, inner: InnerService) -> Self::Service {
37 MapFirstGraphqlResponseService {
38 inner,
39 callback: self.callback.clone(),
40 }
41 }
42}
43
44impl<InnerService, Callback, Request> Service<Request>
45 for MapFirstGraphqlResponseService<InnerService, Callback>
46where
47 InnerService: Service<Request, Response = supergraph::Response>,
48 InnerService::Future: Send + 'static,
49 Callback: FnOnce(
50 Context,
51 http::response::Parts,
52 graphql::Response,
53 ) -> (http::response::Parts, graphql::Response)
54 + Clone
55 + Send
56 + 'static,
57{
58 type Response = supergraph::Response;
59 type Error = InnerService::Error;
60 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
61
62 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
63 self.inner.poll_ready(cx)
64 }
65
66 fn call(&mut self, request: Request) -> Self::Future {
67 let future = self.inner.call(request);
68 let callback = self.callback.clone();
69 async move {
70 let supergraph_response = future.await?;
71 let context = supergraph_response.context;
72 let (mut parts, mut stream) = supergraph_response.response.into_parts();
73 if let Some(first) = stream.next().await {
74 let (new_parts, new_first) = callback(context.clone(), parts, first);
75 parts = new_parts;
76 stream = once(ready(new_first)).chain(stream).boxed();
77 };
78 Ok(supergraph::Response {
79 context,
80 response: http::Response::from_parts(parts, stream),
81 })
82 }
83 .boxed()
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use tower::ServiceExt;
90
91 use super::*;
92 use crate::layers::ServiceExt as _;
93
94 #[tokio::test]
95 async fn test_map_first_graphql_response() {
96 assert_eq!(
97 crate::TestHarness::builder()
98 .execution_hook(|service| {
99 service
100 .map_first_graphql_response(|_context, http_parts, mut graphql_response| {
101 graphql_response.errors.push(
102 graphql::Error::builder()
103 .message("oh no!")
104 .extension_code("FOO".to_string())
105 .build(),
106 );
107 (http_parts, graphql_response)
108 })
109 .boxed()
110 })
111 .build_supergraph()
112 .await
113 .unwrap()
114 .oneshot(supergraph::Request::canned_builder().build().unwrap())
115 .await
116 .unwrap()
117 .next_response()
118 .await
119 .unwrap()
120 .errors[0]
121 .message,
122 "oh no!"
123 );
124 }
125}