gateway_runtime/layers/
governance.rs1use crate::alloc::boxed::Box;
15use crate::{GatewayError, GatewayRequest, GatewayResponse};
16use core::task::{Context, Poll};
17use core::time::Duration;
18use http::StatusCode;
19use std::future::Future;
20use std::pin::Pin;
21use tower::Service;
22
23#[derive(Debug, Clone, Default)]
25pub struct GovernanceConfig {
26 pub connection_limit: Option<usize>,
28 pub request_timeout: Option<Duration>,
30 pub max_request_body_size: Option<usize>,
32 pub max_response_body_size: Option<usize>,
34 pub rate_limit_per_second: Option<u64>,
36 pub enable_load_shedding: bool,
38 pub retry_count: Option<usize>,
40}
41
42#[derive(Clone)]
48pub struct GatewayRetryPolicy {
49 remaining_attempts: usize,
50}
51
52impl GatewayRetryPolicy {
53 pub fn new(attempts: usize) -> Self {
54 Self {
55 remaining_attempts: attempts,
56 }
57 }
58}
59
60impl tower::retry::Policy<GatewayRequest, GatewayResponse, GatewayError> for GatewayRetryPolicy {
61 type Future = futures::future::Ready<()>;
62
63 fn retry(
64 &mut self,
65 _req: &mut GatewayRequest,
66 result: &mut Result<GatewayResponse, GatewayError>,
67 ) -> Option<Self::Future> {
68 if self.remaining_attempts == 0 {
69 return None;
70 }
71
72 match result {
73 Ok(resp) => {
74 if resp.status() == StatusCode::SERVICE_UNAVAILABLE
76 || resp.status() == StatusCode::BAD_GATEWAY
77 {
78 self.remaining_attempts -= 1;
79 Some(futures::future::ready(()))
80 } else {
81 None
82 }
83 }
84 Err(GatewayError::Upstream(_)) => {
85 self.remaining_attempts -= 1;
86 Some(futures::future::ready(()))
87 }
88 Err(_) => None,
89 }
90 }
91
92 fn clone_request(&mut self, req: &GatewayRequest) -> Option<GatewayRequest> {
93 let mut new_req = http::Request::builder()
98 .method(req.method().clone())
99 .uri(req.uri().clone())
100 .version(req.version());
101
102 for (k, v) in req.headers() {
103 new_req.headers_mut().unwrap().insert(k, v.clone());
104 }
105
106 new_req.body(req.body().clone()).ok()
107 }
108}
109
110#[derive(Clone)]
116pub struct BodyLimitLayer<S> {
117 inner: S,
118 max_req: Option<usize>,
119 max_resp: Option<usize>,
120}
121
122impl<S> BodyLimitLayer<S> {
123 pub fn new(inner: S, max_req: Option<usize>, max_resp: Option<usize>) -> Self {
124 Self {
125 inner,
126 max_req,
127 max_resp,
128 }
129 }
130}
131
132impl<S> Service<GatewayRequest> for BodyLimitLayer<S>
133where
134 S: Service<GatewayRequest, Response = GatewayResponse, Error = GatewayError>,
135 S::Future: Send + 'static,
136{
137 type Response = GatewayResponse;
138 type Error = GatewayError;
139 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
140
141 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
142 self.inner.poll_ready(cx)
143 }
144
145 fn call(&mut self, req: GatewayRequest) -> Self::Future {
146 if let Some(limit) = self.max_req {
148 if req.body().len() > limit {
149 return Box::pin(async move {
150 Err(GatewayError::Custom(
151 StatusCode::PAYLOAD_TOO_LARGE,
152 "Request body too large".into(),
153 ))
154 });
155 }
156 }
157
158 let max_resp = self.max_resp;
159 let fut = self.inner.call(req);
160
161 Box::pin(async move {
162 let resp = fut.await?;
163
164 if let Some(limit) = max_resp {
166 let (parts, body) = resp.into_parts();
169 let limited_body = LimitedBody {
170 inner: body,
171 remaining: limit,
172 };
173
174 let boxed_body =
176 http_body_util::combinators::UnsyncBoxBody::new(Box::new(limited_body));
177 Ok(http::Response::from_parts(parts, boxed_body))
178 } else {
179 Ok(resp)
180 }
181 })
182 }
183}
184
185struct LimitedBody<B> {
187 inner: B,
188 remaining: usize,
189}
190
191impl<B> http_body::Body for LimitedBody<B>
192where
193 B: http_body::Body<Data = crate::bytes::Bytes, Error = GatewayError> + Unpin,
194{
195 type Data = crate::bytes::Bytes;
196 type Error = GatewayError;
197
198 fn poll_frame(
199 mut self: Pin<&mut Self>,
200 cx: &mut Context<'_>,
201 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
202 let res = Pin::new(&mut self.inner).poll_frame(cx);
203 match res {
204 Poll::Ready(Some(Ok(frame))) => {
205 if let Some(data) = frame.data_ref() {
206 if data.len() > self.remaining {
207 return Poll::Ready(Some(Err(GatewayError::Custom(
208 StatusCode::PAYLOAD_TOO_LARGE,
209 "Response body limit exceeded".into(),
210 ))));
211 }
212 self.remaining -= data.len();
213 }
214 Poll::Ready(Some(Ok(frame)))
215 }
216 other => other,
217 }
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use http_body_util::{BodyExt, Full};
225 use tower::retry::Policy;
226
227 #[tokio::test]
228 async fn test_req_body_limit_exceeded() {
229 let service = tower::service_fn(|_| async {
230 Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
231 Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
232 )))
233 });
234
235 let mut layer = BodyLimitLayer::new(service, Some(5), None);
236 let req = http::Request::builder().body(vec![0u8; 10]).unwrap(); let err = layer.call(req).await.unwrap_err();
239 match err {
240 GatewayError::Custom(status, msg) => {
241 assert_eq!(status, StatusCode::PAYLOAD_TOO_LARGE);
242 assert_eq!(msg, "Request body too large");
243 }
244 _ => panic!("Unexpected error type"),
245 }
246 }
247
248 #[tokio::test]
249 async fn test_req_body_limit_ok() {
250 let service = tower::service_fn(|_| async {
251 Ok::<GatewayResponse, GatewayError>(http::Response::new(BodyExt::boxed_unsync(
252 Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
253 )))
254 });
255
256 let mut layer = BodyLimitLayer::new(service, Some(15), None);
257 let req = http::Request::builder().body(vec![0u8; 10]).unwrap();
258
259 assert!(layer.call(req).await.is_ok());
260 }
261
262 #[test]
263 fn test_retry_policy() {
264 {
266 let mut policy = GatewayRetryPolicy::new(1);
267 let mut req = http::Request::builder().body(Vec::new()).unwrap();
268 let mut resp_res: Result<GatewayResponse, GatewayError> = Ok(http::Response::builder()
269 .status(StatusCode::SERVICE_UNAVAILABLE)
270 .body(BodyExt::boxed_unsync(
271 Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
272 ))
273 .unwrap());
274
275 assert!(policy.retry(&mut req, &mut resp_res).is_some());
276 }
277
278 {
280 let mut policy = GatewayRetryPolicy::new(1);
281 let mut req = http::Request::builder().body(Vec::new()).unwrap();
282 let mut resp_ok_res: Result<GatewayResponse, GatewayError> =
283 Ok(http::Response::builder()
284 .status(StatusCode::OK)
285 .body(BodyExt::boxed_unsync(
286 Full::new(crate::bytes::Bytes::new()).map_err(|_| unreachable!()),
287 ))
288 .unwrap());
289 assert!(policy.retry(&mut req, &mut resp_ok_res).is_none());
290 }
291
292 {
294 let mut policy = GatewayRetryPolicy::new(1);
295 let mut req = http::Request::builder().body(Vec::new()).unwrap();
296 let mut err_res = Err(GatewayError::Upstream(tonic::Status::unavailable("fail")));
297 assert!(policy.retry(&mut req, &mut err_res).is_some());
298 }
299 }
300}