apollo_router/layers/
map_first_graphql_response.rs

1//! Extension of map_future layer. Allows mapping of the first graphql response. Useful when working with a stream of responses.
2//!
3//! See [`Layer`] and [`Service`] for more details.
4
5use std::future::ready;
6use std::task::Poll;
7
8use futures::FutureExt;
9use futures::StreamExt;
10use futures::future::BoxFuture;
11use futures::stream::once;
12use tower::Layer;
13use tower::Service;
14
15use crate::Context;
16use crate::graphql;
17use crate::services::supergraph;
18
19/// [`Layer`] for mapping first graphql responses. See [`ServiceBuilderExt::map_first_graphql_response()`](crate::layers::ServiceBuilderExt::map_first_graphql_response()).
20pub struct MapFirstGraphqlResponseLayer<Callback> {
21    pub(super) callback: Callback,
22}
23
24/// [`Service`] for mapping first graphql responses. See [`ServiceBuilderExt::map_first_graphql_response()`](crate::layers::ServiceBuilderExt::map_first_graphql_response()).
25pub struct MapFirstGraphqlResponseService<InnerService, Callback> {
26    inner: InnerService,
27    callback: Callback,
28}
29
30impl<InnerService, Callback> Layer<InnerService> for MapFirstGraphqlResponseLayer<Callback>
31where
32    Callback: Clone,
33{
34    type Service = MapFirstGraphqlResponseService<InnerService, Callback>;
35
36    fn layer(&self, inner: InnerService) -> Self::Service {
37        MapFirstGraphqlResponseService {
38            inner,
39            callback: self.callback.clone(),
40        }
41    }
42}
43
44impl<InnerService, Callback, Request> Service<Request>
45    for MapFirstGraphqlResponseService<InnerService, Callback>
46where
47    InnerService: Service<Request, Response = supergraph::Response>,
48    InnerService::Future: Send + 'static,
49    Callback: FnOnce(
50            Context,
51            http::response::Parts,
52            graphql::Response,
53        ) -> (http::response::Parts, graphql::Response)
54        + Clone
55        + Send
56        + 'static,
57{
58    type Response = supergraph::Response;
59    type Error = InnerService::Error;
60    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
61
62    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
63        self.inner.poll_ready(cx)
64    }
65
66    fn call(&mut self, request: Request) -> Self::Future {
67        let future = self.inner.call(request);
68        let callback = self.callback.clone();
69        async move {
70            let supergraph_response = future.await?;
71            let context = supergraph_response.context;
72            let (mut parts, mut stream) = supergraph_response.response.into_parts();
73            if let Some(first) = stream.next().await {
74                let (new_parts, new_first) = callback(context.clone(), parts, first);
75                parts = new_parts;
76                stream = once(ready(new_first)).chain(stream).boxed();
77            };
78            Ok(supergraph::Response {
79                context,
80                response: http::Response::from_parts(parts, stream),
81            })
82        }
83        .boxed()
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use tower::ServiceExt;
90
91    use super::*;
92    use crate::layers::ServiceExt as _;
93
94    #[tokio::test]
95    async fn test_map_first_graphql_response() {
96        assert_eq!(
97            crate::TestHarness::builder()
98                .execution_hook(|service| {
99                    service
100                        .map_first_graphql_response(|_context, http_parts, mut graphql_response| {
101                            graphql_response.errors.push(
102                                graphql::Error::builder()
103                                    .message("oh no!")
104                                    .extension_code("FOO".to_string())
105                                    .build(),
106                            );
107                            (http_parts, graphql_response)
108                        })
109                        .boxed()
110                })
111                .build_supergraph()
112                .await
113                .unwrap()
114                .oneshot(supergraph::Request::canned_builder().build().unwrap())
115                .await
116                .unwrap()
117                .next_response()
118                .await
119                .unwrap()
120                .errors[0]
121                .message,
122            "oh no!"
123        );
124    }
125}