modo/middleware/
user_agent.rs1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use axum::body::Body;
5use http::header::USER_AGENT;
6use http::{HeaderValue, Request};
7use tower::{Layer, Service};
8
9const DEFAULT_MAX_LEN: usize = 512;
16
17#[derive(Debug, Clone, Copy)]
45pub struct UserAgentLayer {
46 max_len: usize,
47}
48
49impl UserAgentLayer {
50 pub fn new() -> Self {
52 Self {
53 max_len: DEFAULT_MAX_LEN,
54 }
55 }
56
57 pub fn with_max_length(max_len: usize) -> Self {
59 Self { max_len }
60 }
61}
62
63impl Default for UserAgentLayer {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl<S> Layer<S> for UserAgentLayer {
70 type Service = UserAgentMiddleware<S>;
71
72 fn layer(&self, inner: S) -> Self::Service {
73 UserAgentMiddleware {
74 inner,
75 max_len: self.max_len,
76 }
77 }
78}
79
80pub struct UserAgentMiddleware<S> {
82 inner: S,
83 max_len: usize,
84}
85
86impl<S: Clone> Clone for UserAgentMiddleware<S> {
87 fn clone(&self) -> Self {
88 Self {
89 inner: self.inner.clone(),
90 max_len: self.max_len,
91 }
92 }
93}
94
95impl<S, ReqBody> Service<Request<ReqBody>> for UserAgentMiddleware<S>
96where
97 S: Service<Request<ReqBody>, Response = http::Response<Body>> + Clone + Send + 'static,
98 S::Future: Send + 'static,
99 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
100 ReqBody: Send + 'static,
101{
102 type Response = http::Response<Body>;
103 type Error = S::Error;
104 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
105
106 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107 self.inner.poll_ready(cx)
108 }
109
110 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
111 let max_len = self.max_len;
112 let mut inner = self.inner.clone();
113 std::mem::swap(&mut self.inner, &mut inner);
114
115 Box::pin(async move {
116 let raw = request
120 .headers()
121 .get(USER_AGENT)
122 .and_then(|v| v.to_str().ok())
123 .map(str::to_string);
124
125 if let Some(raw) = raw {
126 match sanitize_user_agent(&raw, max_len) {
127 Some(clean) => {
128 let value = HeaderValue::from_str(&clean)
131 .expect("sanitized user-agent must be a valid header value");
132 request.headers_mut().insert(USER_AGENT, value);
136 }
137 None => {
138 request.headers_mut().remove(USER_AGENT);
139 }
140 }
141 }
142
143 inner.call(request).await
144 })
145 }
146}
147
148pub(crate) fn sanitize_user_agent(raw: &str, max_len: usize) -> Option<String> {
161 let mut end = raw.len().min(max_len);
162 while end > 0 && !raw.is_char_boundary(end) {
163 end -= 1;
164 }
165 let truncated = &raw[..end];
166
167 let mut out = String::with_capacity(truncated.len());
168 let mut prev_ws = false;
169 for c in truncated.chars() {
170 if c.is_ascii_whitespace() {
171 if !prev_ws {
172 out.push(' ');
173 prev_ws = true;
174 }
175 continue;
176 }
177 if c.is_ascii_control() {
178 continue;
181 }
182 out.push(c);
183 prev_ws = false;
184 }
185
186 let trimmed = out.trim();
187 if trimmed.is_empty() {
188 None
189 } else if trimmed.len() == out.len() {
190 Some(out)
191 } else {
192 Some(trimmed.to_string())
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use axum::body::Body;
200 use http::{Request, Response, StatusCode};
201 use std::convert::Infallible;
202 use tower::ServiceExt;
203
204 #[test]
207 fn passes_clean_short_ua() {
208 assert_eq!(
209 sanitize_user_agent("Mozilla/5.0", 512).as_deref(),
210 Some("Mozilla/5.0"),
211 );
212 }
213
214 #[test]
215 fn truncates_to_max_len_ascii() {
216 let raw: String = "A".repeat(1024);
217 let out = sanitize_user_agent(&raw, 64).unwrap();
218 assert_eq!(out.len(), 64);
219 assert!(out.chars().all(|c| c == 'A'));
220 }
221
222 #[test]
223 fn truncates_at_char_boundary_multibyte() {
224 let raw: String = "ñ".repeat(20);
226 let out = sanitize_user_agent(&raw, 5).unwrap();
228 assert!(out.len() <= 5);
229 assert!(out.chars().all(|c| c == 'ñ'));
231 assert_eq!(out.len() % 2, 0);
232 }
233
234 #[test]
235 fn strips_ascii_control_chars() {
236 let out = sanitize_user_agent("Mozilla\x01/\x07X", 512).unwrap();
237 assert_eq!(out, "Mozilla/X");
238 }
239
240 #[test]
241 fn collapses_whitespace_runs() {
242 let out = sanitize_user_agent("Mozilla \t /5.0", 512).unwrap();
243 assert_eq!(out, "Mozilla /5.0");
244 }
245
246 #[test]
247 fn trims_leading_and_trailing_whitespace() {
248 assert_eq!(
249 sanitize_user_agent(" UA-Test ", 512).as_deref(),
250 Some("UA-Test"),
251 );
252 }
253
254 #[test]
255 fn keeps_non_ascii_chars() {
256 assert_eq!(
261 sanitize_user_agent("клиент/1.0", 512).as_deref(),
262 Some("клиент/1.0"),
263 );
264 }
265
266 #[test]
267 fn returns_none_for_empty_input() {
268 assert!(sanitize_user_agent("", 512).is_none());
269 }
270
271 #[test]
272 fn returns_none_for_only_whitespace() {
273 assert!(sanitize_user_agent(" \t ", 512).is_none());
274 }
275
276 #[test]
277 fn returns_none_for_only_controls() {
278 assert!(sanitize_user_agent("\x01\x02\x03", 512).is_none());
279 }
280
281 #[test]
282 fn zero_max_len_returns_none() {
283 assert!(sanitize_user_agent("Mozilla/5.0", 0).is_none());
284 }
285
286 async fn echo_ua(req: Request<Body>) -> Result<Response<Body>, Infallible> {
289 let ua = req
290 .headers()
291 .get(USER_AGENT)
292 .and_then(|v| v.to_str().ok())
293 .map(str::to_string)
294 .unwrap_or_else(|| "<absent>".to_string());
295 Ok(Response::new(Body::from(ua)))
296 }
297
298 async fn run(svc_layer: UserAgentLayer, req: Request<Body>) -> String {
299 let svc = svc_layer.layer(tower::service_fn(echo_ua));
300 let resp = svc.oneshot(req).await.unwrap();
301 assert_eq!(resp.status(), StatusCode::OK);
302 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
303 .await
304 .unwrap();
305 String::from_utf8(body.to_vec()).unwrap()
306 }
307
308 #[tokio::test]
309 async fn passes_clean_ua_unchanged() {
310 let req = Request::builder()
311 .header(USER_AGENT, "Mozilla/5.0")
312 .body(Body::empty())
313 .unwrap();
314 assert_eq!(run(UserAgentLayer::new(), req).await, "Mozilla/5.0");
315 }
316
317 #[tokio::test]
318 async fn truncates_long_ua() {
319 let long = "A".repeat(2000);
320 let req = Request::builder()
321 .header(USER_AGENT, long)
322 .body(Body::empty())
323 .unwrap();
324 let out = run(UserAgentLayer::with_max_length(64), req).await;
325 assert_eq!(out.len(), 64);
326 assert!(out.chars().all(|c| c == 'A'));
327 }
328
329 #[tokio::test]
330 async fn strips_controls_and_collapses_whitespace() {
331 let req = Request::builder()
335 .header(USER_AGENT, "Mozilla/5.0\t\t (foo) bar")
336 .body(Body::empty())
337 .unwrap();
338 assert_eq!(
339 run(UserAgentLayer::new(), req).await,
340 "Mozilla/5.0 (foo) bar",
341 );
342 }
343
344 #[tokio::test]
345 async fn removes_header_when_only_whitespace() {
346 let req = Request::builder()
347 .header(USER_AGENT, " \t ")
348 .body(Body::empty())
349 .unwrap();
350 assert_eq!(run(UserAgentLayer::new(), req).await, "<absent>");
351 }
352
353 #[tokio::test]
354 async fn leaves_absent_header_alone() {
355 let req = Request::builder().body(Body::empty()).unwrap();
356 assert_eq!(run(UserAgentLayer::new(), req).await, "<absent>");
357 }
358
359 #[tokio::test]
360 async fn respects_with_max_length() {
361 let req = Request::builder()
362 .header(USER_AGENT, "abcdefghijklmnop")
363 .body(Body::empty())
364 .unwrap();
365 assert_eq!(
366 run(UserAgentLayer::with_max_length(8), req).await,
367 "abcdefgh"
368 );
369 }
370
371 #[tokio::test]
372 async fn normalizes_duplicate_user_agent_headers() {
373 let mut req = Request::builder().body(Body::empty()).unwrap();
378 req.headers_mut()
379 .append(USER_AGENT, "Mozilla/5.0".parse().unwrap());
380 req.headers_mut()
381 .append(USER_AGENT, "Other/1.0".parse().unwrap());
382
383 let svc = UserAgentLayer::new().layer(tower::service_fn(|req: Request<Body>| async move {
384 let count = req.headers().get_all(USER_AGENT).iter().count();
385 let first = req
386 .headers()
387 .get(USER_AGENT)
388 .and_then(|v| v.to_str().ok())
389 .unwrap_or("")
390 .to_string();
391 Ok::<_, Infallible>(Response::new(Body::from(format!("{count}|{first}"))))
392 }));
393 let resp = svc.oneshot(req).await.unwrap();
394 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
395 .await
396 .unwrap();
397 assert_eq!(body.as_ref(), b"1|Mozilla/5.0");
398 }
399}