1use crate::{
39 HeaderName, Request,
40 headers::{self, HeaderMapExt},
41};
42use rama_core::{Context, Layer, Service};
43use rama_utils::macros::define_inner_service_accessors;
44use std::fmt::{self, Debug};
45
46pub use rama_ua::{
47 DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind,
48 UserAgentOverwrites,
49};
50
51pub struct UserAgentClassifier<S> {
59 inner: S,
60 overwrite_header: Option<HeaderName>,
61}
62
63impl<S> UserAgentClassifier<S> {
64 pub const fn new(inner: S, overwrite_header: Option<HeaderName>) -> Self {
66 Self {
67 inner,
68 overwrite_header,
69 }
70 }
71
72 define_inner_service_accessors!();
73}
74
75impl<S> Debug for UserAgentClassifier<S>
76where
77 S: Debug,
78{
79 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80 f.debug_struct("UserAgentClassifier")
81 .field("inner", &self.inner)
82 .finish()
83 }
84}
85
86impl<S> Clone for UserAgentClassifier<S>
87where
88 S: Clone,
89{
90 fn clone(&self) -> Self {
91 Self {
92 inner: self.inner.clone(),
93 overwrite_header: self.overwrite_header.clone(),
94 }
95 }
96}
97
98impl<S> Default for UserAgentClassifier<S>
99where
100 S: Default,
101{
102 fn default() -> Self {
103 Self {
104 inner: S::default(),
105 overwrite_header: None,
106 }
107 }
108}
109
110impl<S, State, Body> Service<State, Request<Body>> for UserAgentClassifier<S>
111where
112 S: Service<State, Request<Body>>,
113 State: Clone + Send + Sync + 'static,
114{
115 type Response = S::Response;
116 type Error = S::Error;
117
118 fn serve(
119 &self,
120 mut ctx: Context<State>,
121 req: Request<Body>,
122 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
123 let overwrites = self
124 .overwrite_header
125 .as_ref()
126 .and_then(|header| req.headers().get(header))
127 .map(|header| header.as_bytes())
128 .and_then(|value| serde_html_form::from_bytes::<UserAgentOverwrites>(value).ok());
129
130 let mut user_agent = overwrites
131 .as_ref()
132 .and_then(|o| o.ua.as_deref())
133 .map(UserAgent::new)
134 .or_else(|| {
135 req.headers()
136 .typed_get::<headers::UserAgent>()
137 .map(|ua| UserAgent::new(ua.to_string()))
138 });
139
140 if let Some(mut ua) = user_agent.take() {
141 if let Some(overwrites) = overwrites {
142 if let Some(http_agent) = overwrites.http {
143 ua.set_http_agent(http_agent);
144 }
145 if let Some(tls_agent) = overwrites.tls {
146 ua.set_tls_agent(tls_agent);
147 }
148 }
149
150 ctx.insert(ua);
151 }
152
153 self.inner.serve(ctx, req)
154 }
155}
156
157#[derive(Debug, Clone, Default)]
158pub struct UserAgentClassifierLayer {
162 overwrite_header: Option<HeaderName>,
163}
164
165impl UserAgentClassifierLayer {
166 pub const fn new() -> Self {
168 Self {
169 overwrite_header: None,
170 }
171 }
172
173 pub fn overwrite_header(mut self, header: HeaderName) -> Self {
176 self.overwrite_header = Some(header);
177 self
178 }
179
180 pub fn set_overwrite_header(&mut self, header: HeaderName) -> &mut Self {
183 self.overwrite_header = Some(header);
184 self
185 }
186}
187
188impl<S> Layer<S> for UserAgentClassifierLayer {
189 type Service = UserAgentClassifier<S>;
190
191 fn layer(&self, inner: S) -> Self::Service {
192 UserAgentClassifier::new(inner, self.overwrite_header.clone())
193 }
194
195 fn into_layer(self, inner: S) -> Self::Service {
196 UserAgentClassifier::new(inner, self.overwrite_header)
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::layer::required_header::AddRequiredRequestHeadersLayer;
204 use crate::service::client::HttpClientExt;
205 use crate::{IntoResponse, Response, StatusCode, headers};
206 use rama_core::Context;
207 use rama_core::service::service_fn;
208 use std::convert::Infallible;
209
210 #[tokio::test]
211 async fn test_user_agent_classifier_layer_ua_rama() {
212 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
213 let ua: &UserAgent = ctx.get().unwrap();
214
215 assert_eq!(
216 ua.header_str(),
217 format!("{}/{}", rama_utils::info::NAME, rama_utils::info::VERSION).as_str(),
218 );
219 assert!(ua.info().is_none());
220 assert!(ua.platform().is_none());
221
222 Ok(StatusCode::OK.into_response())
223 }
224
225 let service = (
226 AddRequiredRequestHeadersLayer::default(),
227 UserAgentClassifierLayer::new(),
228 )
229 .into_layer(service_fn(handle));
230
231 let _ = service
232 .get("http://www.example.com")
233 .send(Context::default())
234 .await
235 .unwrap();
236 }
237
238 #[tokio::test]
239 async fn test_user_agent_classifier_layer_ua_iphone_app() {
240 const UA: &str = "iPhone App/1.0";
241
242 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
243 let ua: &UserAgent = ctx.get().unwrap();
244
245 assert_eq!(ua.header_str(), UA);
246 assert!(ua.info().is_none());
247 assert_eq!(ua.platform(), Some(PlatformKind::IOS));
248 assert_eq!(ua.http_agent(), None);
249 assert_eq!(ua.tls_agent(), None);
250
251 Ok(StatusCode::OK.into_response())
252 }
253
254 let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
255
256 let _ = service
257 .get("http://www.example.com")
258 .typed_header(headers::UserAgent::from_static(UA))
259 .send(Context::default())
260 .await
261 .unwrap();
262 }
263
264 #[tokio::test]
265 async fn test_user_agent_classifier_layer_ua_chrome() {
266 const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
267
268 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
269 let ua: &UserAgent = ctx.get().unwrap();
270
271 assert_eq!(ua.header_str(), UA);
272 let ua_info = ua.info().unwrap();
273 assert_eq!(ua_info.kind, UserAgentKind::Chromium);
274 assert_eq!(ua_info.version, Some(124));
275 assert_eq!(ua.platform(), Some(PlatformKind::Windows));
276
277 Ok(StatusCode::OK.into_response())
278 }
279
280 let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
281
282 let _ = service
283 .get("http://www.example.com")
284 .typed_header(headers::UserAgent::from_static(UA))
285 .send(Context::default())
286 .await
287 .unwrap();
288 }
289
290 #[tokio::test]
291 async fn test_user_agent_classifier_layer_overwrite_ua() {
292 const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
293
294 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
295 let ua: &UserAgent = ctx.get().unwrap();
296
297 assert_eq!(ua.header_str(), UA);
298 let ua_info = ua.info().unwrap();
299 assert_eq!(ua_info.kind, UserAgentKind::Chromium);
300 assert_eq!(ua_info.version, Some(124));
301 assert_eq!(ua.platform(), Some(PlatformKind::Windows));
302
303 Ok(StatusCode::OK.into_response())
304 }
305
306 let service = UserAgentClassifierLayer::new()
307 .overwrite_header(HeaderName::from_static("x-proxy-ua"))
308 .into_layer(service_fn(handle));
309
310 let _ = service
311 .get("http://www.example.com")
312 .header(
313 "x-proxy-ua",
314 serde_html_form::to_string(&UserAgentOverwrites {
315 ua: Some(UA.to_owned()),
316 ..Default::default()
317 })
318 .unwrap(),
319 )
320 .send(Context::default())
321 .await
322 .unwrap();
323 }
324
325 #[tokio::test]
326 async fn test_user_agent_classifier_layer_overwrite_ua_all() {
327 const UA: &str = "iPhone App/1.0";
328
329 async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
330 let ua: &UserAgent = ctx.get().unwrap();
331
332 assert_eq!(ua.header_str(), UA);
333 assert!(ua.info().is_none());
334 assert_eq!(ua.platform(), Some(PlatformKind::IOS));
335 assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox));
336 assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl));
337
338 Ok(StatusCode::OK.into_response())
339 }
340
341 let service = UserAgentClassifierLayer::new()
342 .overwrite_header(HeaderName::from_static("x-proxy-ua"))
343 .into_layer(service_fn(handle));
344
345 let _ = service
346 .get("http://www.example.com")
347 .header(
348 "x-proxy-ua",
349 serde_html_form::to_string(&UserAgentOverwrites {
350 ua: Some(UA.to_owned()),
351 http: Some(HttpAgent::Firefox),
352 tls: Some(TlsAgent::Boringssl),
353 })
354 .unwrap(),
355 )
356 .send(Context::default())
357 .await
358 .unwrap();
359 }
360}