Skip to main content

gateway_runtime/layers/
response.rs

1//! # Response Modification Layer
2//!
3//! This layer executes registered [ResponseModifier] hooks after the request has been
4//! processed by the inner service.
5//!
6//! This is typically used to inject headers (e.g., standard security headers),
7//! rewrite status codes (e.g., mapping gRPC metadata to HTTP status), or perform
8//! post-processing logging.
9
10use crate::alloc::boxed::Box;
11use crate::alloc::sync::Arc;
12use crate::alloc::vec::Vec;
13use crate::{GatewayRequest, GatewayResponse};
14use core::task::{Context, Poll};
15use std::future::Future;
16use std::pin::Pin;
17use tower::Service;
18
19/// A handler for modifying HTTP responses before they are sent.
20///
21/// Functions of this type are invoked after the inner service returns a successful response.
22/// They receive a read-only view of the original request (headers/metadata) and a mutable
23/// reference to the response, allowing for header injection or status modification.
24pub type ResponseModifier = Arc<dyn Fn(&GatewayRequest, &mut GatewayResponse) + Send + Sync>;
25
26/// A Tower middleware that applies response modifiers.
27#[derive(Clone)]
28pub struct ResponseLayer<S> {
29    inner: S,
30    modifiers: Vec<ResponseModifier>,
31}
32
33impl<S> ResponseLayer<S> {
34    /// Creates a new `ResponseLayer`.
35    ///
36    /// # Parameters
37    /// *   `inner`: The inner service.
38    /// *   `modifiers`: A list of functions that can mutate the response.
39    pub fn new(inner: S, modifiers: Vec<ResponseModifier>) -> Self {
40        Self { inner, modifiers }
41    }
42}
43
44impl<S> Service<GatewayRequest> for ResponseLayer<S>
45where
46    S: Service<GatewayRequest, Response = GatewayResponse>,
47    S::Future: Send + 'static,
48{
49    type Response = GatewayResponse;
50    type Error = S::Error;
51    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
52
53    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54        self.inner.poll_ready(cx)
55    }
56
57    fn call(&mut self, req: GatewayRequest) -> Self::Future {
58        // Capture request context (Method, URI, Headers) for use by the modifiers.
59        // The body is not available as it is consumed by the inner service.
60        let method = req.method().clone();
61        let uri = req.uri().clone();
62        let headers = req.headers().clone();
63
64        let modifiers = self.modifiers.clone();
65        let fut = self.inner.call(req);
66
67        Box::pin(async move {
68            let mut resp = fut.await?;
69
70            // Execute all modifiers on the successful response
71            if !modifiers.is_empty() {
72                let mut partial_req = http::Request::builder()
73                    .method(method)
74                    .uri(uri)
75                    .body(Vec::new())
76                    .unwrap();
77                *partial_req.headers_mut() = headers;
78
79                for modifier in &modifiers {
80                    modifier(&partial_req, &mut resp);
81                }
82            }
83
84            Ok(resp)
85        })
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::GatewayError;
93    use http::StatusCode;
94    use http_body_util::BodyExt;
95    use http_body_util::Full;
96
97    #[tokio::test]
98    async fn test_response_layer_modifies() {
99        let modifier: ResponseModifier = Arc::new(|_, resp| {
100            resp.headers_mut()
101                .insert("x-modified", "true".parse().unwrap());
102        });
103
104        let service = tower::service_fn(|_req: GatewayRequest| async {
105            Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
106                Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
107            )))
108        });
109
110        let mut layer = ResponseLayer::new(service, vec![modifier]);
111        let req = http::Request::builder()
112            .body(crate::alloc::vec::Vec::new())
113            .unwrap();
114
115        let resp = layer.call(req).await.unwrap();
116        assert_eq!(resp.headers().get("x-modified").unwrap(), "true");
117    }
118
119    #[tokio::test]
120    async fn test_response_layer_multiple_modifiers() {
121        let m1: ResponseModifier = Arc::new(|_, resp| {
122            resp.headers_mut().insert("h1", "v1".parse().unwrap());
123        });
124        let m2: ResponseModifier = Arc::new(|_, resp| {
125            resp.headers_mut().insert("h2", "v2".parse().unwrap());
126        });
127
128        let service = tower::service_fn(|_req: GatewayRequest| async {
129            Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
130                Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
131            )))
132        });
133
134        let mut layer = ResponseLayer::new(service, vec![m1, m2]);
135        let req = http::Request::builder()
136            .body(crate::alloc::vec::Vec::new())
137            .unwrap();
138
139        let resp = layer.call(req).await.unwrap();
140        assert_eq!(resp.headers().get("h1").unwrap(), "v1");
141        assert_eq!(resp.headers().get("h2").unwrap(), "v2");
142    }
143
144    #[tokio::test]
145    async fn test_response_layer_access_request_context() {
146        let modifier: ResponseModifier = Arc::new(|req, resp| {
147            if req.headers().contains_key("x-trigger") {
148                *resp.status_mut() = StatusCode::ACCEPTED;
149            }
150        });
151
152        let service = tower::service_fn(|_req: GatewayRequest| async {
153            Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
154                Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
155            )))
156        });
157
158        let mut layer = ResponseLayer::new(service, vec![modifier]);
159
160        // Request with trigger
161        let req = http::Request::builder()
162            .header("x-trigger", "1")
163            .body(crate::alloc::vec::Vec::new())
164            .unwrap();
165        let resp = layer.call(req).await.unwrap();
166        assert_eq!(resp.status(), StatusCode::ACCEPTED);
167
168        // Request without trigger
169        let req = http::Request::builder()
170            .body(crate::alloc::vec::Vec::new())
171            .unwrap();
172        let resp = layer.call(req).await.unwrap();
173        assert_eq!(resp.status(), StatusCode::OK);
174    }
175}