Skip to main content

modkit_http/layers/
user_agent.rs

1use crate::error::HttpError;
2use http::{HeaderValue, Request, Response};
3use std::task::{Context, Poll};
4use tower::{Layer, Service};
5
6/// Tower layer that adds User-Agent header to all requests
7#[derive(Clone)]
8pub struct UserAgentLayer {
9    user_agent: HeaderValue,
10}
11
12impl UserAgentLayer {
13    /// Create a new `UserAgentLayer` with the specified user agent string
14    ///
15    /// # Errors
16    /// Returns `HttpError::InvalidHeaderValue` if the user agent string is not valid
17    pub fn try_new(user_agent: impl AsRef<str>) -> Result<Self, HttpError> {
18        let user_agent =
19            HeaderValue::from_str(user_agent.as_ref()).map_err(HttpError::InvalidHeaderValue)?;
20        Ok(Self { user_agent })
21    }
22}
23
24impl<S> Layer<S> for UserAgentLayer {
25    type Service = UserAgentService<S>;
26
27    fn layer(&self, inner: S) -> Self::Service {
28        UserAgentService {
29            inner,
30            user_agent: self.user_agent.clone(),
31        }
32    }
33}
34
35/// Service that adds User-Agent header to requests
36#[derive(Clone)]
37pub struct UserAgentService<S> {
38    inner: S,
39    user_agent: HeaderValue,
40}
41
42impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for UserAgentService<S>
43where
44    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
45{
46    type Response = S::Response;
47    type Error = S::Error;
48    type Future = S::Future;
49
50    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
51        self.inner.poll_ready(cx)
52    }
53
54    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
55        // Only add User-Agent if not already present
56        if !req.headers().contains_key(http::header::USER_AGENT) {
57            req.headers_mut()
58                .insert(http::header::USER_AGENT, self.user_agent.clone());
59        }
60        self.inner.call(req)
61    }
62}
63
64#[cfg(test)]
65#[cfg_attr(coverage_nightly, coverage(off))]
66mod tests {
67    use super::*;
68    use bytes::Bytes;
69    use http::{Method, Request, Response, StatusCode};
70    use http_body_util::Full;
71    use tower::ServiceExt;
72
73    /// Test service that asserts the User-Agent header matches the expected value.
74    #[derive(Clone)]
75    struct CheckUaService {
76        expected_ua: HeaderValue,
77    }
78
79    impl Service<Request<Full<Bytes>>> for CheckUaService {
80        type Response = Response<Full<Bytes>>;
81        type Error = Box<dyn std::error::Error + Send + Sync>;
82        type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
83
84        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85            Poll::Ready(Ok(()))
86        }
87
88        fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
89            let ua = req.headers().get(http::header::USER_AGENT);
90            assert_eq!(ua, Some(&self.expected_ua));
91            std::future::ready(Ok(Response::builder()
92                .status(StatusCode::OK)
93                .body(Full::new(Bytes::new()))
94                .unwrap()))
95        }
96    }
97
98    #[tokio::test]
99    async fn test_user_agent_added() {
100        let check_service = CheckUaService {
101            expected_ua: HeaderValue::from_static("test-agent/1.0"),
102        };
103
104        let layer = UserAgentLayer::try_new("test-agent/1.0").unwrap();
105        let mut service = layer.layer(check_service);
106
107        let req = Request::builder()
108            .method(Method::GET)
109            .uri("http://example.com")
110            .body(Full::new(Bytes::new()))
111            .unwrap();
112
113        service.ready().await.unwrap().call(req).await.unwrap();
114    }
115
116    #[tokio::test]
117    async fn test_user_agent_not_overwritten() {
118        let check_service = CheckUaService {
119            expected_ua: HeaderValue::from_static("custom-agent/2.0"),
120        };
121
122        let layer = UserAgentLayer::try_new("test-agent/1.0").unwrap();
123        let mut service = layer.layer(check_service);
124
125        let req = Request::builder()
126            .method(Method::GET)
127            .uri("http://example.com")
128            .header(http::header::USER_AGENT, "custom-agent/2.0")
129            .body(Full::new(Bytes::new()))
130            .unwrap();
131
132        service.ready().await.unwrap().call(req).await.unwrap();
133    }
134
135    #[test]
136    fn test_user_agent_layer_invalid_value() {
137        // Control characters are invalid in header values
138        let result = UserAgentLayer::try_new("invalid\x00agent");
139        assert!(result.is_err());
140    }
141}