modkit_http/layers/
user_agent.rs1use crate::error::HttpError;
2use http::{HeaderValue, Request, Response};
3use std::task::{Context, Poll};
4use tower::{Layer, Service};
5
6#[derive(Clone)]
8pub struct UserAgentLayer {
9 user_agent: HeaderValue,
10}
11
12impl UserAgentLayer {
13 pub fn try_new(user_agent: impl AsRef<str>) -> Result<Self, HttpError> {
18 let user_agent =
19 HeaderValue::from_str(user_agent.as_ref()).map_err(HttpError::InvalidHeaderValue)?;
20 Ok(Self { user_agent })
21 }
22}
23
24impl<S> Layer<S> for UserAgentLayer {
25 type Service = UserAgentService<S>;
26
27 fn layer(&self, inner: S) -> Self::Service {
28 UserAgentService {
29 inner,
30 user_agent: self.user_agent.clone(),
31 }
32 }
33}
34
35#[derive(Clone)]
37pub struct UserAgentService<S> {
38 inner: S,
39 user_agent: HeaderValue,
40}
41
42impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for UserAgentService<S>
43where
44 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
45{
46 type Response = S::Response;
47 type Error = S::Error;
48 type Future = S::Future;
49
50 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
51 self.inner.poll_ready(cx)
52 }
53
54 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
55 if !req.headers().contains_key(http::header::USER_AGENT) {
57 req.headers_mut()
58 .insert(http::header::USER_AGENT, self.user_agent.clone());
59 }
60 self.inner.call(req)
61 }
62}
63
64#[cfg(test)]
65#[cfg_attr(coverage_nightly, coverage(off))]
66mod tests {
67 use super::*;
68 use bytes::Bytes;
69 use http::{Method, Request, Response, StatusCode};
70 use http_body_util::Full;
71 use tower::ServiceExt;
72
73 #[derive(Clone)]
75 struct CheckUaService {
76 expected_ua: HeaderValue,
77 }
78
79 impl Service<Request<Full<Bytes>>> for CheckUaService {
80 type Response = Response<Full<Bytes>>;
81 type Error = Box<dyn std::error::Error + Send + Sync>;
82 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
83
84 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85 Poll::Ready(Ok(()))
86 }
87
88 fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
89 let ua = req.headers().get(http::header::USER_AGENT);
90 assert_eq!(ua, Some(&self.expected_ua));
91 std::future::ready(Ok(Response::builder()
92 .status(StatusCode::OK)
93 .body(Full::new(Bytes::new()))
94 .unwrap()))
95 }
96 }
97
98 #[tokio::test]
99 async fn test_user_agent_added() {
100 let check_service = CheckUaService {
101 expected_ua: HeaderValue::from_static("test-agent/1.0"),
102 };
103
104 let layer = UserAgentLayer::try_new("test-agent/1.0").unwrap();
105 let mut service = layer.layer(check_service);
106
107 let req = Request::builder()
108 .method(Method::GET)
109 .uri("http://example.com")
110 .body(Full::new(Bytes::new()))
111 .unwrap();
112
113 service.ready().await.unwrap().call(req).await.unwrap();
114 }
115
116 #[tokio::test]
117 async fn test_user_agent_not_overwritten() {
118 let check_service = CheckUaService {
119 expected_ua: HeaderValue::from_static("custom-agent/2.0"),
120 };
121
122 let layer = UserAgentLayer::try_new("test-agent/1.0").unwrap();
123 let mut service = layer.layer(check_service);
124
125 let req = Request::builder()
126 .method(Method::GET)
127 .uri("http://example.com")
128 .header(http::header::USER_AGENT, "custom-agent/2.0")
129 .body(Full::new(Bytes::new()))
130 .unwrap();
131
132 service.ready().await.unwrap().call(req).await.unwrap();
133 }
134
135 #[test]
136 fn test_user_agent_layer_invalid_value() {
137 let result = UserAgentLayer::try_new("invalid\x00agent");
139 assert!(result.is_err());
140 }
141}