rama_http/layer/
ua.rs

1//! User-Agent (see also `rama-ua`) http layer support
2//!
3//! # Example
4//!
5//! ```
6//! use rama_http::{
7//!     service::client::HttpClientExt, IntoResponse, Request, Response, StatusCode,
8//!     layer::ua::{PlatformKind, UserAgent, UserAgentClassifierLayer, UserAgentKind, UserAgentInfo},
9//! };
10//! use rama_core::{Context, Layer, service::service_fn};
11//! use std::convert::Infallible;
12//!
13//! const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
14//!
15//! async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
16//!     let ua: &UserAgent = ctx.get().unwrap();
17//!
18//!     assert_eq!(ua.header_str(), UA);
19//!     assert_eq!(ua.info(), Some(UserAgentInfo{ kind: UserAgentKind::Chromium, version: Some(124) }));
20//!     assert_eq!(ua.platform(), Some(PlatformKind::Windows));
21//!
22//!     Ok(StatusCode::OK.into_response())
23//! }
24//!
25//! # #[tokio::main]
26//! # async fn main() {
27//! let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
28//!
29//! let _ = service
30//!     .get("http://www.example.com")
31//!     .typed_header(rama_http_types::headers::UserAgent::from_static(UA))
32//!     .send(Context::default())
33//!     .await
34//!     .unwrap();
35//! # }
36//! ```
37
38use crate::{
39    HeaderName, Request,
40    headers::{self, HeaderMapExt},
41};
42use rama_core::{Context, Layer, Service};
43use rama_utils::macros::define_inner_service_accessors;
44use std::fmt::{self, Debug};
45
46pub use rama_ua::{
47    DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind,
48    UserAgentOverwrites,
49};
50
51/// A [`Service`] that classifies the [`UserAgent`] of incoming [`Request`]s.
52///
53/// The [`Extensions`] of the [`Context`] is updated with the [`UserAgent`]
54/// if the [`Request`] contains a valid [`UserAgent`] header.
55///
56/// [`Extensions`]: rama_core::context::Extensions
57/// [`Context`]: rama_core::Context
58pub struct UserAgentClassifier<S> {
59    inner: S,
60    overwrite_header: Option<HeaderName>,
61}
62
63impl<S> UserAgentClassifier<S> {
64    /// Create a new [`UserAgentClassifier`] [`Service`].
65    pub const fn new(inner: S, overwrite_header: Option<HeaderName>) -> Self {
66        Self {
67            inner,
68            overwrite_header,
69        }
70    }
71
72    define_inner_service_accessors!();
73}
74
75impl<S> Debug for UserAgentClassifier<S>
76where
77    S: Debug,
78{
79    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80        f.debug_struct("UserAgentClassifier")
81            .field("inner", &self.inner)
82            .finish()
83    }
84}
85
86impl<S> Clone for UserAgentClassifier<S>
87where
88    S: Clone,
89{
90    fn clone(&self) -> Self {
91        Self {
92            inner: self.inner.clone(),
93            overwrite_header: self.overwrite_header.clone(),
94        }
95    }
96}
97
98impl<S> Default for UserAgentClassifier<S>
99where
100    S: Default,
101{
102    fn default() -> Self {
103        Self {
104            inner: S::default(),
105            overwrite_header: None,
106        }
107    }
108}
109
110impl<S, State, Body> Service<State, Request<Body>> for UserAgentClassifier<S>
111where
112    S: Service<State, Request<Body>>,
113    State: Clone + Send + Sync + 'static,
114{
115    type Response = S::Response;
116    type Error = S::Error;
117
118    fn serve(
119        &self,
120        mut ctx: Context<State>,
121        req: Request<Body>,
122    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
123        let overwrites = self
124            .overwrite_header
125            .as_ref()
126            .and_then(|header| req.headers().get(header))
127            .map(|header| header.as_bytes())
128            .and_then(|value| serde_html_form::from_bytes::<UserAgentOverwrites>(value).ok());
129
130        let mut user_agent = overwrites
131            .as_ref()
132            .and_then(|o| o.ua.as_deref())
133            .map(UserAgent::new)
134            .or_else(|| {
135                req.headers()
136                    .typed_get::<headers::UserAgent>()
137                    .map(|ua| UserAgent::new(ua.to_string()))
138            });
139
140        if let Some(mut ua) = user_agent.take() {
141            if let Some(overwrites) = overwrites {
142                if let Some(http_agent) = overwrites.http {
143                    ua.set_http_agent(http_agent);
144                }
145                if let Some(tls_agent) = overwrites.tls {
146                    ua.set_tls_agent(tls_agent);
147                }
148            }
149
150            ctx.insert(ua);
151        }
152
153        self.inner.serve(ctx, req)
154    }
155}
156
157#[derive(Debug, Clone, Default)]
158/// A [`Layer`] that wraps a [`Service`] with a [`UserAgentClassifier`].
159///
160/// This [`Layer`] is used to classify the [`UserAgent`] of incoming [`Request`]s.
161pub struct UserAgentClassifierLayer {
162    overwrite_header: Option<HeaderName>,
163}
164
165impl UserAgentClassifierLayer {
166    /// Create a new [`UserAgentClassifierLayer`].
167    pub const fn new() -> Self {
168        Self {
169            overwrite_header: None,
170        }
171    }
172
173    /// Define a custom header to allow overwriting certain
174    /// [`UserAgent`] information.
175    pub fn overwrite_header(mut self, header: HeaderName) -> Self {
176        self.overwrite_header = Some(header);
177        self
178    }
179
180    /// Define a custom header to allow overwriting certain
181    /// [`UserAgent`] information.
182    pub fn set_overwrite_header(&mut self, header: HeaderName) -> &mut Self {
183        self.overwrite_header = Some(header);
184        self
185    }
186}
187
188impl<S> Layer<S> for UserAgentClassifierLayer {
189    type Service = UserAgentClassifier<S>;
190
191    fn layer(&self, inner: S) -> Self::Service {
192        UserAgentClassifier::new(inner, self.overwrite_header.clone())
193    }
194
195    fn into_layer(self, inner: S) -> Self::Service {
196        UserAgentClassifier::new(inner, self.overwrite_header)
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::layer::required_header::AddRequiredRequestHeadersLayer;
204    use crate::service::client::HttpClientExt;
205    use crate::{IntoResponse, Response, StatusCode, headers};
206    use rama_core::Context;
207    use rama_core::service::service_fn;
208    use std::convert::Infallible;
209
210    #[tokio::test]
211    async fn test_user_agent_classifier_layer_ua_rama() {
212        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
213            let ua: &UserAgent = ctx.get().unwrap();
214
215            assert_eq!(
216                ua.header_str(),
217                format!("{}/{}", rama_utils::info::NAME, rama_utils::info::VERSION).as_str(),
218            );
219            assert!(ua.info().is_none());
220            assert!(ua.platform().is_none());
221
222            Ok(StatusCode::OK.into_response())
223        }
224
225        let service = (
226            AddRequiredRequestHeadersLayer::default(),
227            UserAgentClassifierLayer::new(),
228        )
229            .into_layer(service_fn(handle));
230
231        let _ = service
232            .get("http://www.example.com")
233            .send(Context::default())
234            .await
235            .unwrap();
236    }
237
238    #[tokio::test]
239    async fn test_user_agent_classifier_layer_ua_iphone_app() {
240        const UA: &str = "iPhone App/1.0";
241
242        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
243            let ua: &UserAgent = ctx.get().unwrap();
244
245            assert_eq!(ua.header_str(), UA);
246            assert!(ua.info().is_none());
247            assert_eq!(ua.platform(), Some(PlatformKind::IOS));
248            assert_eq!(ua.http_agent(), None);
249            assert_eq!(ua.tls_agent(), None);
250
251            Ok(StatusCode::OK.into_response())
252        }
253
254        let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
255
256        let _ = service
257            .get("http://www.example.com")
258            .typed_header(headers::UserAgent::from_static(UA))
259            .send(Context::default())
260            .await
261            .unwrap();
262    }
263
264    #[tokio::test]
265    async fn test_user_agent_classifier_layer_ua_chrome() {
266        const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
267
268        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
269            let ua: &UserAgent = ctx.get().unwrap();
270
271            assert_eq!(ua.header_str(), UA);
272            let ua_info = ua.info().unwrap();
273            assert_eq!(ua_info.kind, UserAgentKind::Chromium);
274            assert_eq!(ua_info.version, Some(124));
275            assert_eq!(ua.platform(), Some(PlatformKind::Windows));
276
277            Ok(StatusCode::OK.into_response())
278        }
279
280        let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
281
282        let _ = service
283            .get("http://www.example.com")
284            .typed_header(headers::UserAgent::from_static(UA))
285            .send(Context::default())
286            .await
287            .unwrap();
288    }
289
290    #[tokio::test]
291    async fn test_user_agent_classifier_layer_overwrite_ua() {
292        const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
293
294        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
295            let ua: &UserAgent = ctx.get().unwrap();
296
297            assert_eq!(ua.header_str(), UA);
298            let ua_info = ua.info().unwrap();
299            assert_eq!(ua_info.kind, UserAgentKind::Chromium);
300            assert_eq!(ua_info.version, Some(124));
301            assert_eq!(ua.platform(), Some(PlatformKind::Windows));
302
303            Ok(StatusCode::OK.into_response())
304        }
305
306        let service = UserAgentClassifierLayer::new()
307            .overwrite_header(HeaderName::from_static("x-proxy-ua"))
308            .into_layer(service_fn(handle));
309
310        let _ = service
311            .get("http://www.example.com")
312            .header(
313                "x-proxy-ua",
314                serde_html_form::to_string(&UserAgentOverwrites {
315                    ua: Some(UA.to_owned()),
316                    ..Default::default()
317                })
318                .unwrap(),
319            )
320            .send(Context::default())
321            .await
322            .unwrap();
323    }
324
325    #[tokio::test]
326    async fn test_user_agent_classifier_layer_overwrite_ua_all() {
327        const UA: &str = "iPhone App/1.0";
328
329        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
330            let ua: &UserAgent = ctx.get().unwrap();
331
332            assert_eq!(ua.header_str(), UA);
333            assert!(ua.info().is_none());
334            assert_eq!(ua.platform(), Some(PlatformKind::IOS));
335            assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox));
336            assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl));
337
338            Ok(StatusCode::OK.into_response())
339        }
340
341        let service = UserAgentClassifierLayer::new()
342            .overwrite_header(HeaderName::from_static("x-proxy-ua"))
343            .into_layer(service_fn(handle));
344
345        let _ = service
346            .get("http://www.example.com")
347            .header(
348                "x-proxy-ua",
349                serde_html_form::to_string(&UserAgentOverwrites {
350                    ua: Some(UA.to_owned()),
351                    http: Some(HttpAgent::Firefox),
352                    tls: Some(TlsAgent::Boringssl),
353                })
354                .unwrap(),
355            )
356            .send(Context::default())
357            .await
358            .unwrap();
359    }
360}