Skip to main content

tower_request_guard/
service.rs

1use crate::body::{check_content_length, is_bodyless_method};
2use crate::content_type::matches_content_type;
3use crate::guard::RequestGuard;
4use crate::headers::find_missing_header;
5use crate::response::violation_response;
6use crate::route::RouteGuardConfig;
7use crate::violation::{OnViolation, Violation, ViolationAction};
8use http::{Request, Response};
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13use tower_service::Service;
14
15/// Tower Service that validates requests before forwarding.
16pub struct RequestGuardService<S> {
17    pub(crate) inner: S,
18    pub(crate) guard: Arc<RequestGuard>,
19}
20
21impl<S: Clone> Clone for RequestGuardService<S> {
22    fn clone(&self) -> Self {
23        Self {
24            inner: self.inner.clone(),
25            guard: self.guard.clone(),
26        }
27    }
28}
29
30impl<S, B, ResBody> Service<Request<B>> for RequestGuardService<S>
31where
32    S: Service<Request<B>, Response = Response<ResBody>> + Clone + Send + 'static,
33    S::Future: Send,
34    S::Error: Send,
35    B: Send + 'static,
36    ResBody: From<String> + Send,
37{
38    type Response = Response<ResBody>;
39    type Error = S::Error;
40    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
41
42    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
43        self.inner.poll_ready(cx)
44    }
45
46    fn call(&mut self, req: Request<B>) -> Self::Future {
47        let guard = self.guard.clone();
48        let mut inner = self.inner.clone();
49        std::mem::swap(&mut self.inner, &mut inner);
50
51        Box::pin(async move {
52            // Resolve effective config (merge route overrides)
53            let effective = match req.extensions().get::<RouteGuardConfig>() {
54                Some(route_config) => route_config.merge_with(&guard.config),
55                None => guard.config.clone(),
56            };
57
58            let is_bodyless = is_bodyless_method(req.method());
59
60            // 1. Content-Type check (skip for bodyless methods)
61            if !is_bodyless {
62                if let Some(ref allowed) = effective.allowed_content_types {
63                    let content_type = req
64                        .headers()
65                        .get("content-type")
66                        .and_then(|v| v.to_str().ok())
67                        .unwrap_or("");
68                    if !matches_content_type(content_type, allowed) {
69                        let violation = Violation::InvalidContentType {
70                            received: content_type.to_string(),
71                            allowed: allowed.clone(),
72                        };
73                        if let Some(resp) = handle_violation(&violation, &guard.on_violation) {
74                            return Ok(resp.map(Into::into));
75                        }
76                    }
77                }
78            }
79
80            // 2. Required headers check
81            if !effective.required_headers.is_empty() {
82                if let Some(missing) =
83                    find_missing_header(req.headers(), &effective.required_headers)
84                {
85                    let violation = Violation::MissingHeader { header: missing };
86                    if let Some(resp) = handle_violation(&violation, &guard.on_violation) {
87                        return Ok(resp.map(Into::into));
88                    }
89                }
90            }
91
92            // 3. Content-Length pre-check (skip for bodyless methods)
93            if !is_bodyless {
94                if let Some(max) = effective.max_body_size {
95                    if let Some(received) = check_content_length(req.headers(), max) {
96                        let violation = Violation::BodyTooLarge { max, received };
97                        if let Some(resp) = handle_violation(&violation, &guard.on_violation) {
98                            return Ok(resp.map(Into::into));
99                        }
100                    }
101                }
102            }
103
104            // Steps 4-5 (JSON depth + full body buffering with size check)
105            // are handled by BufferedRequestGuardService when json feature is enabled.
106            //
107            // Body size enforcement strategy for non-buffered path:
108            // - Content-Length present: rejected in step 3 above (O(1), no body read)
109            // - Content-Length absent (chunked): not enforced here (would require
110            //   changing the body type to LimitedBody<B>, breaking the generic service
111            //   contract). For full stream limiting of chunked bodies, enable the
112            //   `json` feature (which buffers and checks) or combine with
113            //   tower-http::RequestBodyLimitLayer.
114
115            // 6. Timeout wrap
116            if let Some(timeout_duration) = effective.timeout {
117                match tokio::time::timeout(timeout_duration, inner.call(req)).await {
118                    Ok(result) => result,
119                    Err(_elapsed) => {
120                        let violation = Violation::RequestTimeout {
121                            timeout_ms: u64::try_from(timeout_duration.as_millis())
122                                .unwrap_or(u64::MAX),
123                        };
124                        let resp = handle_timeout_violation(&violation, &guard.on_violation);
125                        Ok(resp.map(Into::into))
126                    }
127                }
128            } else {
129                inner.call(req).await
130            }
131        })
132    }
133}
134
135/// Handle a pre-handler violation according to the OnViolation policy.
136/// Returns Some(response) if the request should be rejected, None if it should pass.
137pub(crate) fn handle_violation(
138    violation: &Violation,
139    policy: &OnViolation,
140) -> Option<Response<String>> {
141    match policy {
142        OnViolation::Reject => Some(violation_response(violation)),
143        OnViolation::LogAndPass => {
144            tracing::warn!(?violation, "request guard violation (log-and-pass)");
145            None
146        }
147        OnViolation::Custom(callback) => match callback(violation) {
148            ViolationAction::Reject => Some(violation_response(violation)),
149            ViolationAction::Pass => None,
150            ViolationAction::RespondWith(resp) => Some(resp),
151        },
152    }
153}
154
155/// Handle a timeout violation. LogAndPass is ignored for timeouts.
156pub(crate) fn handle_timeout_violation(
157    violation: &Violation,
158    policy: &OnViolation,
159) -> Response<String> {
160    match policy {
161        OnViolation::Custom(callback) => match callback(violation) {
162            ViolationAction::RespondWith(resp) => resp,
163            _ => violation_response(violation),
164        },
165        _ => violation_response(violation),
166    }
167}