1use crate::{
40 HeaderName, Request,
41 headers::{self, HeaderMapExt},
42};
43use rama_core::{Context, Layer, Service};
44use rama_utils::macros::define_inner_service_accessors;
45use std::fmt::{self, Debug};
46
47pub use rama_ua::{
48 DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind,
49 UserAgentOverwrites,
50};
51
52pub struct UserAgentClassifier<S> {
60 inner: S,
61 overwrite_header: Option<HeaderName>,
62}
63
64impl<S> UserAgentClassifier<S> {
65 pub const fn new(inner: S, overwrite_header: Option<HeaderName>) -> Self {
67 Self {
68 inner,
69 overwrite_header,
70 }
71 }
72
73 define_inner_service_accessors!();
74}
75
76impl<S> Debug for UserAgentClassifier<S>
77where
78 S: Debug,
79{
80 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81 f.debug_struct("UserAgentClassifier")
82 .field("inner", &self.inner)
83 .finish()
84 }
85}
86
87impl<S> Clone for UserAgentClassifier<S>
88where
89 S: Clone,
90{
91 fn clone(&self) -> Self {
92 Self {
93 inner: self.inner.clone(),
94 overwrite_header: self.overwrite_header.clone(),
95 }
96 }
97}
98
99impl<S> Default for UserAgentClassifier<S>
100where
101 S: Default,
102{
103 fn default() -> Self {
104 Self {
105 inner: S::default(),
106 overwrite_header: None,
107 }
108 }
109}
110
111impl<S, State, Body> Service<State, Request<Body>> for UserAgentClassifier<S>
112where
113 S: Service<State, Request<Body>>,
114 State: Clone + Send + Sync + 'static,
115{
116 type Response = S::Response;
117 type Error = S::Error;
118
119 fn serve(
120 &self,
121 mut ctx: Context<State>,
122 req: Request<Body>,
123 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
124 let overwrites = self
125 .overwrite_header
126 .as_ref()
127 .and_then(|header| req.headers().get(header))
128 .map(|header| header.as_bytes())
129 .and_then(|value| serde_html_form::from_bytes::<UserAgentOverwrites>(value).ok());
130
131 let mut user_agent = overwrites
132 .as_ref()
133 .and_then(|o| o.ua.as_deref())
134 .map(UserAgent::new)
135 .or_else(|| {
136 req.headers()
137 .typed_get::<headers::UserAgent>()
138 .map(|ua| UserAgent::new(ua.to_string()))
139 });
140
141 if let Some(mut ua) = user_agent.take() {
142 if let Some(overwrites) = overwrites {
143 if let Some(http_agent) = overwrites.http {
144 ua.set_http_agent(http_agent);
145 }
146 if let Some(tls_agent) = overwrites.tls {
147 ua.set_tls_agent(tls_agent);
148 }
149 }
150
151 ctx.insert(ua);
152 }
153
154 self.inner.serve(ctx, req)
155 }
156}
157
158#[derive(Debug, Clone, Default)]
159pub struct UserAgentClassifierLayer {
163 overwrite_header: Option<HeaderName>,
164}
165
166impl UserAgentClassifierLayer {
167 pub const fn new() -> Self {
169 Self {
170 overwrite_header: None,
171 }
172 }
173
174 pub fn overwrite_header(mut self, header: HeaderName) -> Self {
177 self.overwrite_header = Some(header);
178 self
179 }
180
181 pub fn set_overwrite_header(&mut self, header: HeaderName) -> &mut Self {
184 self.overwrite_header = Some(header);
185 self
186 }
187}
188
189impl<S> Layer<S> for UserAgentClassifierLayer {
190 type Service = UserAgentClassifier<S>;
191
192 fn layer(&self, inner: S) -> Self::Service {
193 UserAgentClassifier::new(inner, self.overwrite_header.clone())
194 }
195
196 fn into_layer(self, inner: S) -> Self::Service {
197 UserAgentClassifier::new(inner, self.overwrite_header)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::layer::required_header::AddRequiredRequestHeadersLayer;
205 use crate::service::client::HttpClientExt;
206 use crate::service::web::response::IntoResponse;
207 use crate::{Response, StatusCode, headers};
208 use rama_core::Context;
209 use rama_core::service::service_fn;
210 use std::convert::Infallible;
211
212 #[tokio::test]
213 async fn test_user_agent_classifier_layer_ua_rama() {
214 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
215 let ua: &UserAgent = ctx.get().unwrap();
216
217 assert_eq!(
218 ua.header_str(),
219 format!("{}/{}", rama_utils::info::NAME, rama_utils::info::VERSION).as_str(),
220 );
221 assert!(ua.info().is_none());
222 assert!(ua.platform().is_none());
223
224 Ok(StatusCode::OK.into_response())
225 }
226
227 let service = (
228 AddRequiredRequestHeadersLayer::default(),
229 UserAgentClassifierLayer::new(),
230 )
231 .into_layer(service_fn(handle));
232
233 let _ = service
234 .get("http://www.example.com")
235 .send(Context::default())
236 .await
237 .unwrap();
238 }
239
240 #[tokio::test]
241 async fn test_user_agent_classifier_layer_ua_iphone_app() {
242 const UA: &str = "iPhone App/1.0";
243
244 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
245 let ua: &UserAgent = ctx.get().unwrap();
246
247 assert_eq!(ua.header_str(), UA);
248 assert!(ua.info().is_none());
249 assert_eq!(ua.platform(), Some(PlatformKind::IOS));
250 assert_eq!(ua.http_agent(), None);
251 assert_eq!(ua.tls_agent(), None);
252
253 Ok(StatusCode::OK.into_response())
254 }
255
256 let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
257
258 let _ = service
259 .get("http://www.example.com")
260 .typed_header(headers::UserAgent::from_static(UA))
261 .send(Context::default())
262 .await
263 .unwrap();
264 }
265
266 #[tokio::test]
267 async fn test_user_agent_classifier_layer_ua_chrome() {
268 const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
269
270 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
271 let ua: &UserAgent = ctx.get().unwrap();
272
273 assert_eq!(ua.header_str(), UA);
274 let ua_info = ua.info().unwrap();
275 assert_eq!(ua_info.kind, UserAgentKind::Chromium);
276 assert_eq!(ua_info.version, Some(124));
277 assert_eq!(ua.platform(), Some(PlatformKind::Windows));
278
279 Ok(StatusCode::OK.into_response())
280 }
281
282 let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
283
284 let _ = service
285 .get("http://www.example.com")
286 .typed_header(headers::UserAgent::from_static(UA))
287 .send(Context::default())
288 .await
289 .unwrap();
290 }
291
292 #[tokio::test]
293 async fn test_user_agent_classifier_layer_overwrite_ua() {
294 const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
295
296 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
297 let ua: &UserAgent = ctx.get().unwrap();
298
299 assert_eq!(ua.header_str(), UA);
300 let ua_info = ua.info().unwrap();
301 assert_eq!(ua_info.kind, UserAgentKind::Chromium);
302 assert_eq!(ua_info.version, Some(124));
303 assert_eq!(ua.platform(), Some(PlatformKind::Windows));
304
305 Ok(StatusCode::OK.into_response())
306 }
307
308 let service = UserAgentClassifierLayer::new()
309 .overwrite_header(HeaderName::from_static("x-proxy-ua"))
310 .into_layer(service_fn(handle));
311
312 let _ = service
313 .get("http://www.example.com")
314 .header(
315 "x-proxy-ua",
316 serde_html_form::to_string(&UserAgentOverwrites {
317 ua: Some(UA.to_owned()),
318 ..Default::default()
319 })
320 .unwrap(),
321 )
322 .send(Context::default())
323 .await
324 .unwrap();
325 }
326
327 #[tokio::test]
328 async fn test_user_agent_classifier_layer_overwrite_ua_all() {
329 const UA: &str = "iPhone App/1.0";
330
331 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
332 let ua: &UserAgent = ctx.get().unwrap();
333
334 assert_eq!(ua.header_str(), UA);
335 assert!(ua.info().is_none());
336 assert_eq!(ua.platform(), Some(PlatformKind::IOS));
337 assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox));
338 assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl));
339
340 Ok(StatusCode::OK.into_response())
341 }
342
343 let service = UserAgentClassifierLayer::new()
344 .overwrite_header(HeaderName::from_static("x-proxy-ua"))
345 .into_layer(service_fn(handle));
346
347 let _ = service
348 .get("http://www.example.com")
349 .header(
350 "x-proxy-ua",
351 serde_html_form::to_string(&UserAgentOverwrites {
352 ua: Some(UA.to_owned()),
353 http: Some(HttpAgent::Firefox),
354 tls: Some(TlsAgent::Boringssl),
355 })
356 .unwrap(),
357 )
358 .send(Context::default())
359 .await
360 .unwrap();
361 }
362}