Skip to main content

gateway_runtime/layers/
governance.rs

1//! # Governance Layer
2//!
3//! This layer provides advanced governance and protection mechanisms for the gateway.
4//!
5//! ## Features
6//! *   **Global Concurrency Limits**: Restricts the total number of simultaneous connections/requests.
7//! *   **Request Timeouts**: Enforces a strict time limit on request processing.
8//! *   **Request Body Limits**: Rejects requests with payloads exceeding a configured size.
9//! *   **Response Body Limits**: Terminates responses that exceed a configured size (preventing large downloads).
10//! *   **Rate Limiting**: Token-bucket based request rate limiting (Requests Per Second).
11//! *   **Load Shedding**: Automatically rejects requests when the service is overloaded (not ready).
12//! *   **Retries**: Configurable retry policy for transient failures (e.g., 503, upstream errors).
13
14use crate::alloc::boxed::Box;
15use crate::{GatewayError, GatewayRequest, GatewayResponse};
16use core::task::{Context, Poll};
17use core::time::Duration;
18use http::StatusCode;
19use std::future::Future;
20use std::pin::Pin;
21use tower::Service;
22
23/// Configuration for the Governance Layer.
24#[derive(Debug, Clone, Default)]
25pub struct GovernanceConfig {
26    /// Maximum number of concurrent requests allowed.
27    pub connection_limit: Option<usize>,
28    /// Maximum duration allowed for a request to complete.
29    pub request_timeout: Option<Duration>,
30    /// Maximum size (in bytes) allowed for the request body.
31    pub max_request_body_size: Option<usize>,
32    /// Maximum size (in bytes) allowed for the response body.
33    pub max_response_body_size: Option<usize>,
34    /// Rate limit in requests per second.
35    pub rate_limit_per_second: Option<u64>,
36    /// Enable automatic load shedding (fail fast when overloaded).
37    pub enable_load_shedding: bool,
38    /// Number of retries for transient failures.
39    pub retry_count: Option<usize>,
40}
41
42/// A retry policy for the Gateway.
43///
44/// Retries the request if:
45/// 1. The error is an Upstream error (gRPC status).
46/// 2. The HTTP status is 503 (Service Unavailable) or 502 (Bad Gateway).
47#[derive(Clone)]
48pub struct GatewayRetryPolicy {
49    remaining_attempts: usize,
50}
51
52impl GatewayRetryPolicy {
53    pub fn new(attempts: usize) -> Self {
54        Self {
55            remaining_attempts: attempts,
56        }
57    }
58}
59
60impl tower::retry::Policy<GatewayRequest, GatewayResponse, GatewayError> for GatewayRetryPolicy {
61    type Future = futures::future::Ready<()>;
62
63    fn retry(
64        &mut self,
65        _req: &mut GatewayRequest,
66        result: &mut Result<GatewayResponse, GatewayError>,
67    ) -> Option<Self::Future> {
68        if self.remaining_attempts == 0 {
69            return None;
70        }
71
72        match result {
73            Ok(resp) => {
74                // Retry on server errors
75                if resp.status() == StatusCode::SERVICE_UNAVAILABLE
76                    || resp.status() == StatusCode::BAD_GATEWAY
77                {
78                    self.remaining_attempts -= 1;
79                    Some(futures::future::ready(()))
80                } else {
81                    None
82                }
83            }
84            Err(GatewayError::Upstream(_)) => {
85                self.remaining_attempts -= 1;
86                Some(futures::future::ready(()))
87            }
88            Err(_) => None,
89        }
90    }
91
92    fn clone_request(&mut self, req: &GatewayRequest) -> Option<GatewayRequest> {
93        // GatewayRequest is http::Request<Vec<u8>>. Vec<u8> is cloneable.
94        // Cloning the body is expensive but necessary for retries.
95        // Since we are buffering the body in VecBodyToVecService anyway, this is the cost of retries.
96
97        let mut new_req = http::Request::builder()
98            .method(req.method().clone())
99            .uri(req.uri().clone())
100            .version(req.version());
101
102        for (k, v) in req.headers() {
103            new_req.headers_mut().unwrap().insert(k, v.clone());
104        }
105
106        new_req.body(req.body().clone()).ok()
107    }
108}
109
110/// A layer that enforces body size limits on requests and responses.
111///
112/// Note: Timeout and Concurrency limits are applied using standard `tower` layers
113/// constructed in `Gateway::into_service`, but this layer handles the logic specifically
114/// for `Vec<u8>` request bodies and `BoxBody` response streams.
115#[derive(Clone)]
116pub struct BodyLimitLayer<S> {
117    inner: S,
118    max_req: Option<usize>,
119    max_resp: Option<usize>,
120}
121
122impl<S> BodyLimitLayer<S> {
123    pub fn new(inner: S, max_req: Option<usize>, max_resp: Option<usize>) -> Self {
124        Self {
125            inner,
126            max_req,
127            max_resp,
128        }
129    }
130}
131
132impl<S> Service<GatewayRequest> for BodyLimitLayer<S>
133where
134    S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>,
135    S::Future: Send + 'static,
136{
137    type Response = GatewayResponse;
138    type Error = GatewayError;
139    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
140
141    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
142        self.inner.poll_ready(cx)
143    }
144
145    fn call(&mut self, req: GatewayRequest) -> Self::Future {
146        // Enforce Request Body Limit
147        if let Some(limit) = self.max_req {
148            if req.body().len() > limit {
149                return Box::pin(async move {
150                    Err(GatewayError::Custom(
151                        StatusCode::PAYLOAD_TOO_LARGE,
152                        "Request body too large".into(),
153                    ))
154                });
155            }
156        }
157
158        let max_resp = self.max_resp;
159        let fut = self.inner.call(req);
160
161        Box::pin(async move {
162            let resp = fut.await?;
163
164            // Enforce Response Body Limit
165            if let Some(limit) = max_resp {
166                // We map the body to a LimitedBody wrapper.
167                // Since GatewayResponse uses UnsyncBoxBody, we need to wrap and re-box.
168                let (parts, body) = resp.into_parts();
169                let limited_body = LimitedBody {
170                    inner: body,
171                    remaining: limit,
172                };
173
174                // Re-box safely
175                let boxed_body =
176                    http_body_util::combinators::UnsyncBoxBody::new(Box::new(limited_body));
177                Ok(http::Response::from_parts(parts, boxed_body))
178            } else {
179                Ok(resp)
180            }
181        })
182    }
183}
184
185/// A wrapper body that enforces a maximum byte limit.
186struct LimitedBody<B> {
187    inner: B,
188    remaining: usize,
189}
190
191impl<B> http_body::Body for LimitedBody<B>
192where
193    B: http_body::Body<Data = crate::bytes::Bytes, Error = GatewayError> + Unpin,
194{
195    type Data = crate::bytes::Bytes;
196    type Error = GatewayError;
197
198    fn poll_frame(
199        mut self: Pin<&mut Self>,
200        cx: &mut Context<'_>,
201    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
202        let res = Pin::new(&mut self.inner).poll_frame(cx);
203        match res {
204            Poll::Ready(Some(Ok(frame))) => {
205                if let Some(data) = frame.data_ref() {
206                    if data.len() > self.remaining {
207                        return Poll::Ready(Some(Err(GatewayError::Custom(
208                            StatusCode::PAYLOAD_TOO_LARGE,
209                            "Response body limit exceeded".into(),
210                        ))));
211                    }
212                    self.remaining -= data.len();
213                }
214                Poll::Ready(Some(Ok(frame)))
215            }
216            other => other,
217        }
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use http_body_util::{BodyExt, Full};
225    use tower::retry::Policy;
226
227    #[tokio::test]
228    async fn test_req_body_limit_exceeded() {
229        let service = tower::service_fn(|_| async {
230            Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
231                Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
232            )))
233        });
234
235        let mut layer = BodyLimitLayer::new(service, Some(5), None);
236        let req = http::Request::builder().body(vec![0u8; 10]).unwrap(); // 10 bytes > 5
237
238        let err = layer.call(req).await.unwrap_err();
239        match err {
240            GatewayError::Custom(status, msg) => {
241                assert_eq!(status, StatusCode::PAYLOAD_TOO_LARGE);
242                assert_eq!(msg, "Request body too large");
243            }
244            _ => panic!("Unexpected error type"),
245        }
246    }
247
248    #[tokio::test]
249    async fn test_req_body_limit_ok() {
250        let service = tower::service_fn(|_| async {
251            Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
252                Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
253            )))
254        });
255
256        let mut layer = BodyLimitLayer::new(service, Some(15), None);
257        let req = http::Request::builder().body(vec![0u8; 10]).unwrap();
258
259        assert!(layer.call(req).await.is_ok());
260    }
261
262    #[test]
263    fn test_retry_policy() {
264        // Test case 1: Retry on Service Unavailable
265        {
266            let mut policy = GatewayRetryPolicy::new(1);
267            let mut req = http::Request::builder().body(Vec::new()).unwrap();
268            let mut resp_res: Result<GatewayResponse, GatewayError> = Ok(http::Response::builder()
269                .status(StatusCode::SERVICE_UNAVAILABLE)
270                .body(BodyExt::boxed_unsync(
271                    Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
272                ))
273                .unwrap());
274
275            assert!(policy.retry(&mut req, &mut resp_res).is_some());
276        }
277
278        // Test case 2: Should NOT retry on OK
279        {
280            let mut policy = GatewayRetryPolicy::new(1);
281            let mut req = http::Request::builder().body(Vec::new()).unwrap();
282            let mut resp_ok_res: Result<GatewayResponse, GatewayError> =
283                Ok(http::Response::builder()
284                    .status(StatusCode::OK)
285                    .body(BodyExt::boxed_unsync(
286                        Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
287                    ))
288                    .unwrap());
289            assert!(policy.retry(&mut req, &mut resp_ok_res).is_none());
290        }
291
292        // Test case 3: Should retry on Upstream error
293        {
294            let mut policy = GatewayRetryPolicy::new(1);
295            let mut req = http::Request::builder().body(Vec::new()).unwrap();
296            let mut err_res = Err(GatewayError::Upstream(tonic::Status::unavailable("fail")));
297            assert!(policy.retry(&mut req, &mut err_res).is_some());
298        }
299    }
300}