apollo_router/layers/
map_future_with_request_data.rs1use std::future::Future;
6use std::task::Context;
7use std::task::Poll;
8
9use tower::Layer;
10use tower::Service;
11
12#[derive(Clone)]
14pub struct MapFutureWithRequestDataLayer<RF, MF> {
15 req_fn: RF,
16 map_fn: MF,
17}
18
19impl<RF, MF> MapFutureWithRequestDataLayer<RF, MF> {
20 pub fn new(req_fn: RF, map_fn: MF) -> Self {
22 Self { req_fn, map_fn }
23 }
24}
25
26impl<S, RF, MF> Layer<S> for MapFutureWithRequestDataLayer<RF, MF>
27where
28 RF: Clone,
29 MF: Clone,
30{
31 type Service = MapFutureWithRequestDataService<S, RF, MF>;
32
33 fn layer(&self, inner: S) -> Self::Service {
34 MapFutureWithRequestDataService::new(inner, self.req_fn.clone(), self.map_fn.clone())
35 }
36}
37
38#[derive(Clone)]
40pub struct MapFutureWithRequestDataService<S, RF, MF> {
41 inner: S,
42 req_fn: RF,
43 map_fn: MF,
44}
45
46impl<S, RF, MF> MapFutureWithRequestDataService<S, RF, MF> {
47 pub fn new(inner: S, req_fn: RF, map_fn: MF) -> MapFutureWithRequestDataService<S, RF, MF>
49 where
50 RF: Clone,
51 MF: Clone,
52 {
53 MapFutureWithRequestDataService {
54 inner,
55 req_fn,
56 map_fn,
57 }
58 }
59}
60
61impl<R, S, MF, RF, T, E, Fut, ReqData> Service<R> for MapFutureWithRequestDataService<S, RF, MF>
62where
63 S: Service<R>,
64 RF: FnMut(&R) -> ReqData,
65 MF: FnMut(ReqData, S::Future) -> Fut,
66 E: From<S::Error>,
67 Fut: Future<Output = Result<T, E>>,
68{
69 type Response = T;
70 type Error = E;
71 type Future = Fut;
72
73 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
74 self.inner.poll_ready(cx).map_err(From::from)
75 }
76
77 fn call(&mut self, req: R) -> Self::Future {
78 let data = (self.req_fn)(&req);
79 (self.map_fn)(data, self.inner.call(req))
80 }
81}
82
83#[cfg(test)]
84mod test {
85 use http::HeaderValue;
86 use tower::BoxError;
87 use tower::Service;
88 use tower::ServiceBuilder;
89 use tower::ServiceExt;
90
91 use crate::layers::ServiceBuilderExt;
92 use crate::plugin::test::MockSupergraphService;
93 use crate::services::SupergraphRequest;
94 use crate::services::SupergraphResponse;
95
96 #[tokio::test]
97 async fn test_layer() -> Result<(), BoxError> {
98 let mut mock_service = MockSupergraphService::new();
99 mock_service
100 .expect_call()
101 .once()
102 .returning(|_| Ok(SupergraphResponse::fake_builder().build().unwrap()));
103
104 let mut service = ServiceBuilder::new()
105 .map_future_with_request_data(
106 |req: &SupergraphRequest| {
107 req.supergraph_request
108 .headers()
109 .get("hello")
110 .cloned()
111 .unwrap()
112 },
113 |value: HeaderValue, resp| async move {
114 let resp: Result<SupergraphResponse, BoxError> = resp.await;
115 resp.map(|mut response| {
116 response
117 .response
118 .headers_mut()
119 .insert("hello", value.clone());
120 response
121 })
122 },
123 )
124 .service(mock_service);
125
126 let result = service
127 .ready()
128 .await
129 .unwrap()
130 .call(
131 SupergraphRequest::fake_builder()
132 .header("hello", "world")
133 .build()
134 .unwrap(),
135 )
136 .await?;
137 assert_eq!(
138 result.response.headers().get("hello"),
139 Some(&HeaderValue::from_static("world"))
140 );
141 Ok(())
142 }
143}