1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
//! Extension of map_future layer. Allows mapping of the first graphql response. Useful when working with a stream of responses.
//!
//! See [`Layer`] and [`Service`] for more details.

use std::future::ready;
use std::task::Poll;

use futures::future::BoxFuture;
use futures::stream::once;
use futures::FutureExt;
use futures::StreamExt;
use tower::Layer;
use tower::Service;

use crate::graphql;
use crate::services::supergraph;
use crate::Context;

/// [`Layer`] for mapping first graphql responses. See [`ServiceBuilderExt::map_first_graphql_response()`](crate::layers::ServiceBuilderExt::map_first_graphql_response()).
pub struct MapFirstGraphqlResponseLayer<Callback> {
    pub(super) callback: Callback,
}

/// [`Service`] for mapping first graphql responses. See [`ServiceBuilderExt::map_first_graphql_response()`](crate::layers::ServiceBuilderExt::map_first_graphql_response()).
pub struct MapFirstGraphqlResponseService<InnerService, Callback> {
    inner: InnerService,
    callback: Callback,
}

impl<InnerService, Callback> Layer<InnerService> for MapFirstGraphqlResponseLayer<Callback>
where
    Callback: Clone,
{
    type Service = MapFirstGraphqlResponseService<InnerService, Callback>;

    fn layer(&self, inner: InnerService) -> Self::Service {
        MapFirstGraphqlResponseService {
            inner,
            callback: self.callback.clone(),
        }
    }
}

impl<InnerService, Callback, Request> Service<Request>
    for MapFirstGraphqlResponseService<InnerService, Callback>
where
    InnerService: Service<Request, Response = supergraph::Response>,
    InnerService::Future: Send + 'static,
    Callback: FnOnce(
            Context,
            http::response::Parts,
            graphql::Response,
        ) -> (http::response::Parts, graphql::Response)
        + Clone
        + Send
        + 'static,
{
    type Response = supergraph::Response;
    type Error = InnerService::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        let future = self.inner.call(request);
        let callback = self.callback.clone();
        async move {
            let supergraph_response = future.await?;
            let context = supergraph_response.context;
            let (mut parts, mut stream) = supergraph_response.response.into_parts();
            if let Some(first) = stream.next().await {
                let (new_parts, new_first) = callback(context.clone(), parts, first);
                parts = new_parts;
                stream = once(ready(new_first)).chain(stream).boxed();
            };
            Ok(supergraph::Response {
                context,
                response: http::Response::from_parts(parts, stream),
            })
        }
        .boxed()
    }
}

#[cfg(test)]
mod tests {
    use tower::ServiceExt;

    use super::*;
    use crate::layers::ServiceExt as _;

    #[tokio::test]
    async fn test_map_first_graphql_response() {
        assert_eq!(
            crate::TestHarness::builder()
                .execution_hook(|service| {
                    service
                        .map_first_graphql_response(|_context, http_parts, mut graphql_response| {
                            graphql_response.errors.push(
                                graphql::Error::builder()
                                    .message("oh no!")
                                    .extension_code("FOO".to_string())
                                    .build(),
                            );
                            (http_parts, graphql_response)
                        })
                        .boxed()
                })
                .build_supergraph()
                .await
                .unwrap()
                .oneshot(supergraph::Request::canned_builder().build().unwrap())
                .await
                .unwrap()
                .next_response()
                .await
                .unwrap()
                .errors[0]
                .message,
            "oh no!"
        );
    }
}