rama_http/layer/
ua.rs

1//! User-Agent (see also `rama-ua`) http layer support
2//!
3//! # Example
4//!
5//! ```
6//! use rama_http::{
7//!     service::client::HttpClientExt, Request, Response, StatusCode,
8//!     layer::ua::{PlatformKind, UserAgent, UserAgentClassifierLayer, UserAgentKind, UserAgentInfo},
9//!     service::web::response::IntoResponse,
10//! };
11//! use rama_core::{Context, Layer, service::service_fn};
12//! use std::convert::Infallible;
13//!
14//! const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
15//!
16//! async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
17//!     let ua: &UserAgent = ctx.get().unwrap();
18//!
19//!     assert_eq!(ua.header_str(), UA);
20//!     assert_eq!(ua.info(), Some(UserAgentInfo{ kind: UserAgentKind::Chromium, version: Some(124) }));
21//!     assert_eq!(ua.platform(), Some(PlatformKind::Windows));
22//!
23//!     Ok(StatusCode::OK.into_response())
24//! }
25//!
26//! # #[tokio::main]
27//! # async fn main() {
28//! let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
29//!
30//! let _ = service
31//!     .get("http://www.example.com")
32//!     .typed_header(rama_http::headers::UserAgent::from_static(UA))
33//!     .send(Context::default())
34//!     .await
35//!     .unwrap();
36//! # }
37//! ```
38
39use crate::{
40    HeaderName, Request,
41    headers::{self, HeaderMapExt},
42};
43use rama_core::{Context, Layer, Service};
44use rama_utils::macros::define_inner_service_accessors;
45use std::fmt::{self, Debug};
46
47pub use rama_ua::{
48    DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind,
49    UserAgentOverwrites,
50};
51
52/// A [`Service`] that classifies the [`UserAgent`] of incoming [`Request`]s.
53///
54/// The [`Extensions`] of the [`Context`] is updated with the [`UserAgent`]
55/// if the [`Request`] contains a valid [`UserAgent`] header.
56///
57/// [`Extensions`]: rama_core::context::Extensions
58/// [`Context`]: rama_core::Context
59pub struct UserAgentClassifier<S> {
60    inner: S,
61    overwrite_header: Option<HeaderName>,
62}
63
64impl<S> UserAgentClassifier<S> {
65    /// Create a new [`UserAgentClassifier`] [`Service`].
66    pub const fn new(inner: S, overwrite_header: Option<HeaderName>) -> Self {
67        Self {
68            inner,
69            overwrite_header,
70        }
71    }
72
73    define_inner_service_accessors!();
74}
75
76impl<S> Debug for UserAgentClassifier<S>
77where
78    S: Debug,
79{
80    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81        f.debug_struct("UserAgentClassifier")
82            .field("inner", &self.inner)
83            .finish()
84    }
85}
86
87impl<S> Clone for UserAgentClassifier<S>
88where
89    S: Clone,
90{
91    fn clone(&self) -> Self {
92        Self {
93            inner: self.inner.clone(),
94            overwrite_header: self.overwrite_header.clone(),
95        }
96    }
97}
98
99impl<S> Default for UserAgentClassifier<S>
100where
101    S: Default,
102{
103    fn default() -> Self {
104        Self {
105            inner: S::default(),
106            overwrite_header: None,
107        }
108    }
109}
110
111impl<S, State, Body> Service<State, Request<Body>> for UserAgentClassifier<S>
112where
113    S: Service<State, Request<Body>>,
114    State: Clone + Send + Sync + 'static,
115{
116    type Response = S::Response;
117    type Error = S::Error;
118
119    fn serve(
120        &self,
121        mut ctx: Context<State>,
122        req: Request<Body>,
123    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
124        let overwrites = self
125            .overwrite_header
126            .as_ref()
127            .and_then(|header| req.headers().get(header))
128            .map(|header| header.as_bytes())
129            .and_then(|value| serde_html_form::from_bytes::<UserAgentOverwrites>(value).ok());
130
131        let mut user_agent = overwrites
132            .as_ref()
133            .and_then(|o| o.ua.as_deref())
134            .map(UserAgent::new)
135            .or_else(|| {
136                req.headers()
137                    .typed_get::<headers::UserAgent>()
138                    .map(|ua| UserAgent::new(ua.to_string()))
139            });
140
141        if let Some(mut ua) = user_agent.take() {
142            if let Some(overwrites) = overwrites {
143                if let Some(http_agent) = overwrites.http {
144                    ua.set_http_agent(http_agent);
145                }
146                if let Some(tls_agent) = overwrites.tls {
147                    ua.set_tls_agent(tls_agent);
148                }
149            }
150
151            ctx.insert(ua);
152        }
153
154        self.inner.serve(ctx, req)
155    }
156}
157
158#[derive(Debug, Clone, Default)]
159/// A [`Layer`] that wraps a [`Service`] with a [`UserAgentClassifier`].
160///
161/// This [`Layer`] is used to classify the [`UserAgent`] of incoming [`Request`]s.
162pub struct UserAgentClassifierLayer {
163    overwrite_header: Option<HeaderName>,
164}
165
166impl UserAgentClassifierLayer {
167    /// Create a new [`UserAgentClassifierLayer`].
168    pub const fn new() -> Self {
169        Self {
170            overwrite_header: None,
171        }
172    }
173
174    /// Define a custom header to allow overwriting certain
175    /// [`UserAgent`] information.
176    pub fn overwrite_header(mut self, header: HeaderName) -> Self {
177        self.overwrite_header = Some(header);
178        self
179    }
180
181    /// Define a custom header to allow overwriting certain
182    /// [`UserAgent`] information.
183    pub fn set_overwrite_header(&mut self, header: HeaderName) -> &mut Self {
184        self.overwrite_header = Some(header);
185        self
186    }
187}
188
189impl<S> Layer<S> for UserAgentClassifierLayer {
190    type Service = UserAgentClassifier<S>;
191
192    fn layer(&self, inner: S) -> Self::Service {
193        UserAgentClassifier::new(inner, self.overwrite_header.clone())
194    }
195
196    fn into_layer(self, inner: S) -> Self::Service {
197        UserAgentClassifier::new(inner, self.overwrite_header)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::layer::required_header::AddRequiredRequestHeadersLayer;
205    use crate::service::client::HttpClientExt;
206    use crate::service::web::response::IntoResponse;
207    use crate::{Response, StatusCode, headers};
208    use rama_core::Context;
209    use rama_core::service::service_fn;
210    use std::convert::Infallible;
211
212    #[tokio::test]
213    async fn test_user_agent_classifier_layer_ua_rama() {
214        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
215            let ua: &UserAgent = ctx.get().unwrap();
216
217            assert_eq!(
218                ua.header_str(),
219                format!("{}/{}", rama_utils::info::NAME, rama_utils::info::VERSION).as_str(),
220            );
221            assert!(ua.info().is_none());
222            assert!(ua.platform().is_none());
223
224            Ok(StatusCode::OK.into_response())
225        }
226
227        let service = (
228            AddRequiredRequestHeadersLayer::default(),
229            UserAgentClassifierLayer::new(),
230        )
231            .into_layer(service_fn(handle));
232
233        let _ = service
234            .get("http://www.example.com")
235            .send(Context::default())
236            .await
237            .unwrap();
238    }
239
240    #[tokio::test]
241    async fn test_user_agent_classifier_layer_ua_iphone_app() {
242        const UA: &str = "iPhone App/1.0";
243
244        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
245            let ua: &UserAgent = ctx.get().unwrap();
246
247            assert_eq!(ua.header_str(), UA);
248            assert!(ua.info().is_none());
249            assert_eq!(ua.platform(), Some(PlatformKind::IOS));
250            assert_eq!(ua.http_agent(), None);
251            assert_eq!(ua.tls_agent(), None);
252
253            Ok(StatusCode::OK.into_response())
254        }
255
256        let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
257
258        let _ = service
259            .get("http://www.example.com")
260            .typed_header(headers::UserAgent::from_static(UA))
261            .send(Context::default())
262            .await
263            .unwrap();
264    }
265
266    #[tokio::test]
267    async fn test_user_agent_classifier_layer_ua_chrome() {
268        const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
269
270        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
271            let ua: &UserAgent = ctx.get().unwrap();
272
273            assert_eq!(ua.header_str(), UA);
274            let ua_info = ua.info().unwrap();
275            assert_eq!(ua_info.kind, UserAgentKind::Chromium);
276            assert_eq!(ua_info.version, Some(124));
277            assert_eq!(ua.platform(), Some(PlatformKind::Windows));
278
279            Ok(StatusCode::OK.into_response())
280        }
281
282        let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
283
284        let _ = service
285            .get("http://www.example.com")
286            .typed_header(headers::UserAgent::from_static(UA))
287            .send(Context::default())
288            .await
289            .unwrap();
290    }
291
292    #[tokio::test]
293    async fn test_user_agent_classifier_layer_overwrite_ua() {
294        const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
295
296        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
297            let ua: &UserAgent = ctx.get().unwrap();
298
299            assert_eq!(ua.header_str(), UA);
300            let ua_info = ua.info().unwrap();
301            assert_eq!(ua_info.kind, UserAgentKind::Chromium);
302            assert_eq!(ua_info.version, Some(124));
303            assert_eq!(ua.platform(), Some(PlatformKind::Windows));
304
305            Ok(StatusCode::OK.into_response())
306        }
307
308        let service = UserAgentClassifierLayer::new()
309            .overwrite_header(HeaderName::from_static("x-proxy-ua"))
310            .into_layer(service_fn(handle));
311
312        let _ = service
313            .get("http://www.example.com")
314            .header(
315                "x-proxy-ua",
316                serde_html_form::to_string(&UserAgentOverwrites {
317                    ua: Some(UA.to_owned()),
318                    ..Default::default()
319                })
320                .unwrap(),
321            )
322            .send(Context::default())
323            .await
324            .unwrap();
325    }
326
327    #[tokio::test]
328    async fn test_user_agent_classifier_layer_overwrite_ua_all() {
329        const UA: &str = "iPhone App/1.0";
330
331        async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
332            let ua: &UserAgent = ctx.get().unwrap();
333
334            assert_eq!(ua.header_str(), UA);
335            assert!(ua.info().is_none());
336            assert_eq!(ua.platform(), Some(PlatformKind::IOS));
337            assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox));
338            assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl));
339
340            Ok(StatusCode::OK.into_response())
341        }
342
343        let service = UserAgentClassifierLayer::new()
344            .overwrite_header(HeaderName::from_static("x-proxy-ua"))
345            .into_layer(service_fn(handle));
346
347        let _ = service
348            .get("http://www.example.com")
349            .header(
350                "x-proxy-ua",
351                serde_html_form::to_string(&UserAgentOverwrites {
352                    ua: Some(UA.to_owned()),
353                    http: Some(HttpAgent::Firefox),
354                    tls: Some(TlsAgent::Boringssl),
355                })
356                .unwrap(),
357            )
358            .send(Context::default())
359            .await
360            .unwrap();
361    }
362}