apollo_router/layers/
map_future_with_request_data.rs

1//! Extension of map_future layer. Allows mapping of the future using information obtained from the request.
2//!
3//! See [`Layer`] and [`Service`] for more details.
4
5use std::future::Future;
6use std::task::Context;
7use std::task::Poll;
8
9use tower::Layer;
10use tower::Service;
11
12/// [`Layer`] for mapping futures with request data. See [`ServiceBuilderExt::map_future_with_request_data()`](crate::layers::ServiceBuilderExt::map_future_with_request_data()).
13#[derive(Clone)]
14pub struct MapFutureWithRequestDataLayer<RF, MF> {
15    req_fn: RF,
16    map_fn: MF,
17}
18
19impl<RF, MF> MapFutureWithRequestDataLayer<RF, MF> {
20    /// Create a new instance.
21    pub fn new(req_fn: RF, map_fn: MF) -> Self {
22        Self { req_fn, map_fn }
23    }
24}
25
26impl<S, RF, MF> Layer<S> for MapFutureWithRequestDataLayer<RF, MF>
27where
28    RF: Clone,
29    MF: Clone,
30{
31    type Service = MapFutureWithRequestDataService<S, RF, MF>;
32
33    fn layer(&self, inner: S) -> Self::Service {
34        MapFutureWithRequestDataService::new(inner, self.req_fn.clone(), self.map_fn.clone())
35    }
36}
37
38/// [`Service`] for mapping futures with request data. See [`ServiceBuilderExt::map_future_with_request_data()`](crate::layers::ServiceBuilderExt::map_future_with_request_data()).
39#[derive(Clone)]
40pub struct MapFutureWithRequestDataService<S, RF, MF> {
41    inner: S,
42    req_fn: RF,
43    map_fn: MF,
44}
45
46impl<S, RF, MF> MapFutureWithRequestDataService<S, RF, MF> {
47    /// Create a new instance.
48    pub fn new(inner: S, req_fn: RF, map_fn: MF) -> MapFutureWithRequestDataService<S, RF, MF>
49    where
50        RF: Clone,
51        MF: Clone,
52    {
53        MapFutureWithRequestDataService {
54            inner,
55            req_fn,
56            map_fn,
57        }
58    }
59}
60
61impl<R, S, MF, RF, T, E, Fut, ReqData> Service<R> for MapFutureWithRequestDataService<S, RF, MF>
62where
63    S: Service<R>,
64    RF: FnMut(&R) -> ReqData,
65    MF: FnMut(ReqData, S::Future) -> Fut,
66    E: From<S::Error>,
67    Fut: Future<Output = Result<T, E>>,
68{
69    type Response = T;
70    type Error = E;
71    type Future = Fut;
72
73    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
74        self.inner.poll_ready(cx).map_err(From::from)
75    }
76
77    fn call(&mut self, req: R) -> Self::Future {
78        let data = (self.req_fn)(&req);
79        (self.map_fn)(data, self.inner.call(req))
80    }
81}
82
83#[cfg(test)]
84mod test {
85    use http::HeaderValue;
86    use tower::BoxError;
87    use tower::Service;
88    use tower::ServiceBuilder;
89    use tower::ServiceExt;
90
91    use crate::layers::ServiceBuilderExt;
92    use crate::plugin::test::MockSupergraphService;
93    use crate::services::SupergraphRequest;
94    use crate::services::SupergraphResponse;
95
96    #[tokio::test]
97    async fn test_layer() -> Result<(), BoxError> {
98        let mut mock_service = MockSupergraphService::new();
99        mock_service
100            .expect_call()
101            .once()
102            .returning(|_| Ok(SupergraphResponse::fake_builder().build().unwrap()));
103
104        let mut service = ServiceBuilder::new()
105            .map_future_with_request_data(
106                |req: &SupergraphRequest| {
107                    req.supergraph_request
108                        .headers()
109                        .get("hello")
110                        .cloned()
111                        .unwrap()
112                },
113                |value: HeaderValue, resp| async move {
114                    let resp: Result<SupergraphResponse, BoxError> = resp.await;
115                    resp.map(|mut response| {
116                        response
117                            .response
118                            .headers_mut()
119                            .insert("hello", value.clone());
120                        response
121                    })
122                },
123            )
124            .service(mock_service);
125
126        let result = service
127            .ready()
128            .await
129            .unwrap()
130            .call(
131                SupergraphRequest::fake_builder()
132                    .header("hello", "world")
133                    .build()
134                    .unwrap(),
135            )
136            .await?;
137        assert_eq!(
138            result.response.headers().get("hello"),
139            Some(&HeaderValue::from_static("world"))
140        );
141        Ok(())
142    }
143}