Skip to main content

modo/middleware/
user_agent.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use axum::body::Body;
5use http::header::USER_AGENT;
6use http::{HeaderValue, Request};
7use tower::{Layer, Service};
8
9/// Default sanitization cap, in bytes.
10///
11/// Real-world `User-Agent` strings are typically 100–300 bytes. 512 leaves
12/// headroom for verbose UA brand reduction strings without permitting the
13/// pathological multi-kilobyte values an upstream HTTP server will otherwise
14/// accept.
15const DEFAULT_MAX_LEN: usize = 512;
16
17/// Tower layer that sanitizes the inbound `User-Agent` header in place.
18///
19/// Reads the request's `User-Agent` once, applies a small set of
20/// normalization rules (length cap on a UTF-8 char boundary, ASCII control
21/// stripping, whitespace collapsing, trimming), and rewrites the header value
22/// before delegating to the inner service. If the sanitized result is empty
23/// the header is removed entirely so downstream consumers see the same
24/// "missing" state they handle today.
25///
26/// Because the layer mutates the request header itself, every downstream
27/// reader — `ClientInfo`, the cookie session middleware, audit logging,
28/// fingerprint hashing — observes the cleaned value with no further plumbing.
29///
30/// # Layer ordering
31///
32/// Install `UserAgentLayer` **before** any layer or handler that reads
33/// `User-Agent`. In axum that means it must be added with a `.layer(...)`
34/// call that comes **after** (i.e. wraps) the consumer:
35///
36/// ```rust,no_run
37/// use axum::{Router, routing::get};
38/// use modo::middleware::UserAgentLayer;
39///
40/// let app: Router = Router::new()
41///     .route("/", get(|| async { "ok" }))
42///     .layer(UserAgentLayer::new());
43/// ```
44#[derive(Debug, Clone, Copy)]
45pub struct UserAgentLayer {
46    max_len: usize,
47}
48
49impl UserAgentLayer {
50    /// Create a layer with the default 512-byte sanitization cap.
51    pub fn new() -> Self {
52        Self {
53            max_len: DEFAULT_MAX_LEN,
54        }
55    }
56
57    /// Create a layer with a custom byte-length cap.
58    pub fn with_max_length(max_len: usize) -> Self {
59        Self { max_len }
60    }
61}
62
63impl Default for UserAgentLayer {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl<S> Layer<S> for UserAgentLayer {
70    type Service = UserAgentMiddleware<S>;
71
72    fn layer(&self, inner: S) -> Self::Service {
73        UserAgentMiddleware {
74            inner,
75            max_len: self.max_len,
76        }
77    }
78}
79
80/// Tower service produced by [`UserAgentLayer`].
81pub struct UserAgentMiddleware<S> {
82    inner: S,
83    max_len: usize,
84}
85
86impl<S: Clone> Clone for UserAgentMiddleware<S> {
87    fn clone(&self) -> Self {
88        Self {
89            inner: self.inner.clone(),
90            max_len: self.max_len,
91        }
92    }
93}
94
95impl<S, ReqBody> Service<Request<ReqBody>> for UserAgentMiddleware<S>
96where
97    S: Service<Request<ReqBody>, Response = http::Response<Body>> + Clone + Send + 'static,
98    S::Future: Send + 'static,
99    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
100    ReqBody: Send + 'static,
101{
102    type Response = http::Response<Body>;
103    type Error = S::Error;
104    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
105
106    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107        self.inner.poll_ready(cx)
108    }
109
110    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
111        let max_len = self.max_len;
112        let mut inner = self.inner.clone();
113        std::mem::swap(&mut self.inner, &mut inner);
114
115        Box::pin(async move {
116            // Snapshot the existing header value as an owned string so we can
117            // compare it to the sanitized output without juggling lifetimes
118            // against `headers_mut()`.
119            let raw = request
120                .headers()
121                .get(USER_AGENT)
122                .and_then(|v| v.to_str().ok())
123                .map(str::to_string);
124
125            if let Some(raw) = raw {
126                match sanitize_user_agent(&raw, max_len) {
127                    Some(clean) => {
128                        // Sanitization output is a subset of an already
129                        // visible-ASCII input, so `from_str` cannot fail.
130                        let value = HeaderValue::from_str(&clean)
131                            .expect("sanitized user-agent must be a valid header value");
132                        // `insert` replaces every existing entry, so a request
133                        // arriving with multiple `User-Agent` headers (rare but
134                        // valid HTTP) is normalized down to a single value.
135                        request.headers_mut().insert(USER_AGENT, value);
136                    }
137                    None => {
138                        request.headers_mut().remove(USER_AGENT);
139                    }
140                }
141            }
142
143            inner.call(request).await
144        })
145    }
146}
147
148/// Sanitize a raw `User-Agent` value.
149///
150/// 1. Truncate to `max_len` bytes, snapping down to the nearest UTF-8 char
151///    boundary so the result is always valid UTF-8.
152/// 2. Drop ASCII control characters.
153/// 3. Collapse runs of ASCII whitespace into a single space. Note that
154///    [`char::is_ascii_whitespace`] includes tab, newline, carriage return,
155///    and form-feed — those bytes never reach the layer's call site
156///    (`HeaderValue::to_str` rejects all but tab), but they are accepted
157///    here because the function is `pub(crate)` and may be called directly.
158/// 4. Trim leading and trailing whitespace.
159/// 5. Return `None` if the resulting string is empty.
160pub(crate) fn sanitize_user_agent(raw: &str, max_len: usize) -> Option<String> {
161    let mut end = raw.len().min(max_len);
162    while end > 0 && !raw.is_char_boundary(end) {
163        end -= 1;
164    }
165    let truncated = &raw[..end];
166
167    let mut out = String::with_capacity(truncated.len());
168    let mut prev_ws = false;
169    for c in truncated.chars() {
170        if c.is_ascii_whitespace() {
171            if !prev_ws {
172                out.push(' ');
173                prev_ws = true;
174            }
175            continue;
176        }
177        if c.is_ascii_control() {
178            // Intentionally do not reset `prev_ws`: a control char between
179            // two spaces still collapses to a single space.
180            continue;
181        }
182        out.push(c);
183        prev_ws = false;
184    }
185
186    let trimmed = out.trim();
187    if trimmed.is_empty() {
188        None
189    } else if trimmed.len() == out.len() {
190        Some(out)
191    } else {
192        Some(trimmed.to_string())
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use axum::body::Body;
200    use http::{Request, Response, StatusCode};
201    use std::convert::Infallible;
202    use tower::ServiceExt;
203
204    // ---------- sanitize_user_agent ----------
205
206    #[test]
207    fn passes_clean_short_ua() {
208        assert_eq!(
209            sanitize_user_agent("Mozilla/5.0", 512).as_deref(),
210            Some("Mozilla/5.0"),
211        );
212    }
213
214    #[test]
215    fn truncates_to_max_len_ascii() {
216        let raw: String = "A".repeat(1024);
217        let out = sanitize_user_agent(&raw, 64).unwrap();
218        assert_eq!(out.len(), 64);
219        assert!(out.chars().all(|c| c == 'A'));
220    }
221
222    #[test]
223    fn truncates_at_char_boundary_multibyte() {
224        // "ñ" is 2 bytes in UTF-8.
225        let raw: String = "ñ".repeat(20);
226        // Cap at an odd byte count that lands inside a multibyte char.
227        let out = sanitize_user_agent(&raw, 5).unwrap();
228        assert!(out.len() <= 5);
229        // Must still be valid UTF-8 (would have panicked above otherwise).
230        assert!(out.chars().all(|c| c == 'ñ'));
231        assert_eq!(out.len() % 2, 0);
232    }
233
234    #[test]
235    fn strips_ascii_control_chars() {
236        let out = sanitize_user_agent("Mozilla\x01/\x07X", 512).unwrap();
237        assert_eq!(out, "Mozilla/X");
238    }
239
240    #[test]
241    fn collapses_whitespace_runs() {
242        let out = sanitize_user_agent("Mozilla   \t  /5.0", 512).unwrap();
243        assert_eq!(out, "Mozilla /5.0");
244    }
245
246    #[test]
247    fn trims_leading_and_trailing_whitespace() {
248        assert_eq!(
249            sanitize_user_agent("   UA-Test   ", 512).as_deref(),
250            Some("UA-Test"),
251        );
252    }
253
254    #[test]
255    fn keeps_non_ascii_chars() {
256        // Non-ASCII passes through untouched (HeaderValue::to_str() at the
257        // call site already rejected anything that isn't visible ASCII, so in
258        // practice this is defensive — but the sanitizer itself is permissive
259        // for non-ASCII so it can be reused outside the layer).
260        assert_eq!(
261            sanitize_user_agent("клиент/1.0", 512).as_deref(),
262            Some("клиент/1.0"),
263        );
264    }
265
266    #[test]
267    fn returns_none_for_empty_input() {
268        assert!(sanitize_user_agent("", 512).is_none());
269    }
270
271    #[test]
272    fn returns_none_for_only_whitespace() {
273        assert!(sanitize_user_agent("   \t  ", 512).is_none());
274    }
275
276    #[test]
277    fn returns_none_for_only_controls() {
278        assert!(sanitize_user_agent("\x01\x02\x03", 512).is_none());
279    }
280
281    #[test]
282    fn zero_max_len_returns_none() {
283        assert!(sanitize_user_agent("Mozilla/5.0", 0).is_none());
284    }
285
286    // ---------- UserAgentLayer / Service ----------
287
288    async fn echo_ua(req: Request<Body>) -> Result<Response<Body>, Infallible> {
289        let ua = req
290            .headers()
291            .get(USER_AGENT)
292            .and_then(|v| v.to_str().ok())
293            .map(str::to_string)
294            .unwrap_or_else(|| "<absent>".to_string());
295        Ok(Response::new(Body::from(ua)))
296    }
297
298    async fn run(svc_layer: UserAgentLayer, req: Request<Body>) -> String {
299        let svc = svc_layer.layer(tower::service_fn(echo_ua));
300        let resp = svc.oneshot(req).await.unwrap();
301        assert_eq!(resp.status(), StatusCode::OK);
302        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
303            .await
304            .unwrap();
305        String::from_utf8(body.to_vec()).unwrap()
306    }
307
308    #[tokio::test]
309    async fn passes_clean_ua_unchanged() {
310        let req = Request::builder()
311            .header(USER_AGENT, "Mozilla/5.0")
312            .body(Body::empty())
313            .unwrap();
314        assert_eq!(run(UserAgentLayer::new(), req).await, "Mozilla/5.0");
315    }
316
317    #[tokio::test]
318    async fn truncates_long_ua() {
319        let long = "A".repeat(2000);
320        let req = Request::builder()
321            .header(USER_AGENT, long)
322            .body(Body::empty())
323            .unwrap();
324        let out = run(UserAgentLayer::with_max_length(64), req).await;
325        assert_eq!(out.len(), 64);
326        assert!(out.chars().all(|c| c == 'A'));
327    }
328
329    #[tokio::test]
330    async fn strips_controls_and_collapses_whitespace() {
331        // Tab is the only control char `HeaderValue::to_str()` accepts in
332        // addition to visible ASCII, so this is the realistic shape of an
333        // attacker-controlled header.
334        let req = Request::builder()
335            .header(USER_AGENT, "Mozilla/5.0\t\t (foo) bar")
336            .body(Body::empty())
337            .unwrap();
338        assert_eq!(
339            run(UserAgentLayer::new(), req).await,
340            "Mozilla/5.0 (foo) bar",
341        );
342    }
343
344    #[tokio::test]
345    async fn removes_header_when_only_whitespace() {
346        let req = Request::builder()
347            .header(USER_AGENT, "   \t  ")
348            .body(Body::empty())
349            .unwrap();
350        assert_eq!(run(UserAgentLayer::new(), req).await, "<absent>");
351    }
352
353    #[tokio::test]
354    async fn leaves_absent_header_alone() {
355        let req = Request::builder().body(Body::empty()).unwrap();
356        assert_eq!(run(UserAgentLayer::new(), req).await, "<absent>");
357    }
358
359    #[tokio::test]
360    async fn respects_with_max_length() {
361        let req = Request::builder()
362            .header(USER_AGENT, "abcdefghijklmnop")
363            .body(Body::empty())
364            .unwrap();
365        assert_eq!(
366            run(UserAgentLayer::with_max_length(8), req).await,
367            "abcdefgh"
368        );
369    }
370
371    #[tokio::test]
372    async fn normalizes_duplicate_user_agent_headers() {
373        // Multiple `User-Agent` headers are rare but valid HTTP. The layer
374        // should always reduce them to a single value so downstream consumers
375        // never see duplicates regardless of whether the first value needed
376        // sanitization.
377        let mut req = Request::builder().body(Body::empty()).unwrap();
378        req.headers_mut()
379            .append(USER_AGENT, "Mozilla/5.0".parse().unwrap());
380        req.headers_mut()
381            .append(USER_AGENT, "Other/1.0".parse().unwrap());
382
383        let svc = UserAgentLayer::new().layer(tower::service_fn(|req: Request<Body>| async move {
384            let count = req.headers().get_all(USER_AGENT).iter().count();
385            let first = req
386                .headers()
387                .get(USER_AGENT)
388                .and_then(|v| v.to_str().ok())
389                .unwrap_or("")
390                .to_string();
391            Ok::<_, Infallible>(Response::new(Body::from(format!("{count}|{first}"))))
392        }));
393        let resp = svc.oneshot(req).await.unwrap();
394        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
395            .await
396            .unwrap();
397        assert_eq!(body.as_ref(), b"1|Mozilla/5.0");
398    }
399}