gateway_runtime/layers/
response.rs1use crate::alloc::boxed::Box;
11use crate::alloc::sync::Arc;
12use crate::alloc::vec::Vec;
13use crate::{GatewayRequest, GatewayResponse};
14use core::task::{Context, Poll};
15use std::future::Future;
16use std::pin::Pin;
17use tower::Service;
18
19pub type ResponseModifier = Arc<dyn Fn(&GatewayRequest, &mut GatewayResponse) + Send + Sync>;
25
26#[derive(Clone)]
28pub struct ResponseLayer<S> {
29 inner: S,
30 modifiers: Vec<ResponseModifier>,
31}
32
33impl<S> ResponseLayer<S> {
34 pub fn new(inner: S, modifiers: Vec<ResponseModifier>) -> Self {
40 Self { inner, modifiers }
41 }
42}
43
44impl<S> Service<GatewayRequest> for ResponseLayer<S>
45where
46 S: Service<GatewayRequest, Response = GatewayResponse>,
47 S::Future: Send + 'static,
48{
49 type Response = GatewayResponse;
50 type Error = S::Error;
51 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
52
53 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54 self.inner.poll_ready(cx)
55 }
56
57 fn call(&mut self, req: GatewayRequest) -> Self::Future {
58 let method = req.method().clone();
61 let uri = req.uri().clone();
62 let headers = req.headers().clone();
63
64 let modifiers = self.modifiers.clone();
65 let fut = self.inner.call(req);
66
67 Box::pin(async move {
68 let mut resp = fut.await?;
69
70 if !modifiers.is_empty() {
72 let mut partial_req = http::Request::builder()
73 .method(method)
74 .uri(uri)
75 .body(Vec::new())
76 .unwrap();
77 *partial_req.headers_mut() = headers;
78
79 for modifier in &modifiers {
80 modifier(&partial_req, &mut resp);
81 }
82 }
83
84 Ok(resp)
85 })
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use crate::GatewayError;
93 use http::StatusCode;
94 use http_body_util::BodyExt;
95 use http_body_util::Full;
96
97 #[tokio::test]
98 async fn test_response_layer_modifies() {
99 let modifier: ResponseModifier = Arc::new(|_, resp| {
100 resp.headers_mut()
101 .insert("x-modified", "true".parse().unwrap());
102 });
103
104 let service = tower::service_fn(|_req: GatewayRequest| async {
105 Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
106 Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
107 )))
108 });
109
110 let mut layer = ResponseLayer::new(service, vec![modifier]);
111 let req = http::Request::builder()
112 .body(crate::alloc::vec::Vec::new())
113 .unwrap();
114
115 let resp = layer.call(req).await.unwrap();
116 assert_eq!(resp.headers().get("x-modified").unwrap(), "true");
117 }
118
119 #[tokio::test]
120 async fn test_response_layer_multiple_modifiers() {
121 let m1: ResponseModifier = Arc::new(|_, resp| {
122 resp.headers_mut().insert("h1", "v1".parse().unwrap());
123 });
124 let m2: ResponseModifier = Arc::new(|_, resp| {
125 resp.headers_mut().insert("h2", "v2".parse().unwrap());
126 });
127
128 let service = tower::service_fn(|_req: GatewayRequest| async {
129 Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
130 Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
131 )))
132 });
133
134 let mut layer = ResponseLayer::new(service, vec![m1, m2]);
135 let req = http::Request::builder()
136 .body(crate::alloc::vec::Vec::new())
137 .unwrap();
138
139 let resp = layer.call(req).await.unwrap();
140 assert_eq!(resp.headers().get("h1").unwrap(), "v1");
141 assert_eq!(resp.headers().get("h2").unwrap(), "v2");
142 }
143
144 #[tokio::test]
145 async fn test_response_layer_access_request_context() {
146 let modifier: ResponseModifier = Arc::new(|req, resp| {
147 if req.headers().contains_key("x-trigger") {
148 *resp.status_mut() = StatusCode::ACCEPTED;
149 }
150 });
151
152 let service = tower::service_fn(|_req: GatewayRequest| async {
153 Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
154 Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
155 )))
156 });
157
158 let mut layer = ResponseLayer::new(service, vec![modifier]);
159
160 let req = http::Request::builder()
162 .header("x-trigger", "1")
163 .body(crate::alloc::vec::Vec::new())
164 .unwrap();
165 let resp = layer.call(req).await.unwrap();
166 assert_eq!(resp.status(), StatusCode::ACCEPTED);
167
168 let req = http::Request::builder()
170 .body(crate::alloc::vec::Vec::new())
171 .unwrap();
172 let resp = layer.call(req).await.unwrap();
173 assert_eq!(resp.status(), StatusCode::OK);
174 }
175}