Skip to main content

tower_request_guard/
route.rs

1use crate::guard::GuardConfig;
2use http::Request;
3use std::task::{Context, Poll};
4use std::time::Duration;
5use tower_layer::Layer;
6use tower_service::Service;
7
8/// Per-route override configuration. Inserted into request extensions
9/// by the `route_guard` layer.
10#[derive(Debug, Clone, Default)]
11pub struct RouteGuardConfig {
12    pub(crate) max_body_size: Option<u64>,
13    pub(crate) timeout: Option<Duration>,
14    pub(crate) allowed_content_types: Option<Vec<String>>,
15    pub(crate) skip_headers: Vec<String>,
16    pub(crate) extra_required_headers: Vec<String>,
17    pub(crate) skip_all: bool,
18    #[cfg(feature = "json")]
19    pub(crate) max_json_depth: Option<u32>,
20}
21
22impl RouteGuardConfig {
23    /// Override the maximum body size for this route.
24    pub fn max_body_size(mut self, size: u64) -> Self {
25        self.max_body_size = Some(size);
26        self
27    }
28
29    /// Override the timeout duration for this route.
30    pub fn timeout(mut self, duration: Duration) -> Self {
31        self.timeout = Some(duration);
32        self
33    }
34
35    /// Override the allowed Content-Type list for this route.
36    pub fn allowed_content_types<I, S>(mut self, types: I) -> Self
37    where
38        I: IntoIterator<Item = S>,
39        S: Into<String>,
40    {
41        self.allowed_content_types = Some(types.into_iter().map(Into::into).collect());
42        self
43    }
44
45    /// Skip a globally-required header for this route.
46    pub fn skip_header(mut self, name: impl Into<String>) -> Self {
47        self.skip_headers.push(name.into());
48        self
49    }
50
51    /// Add an extra required header for this route.
52    pub fn require_header(mut self, name: impl Into<String>) -> Self {
53        self.extra_required_headers.push(name.into());
54        self
55    }
56
57    /// Skip all validations for this route.
58    pub fn skip_all(mut self) -> Self {
59        self.skip_all = true;
60        self
61    }
62
63    /// Override the maximum JSON depth for this route (requires `json` feature).
64    #[cfg(feature = "json")]
65    pub fn max_json_depth(mut self, depth: u32) -> Self {
66        self.max_json_depth = Some(depth);
67        self
68    }
69
70    /// Merge this route config with the global config.
71    /// Route values override globals; unset values inherit from global.
72    pub fn merge_with(&self, global: &GuardConfig) -> GuardConfig {
73        if self.skip_all {
74            return GuardConfig {
75                max_body_size: None,
76                timeout: None,
77                allowed_content_types: None,
78                required_headers: Vec::new(),
79                #[cfg(feature = "json")]
80                max_json_depth: None,
81            };
82        }
83
84        // Required headers: start with global, remove skipped, add extras
85        let mut required_headers = global.required_headers.clone();
86        required_headers.retain(|h| !self.skip_headers.iter().any(|s| s.eq_ignore_ascii_case(h)));
87        for extra in &self.extra_required_headers {
88            if !required_headers
89                .iter()
90                .any(|h| h.eq_ignore_ascii_case(extra))
91            {
92                required_headers.push(extra.clone());
93            }
94        }
95
96        GuardConfig {
97            max_body_size: self.max_body_size.or(global.max_body_size),
98            timeout: self.timeout.or(global.timeout),
99            allowed_content_types: self
100                .allowed_content_types
101                .clone()
102                .or_else(|| global.allowed_content_types.clone()),
103            required_headers,
104            #[cfg(feature = "json")]
105            max_json_depth: self.max_json_depth.or(global.max_json_depth),
106        }
107    }
108}
109
110/// Create a per-route guard override layer.
111/// The closure receives a `RouteGuardConfig` to configure route-specific overrides.
112pub fn route_guard<F>(f: F) -> RouteGuardLayer
113where
114    F: FnOnce(RouteGuardConfig) -> RouteGuardConfig,
115{
116    RouteGuardLayer(f(RouteGuardConfig::default()))
117}
118
119/// Layer that inserts RouteGuardConfig into request extensions.
120#[derive(Debug, Clone)]
121pub struct RouteGuardLayer(RouteGuardConfig);
122
123impl<S> Layer<S> for RouteGuardLayer {
124    type Service = RouteGuardInsertService<S>;
125
126    fn layer(&self, inner: S) -> Self::Service {
127        RouteGuardInsertService {
128            inner,
129            config: self.0.clone(),
130        }
131    }
132}
133
134/// Service that inserts RouteGuardConfig into request extensions.
135#[derive(Debug, Clone)]
136pub struct RouteGuardInsertService<S> {
137    inner: S,
138    config: RouteGuardConfig,
139}
140
141impl<S, B> Service<Request<B>> for RouteGuardInsertService<S>
142where
143    S: Service<Request<B>>,
144{
145    type Response = S::Response;
146    type Error = S::Error;
147    type Future = S::Future;
148
149    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150        self.inner.poll_ready(cx)
151    }
152
153    fn call(&mut self, mut req: Request<B>) -> Self::Future {
154        req.extensions_mut().insert(self.config.clone());
155        self.inner.call(req)
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::guard::GuardConfig;
163    use std::time::Duration;
164
165    fn base_config() -> GuardConfig {
166        GuardConfig {
167            max_body_size: Some(1024),
168            timeout: Some(Duration::from_secs(30)),
169            allowed_content_types: Some(vec!["application/json".into()]),
170            required_headers: vec!["Authorization".into(), "X-Request-Id".into()],
171            #[cfg(feature = "json")]
172            max_json_depth: Some(32),
173        }
174    }
175
176    #[test]
177    fn merge_overrides_numeric_values() {
178        let route = RouteGuardConfig {
179            max_body_size: Some(2048),
180            timeout: Some(Duration::from_secs(60)),
181            ..Default::default()
182        };
183        let merged = route.merge_with(&base_config());
184        assert_eq!(merged.max_body_size, Some(2048));
185        assert_eq!(merged.timeout, Some(Duration::from_secs(60)));
186    }
187
188    #[test]
189    fn merge_replaces_content_types() {
190        let route = RouteGuardConfig {
191            allowed_content_types: Some(vec!["multipart/form-data".into()]),
192            ..Default::default()
193        };
194        let merged = route.merge_with(&base_config());
195        assert_eq!(
196            merged.allowed_content_types,
197            Some(vec!["multipart/form-data".into()])
198        );
199    }
200
201    #[test]
202    fn merge_skip_header_removes() {
203        let route = RouteGuardConfig {
204            skip_headers: vec!["Authorization".into()],
205            ..Default::default()
206        };
207        let merged = route.merge_with(&base_config());
208        assert_eq!(merged.required_headers, vec!["X-Request-Id".to_string()]);
209    }
210
211    #[test]
212    fn merge_require_header_adds() {
213        let route = RouteGuardConfig {
214            extra_required_headers: vec!["X-Tenant-Id".into()],
215            ..Default::default()
216        };
217        let merged = route.merge_with(&base_config());
218        assert!(merged.required_headers.contains(&"X-Tenant-Id".to_string()));
219        assert!(merged
220            .required_headers
221            .contains(&"Authorization".to_string()));
222    }
223
224    #[test]
225    fn merge_skip_all_clears_everything() {
226        let route = RouteGuardConfig {
227            skip_all: true,
228            ..Default::default()
229        };
230        let merged = route.merge_with(&base_config());
231        assert_eq!(merged.max_body_size, None);
232        assert_eq!(merged.timeout, None);
233        assert!(merged.allowed_content_types.is_none());
234        assert!(merged.required_headers.is_empty());
235    }
236
237    #[test]
238    fn merge_inherits_unset_values() {
239        let route = RouteGuardConfig::default();
240        let merged = route.merge_with(&base_config());
241        assert_eq!(merged.max_body_size, Some(1024));
242        assert_eq!(merged.timeout, Some(Duration::from_secs(30)));
243    }
244}