1use super::parse_http_user_agent_header;
2use rama_core::error::OpaqueError;
3use rama_utils::macros::match_ignore_ascii_case_str;
4use serde::{Deserialize, Deserializer, Serialize};
5use std::{convert::Infallible, fmt, str::FromStr, sync::Arc};
6
7#[derive(Debug, Clone)]
11pub struct UserAgent {
12 pub(super) header: Arc<str>,
13 pub(super) data: UserAgentData,
14 pub(super) http_agent_overwrite: Option<HttpAgent>,
15 pub(super) tls_agent_overwrite: Option<TlsAgent>,
16}
17
18impl fmt::Display for UserAgent {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 write!(f, "{}", self.header)
21 }
22}
23
24#[derive(Debug, Clone)]
26pub(super) enum UserAgentData {
27 Standard {
28 info: UserAgentInfo,
29 platform_like: Option<PlatformLike>,
30 },
31 Platform(PlatformKind),
32 Device(DeviceKind),
33 Unknown,
34}
35
36#[derive(Debug, Clone)]
37pub(super) enum PlatformLike {
38 Platform(PlatformKind),
39 Device(DeviceKind),
40}
41
42impl PlatformLike {
43 pub(super) fn device(&self) -> DeviceKind {
44 match self {
45 PlatformLike::Platform(platform_kind) => platform_kind.device(),
46 PlatformLike::Device(device_kind) => *device_kind,
47 }
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
53pub struct UserAgentInfo {
54 pub kind: UserAgentKind,
56 pub version: Option<usize>,
58}
59
60impl UserAgent {
61 pub fn new(header: impl Into<Arc<str>>) -> Self {
63 parse_http_user_agent_header(header.into())
64 }
65
66 pub fn with_http_agent(mut self, http_agent: HttpAgent) -> Self {
68 self.http_agent_overwrite = Some(http_agent);
69 self
70 }
71
72 pub fn set_http_agent(&mut self, http_agent: HttpAgent) -> &mut Self {
74 self.http_agent_overwrite = Some(http_agent);
75 self
76 }
77
78 pub fn with_tls_agent(mut self, tls_agent: TlsAgent) -> Self {
80 self.tls_agent_overwrite = Some(tls_agent);
81 self
82 }
83
84 pub fn set_tls_agent(&mut self, tls_agent: TlsAgent) -> &mut Self {
86 self.tls_agent_overwrite = Some(tls_agent);
87 self
88 }
89
90 pub fn header_str(&self) -> &str {
92 &self.header
93 }
94
95 pub fn device(&self) -> Option<DeviceKind> {
97 match &self.data {
98 UserAgentData::Standard { platform_like, .. } => {
99 platform_like.as_ref().map(|p| p.device())
100 }
101 UserAgentData::Platform(platform) => Some(platform.device()),
102 UserAgentData::Device(kind) => Some(*kind),
103 UserAgentData::Unknown => None,
104 }
105 }
106
107 pub fn info(&self) -> Option<UserAgentInfo> {
110 if let UserAgentData::Standard { info, .. } = &self.data {
111 Some(info.clone())
112 } else {
113 None
114 }
115 }
116
117 pub fn ua_kind(&self) -> Option<UserAgentKind> {
119 match self.http_agent_overwrite {
120 Some(HttpAgent::Chromium) => Some(UserAgentKind::Chromium),
121 Some(HttpAgent::Safari) => Some(UserAgentKind::Safari),
122 Some(HttpAgent::Firefox) => Some(UserAgentKind::Firefox),
123 Some(HttpAgent::Preserve) => None,
124 None => match &self.data {
125 UserAgentData::Standard {
126 info: UserAgentInfo { kind, .. },
127 ..
128 } => Some(*kind),
129 _ => None,
130 },
131 }
132 }
133
134 pub fn ua_version(&self) -> Option<usize> {
136 match &self.data {
137 UserAgentData::Standard { info, .. } => info.version,
138 _ => None,
139 }
140 }
141
142 pub fn platform(&self) -> Option<PlatformKind> {
146 match &self.data {
147 UserAgentData::Standard { platform_like, .. } => match platform_like {
148 Some(PlatformLike::Platform(platform)) => Some(*platform),
149 None | Some(PlatformLike::Device(_)) => None,
150 },
151 UserAgentData::Platform(platform) => Some(*platform),
152 _ => None,
153 }
154 }
155
156 pub fn http_agent(&self) -> Option<HttpAgent> {
160 match self.http_agent_overwrite {
161 Some(agent) => Some(agent),
162 None => match &self.data {
163 UserAgentData::Standard { info, .. } => Some(match info.kind {
164 UserAgentKind::Chromium => HttpAgent::Chromium,
165 UserAgentKind::Firefox => HttpAgent::Firefox,
166 UserAgentKind::Safari => HttpAgent::Safari,
167 }),
168 UserAgentData::Platform(_) | UserAgentData::Device(_) | UserAgentData::Unknown => {
169 None
170 }
171 },
172 }
173 }
174
175 pub fn tls_agent(&self) -> Option<TlsAgent> {
179 match self.tls_agent_overwrite {
180 Some(agent) => Some(agent),
181 None => match &self.data {
182 UserAgentData::Standard { info, .. } => Some(match info.kind {
183 UserAgentKind::Chromium => TlsAgent::Boringssl,
184 UserAgentKind::Firefox => TlsAgent::Nss,
185 UserAgentKind::Safari => TlsAgent::Rustls,
186 }),
187 UserAgentData::Device(_) | UserAgentData::Platform(_) | UserAgentData::Unknown => {
188 None
189 }
190 },
191 }
192 }
193}
194
195impl FromStr for UserAgent {
196 type Err = Infallible;
197
198 fn from_str(s: &str) -> Result<Self, Self::Err> {
199 Ok(UserAgent::new(s))
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
205pub enum UserAgentKind {
206 Chromium,
208 Firefox,
210 Safari,
212}
213
214impl UserAgentKind {
215 pub fn as_str(&self) -> &'static str {
216 match self {
217 UserAgentKind::Chromium => "Chromium",
218 UserAgentKind::Firefox => "Firefox",
219 UserAgentKind::Safari => "Safari",
220 }
221 }
222}
223
224impl fmt::Display for UserAgentKind {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 write!(f, "{}", self.as_str())
227 }
228}
229
230impl FromStr for UserAgentKind {
231 type Err = OpaqueError;
232
233 fn from_str(s: &str) -> Result<Self, Self::Err> {
234 match_ignore_ascii_case_str! {
235 match (s) {
236 "chromium" => Ok(UserAgentKind::Chromium),
237 "firefox" => Ok(UserAgentKind::Firefox),
238 "safari" => Ok(UserAgentKind::Safari),
239 _ => Err(OpaqueError::from_display(format!("invalid user agent kind: {}", s))),
240 }
241 }
242 }
243}
244
245impl Serialize for UserAgentKind {
246 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
247 where
248 S: serde::ser::Serializer,
249 {
250 serializer.serialize_str(self.as_str())
251 }
252}
253
254impl<'de> Deserialize<'de> for UserAgentKind {
255 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
256 where
257 D: Deserializer<'de>,
258 {
259 let s = <std::borrow::Cow<'de, str>>::deserialize(deserializer)?;
260 s.parse::<UserAgentKind>().map_err(serde::de::Error::custom)
261 }
262}
263
264#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
266pub enum DeviceKind {
267 Desktop,
269 Mobile,
271}
272
273impl DeviceKind {
274 pub fn as_str(&self) -> &'static str {
275 match self {
276 DeviceKind::Desktop => "Desktop",
277 DeviceKind::Mobile => "Mobile",
278 }
279 }
280}
281
282impl fmt::Display for DeviceKind {
283 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284 write!(f, "{}", self.as_str())
285 }
286}
287
288#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
290pub enum PlatformKind {
291 Windows,
293 MacOS,
295 Linux,
297 Android,
299 IOS,
301}
302
303impl PlatformKind {
304 pub fn as_str(&self) -> &'static str {
305 match self {
306 PlatformKind::Windows => "Windows",
307 PlatformKind::MacOS => "MacOS",
308 PlatformKind::Linux => "Linux",
309 PlatformKind::Android => "Android",
310 PlatformKind::IOS => "iOS",
311 }
312 }
313
314 pub fn device(&self) -> DeviceKind {
315 match self {
316 PlatformKind::Windows | PlatformKind::MacOS | PlatformKind::Linux => {
317 DeviceKind::Desktop
318 }
319 PlatformKind::Android | PlatformKind::IOS => DeviceKind::Mobile,
320 }
321 }
322}
323
324impl FromStr for PlatformKind {
325 type Err = OpaqueError;
326
327 fn from_str(s: &str) -> Result<Self, Self::Err> {
328 match_ignore_ascii_case_str! {
329 match (s) {
330 "windows" => Ok(PlatformKind::Windows),
331 "macos" => Ok(PlatformKind::MacOS),
332 "linux" => Ok(PlatformKind::Linux),
333 "android" => Ok(PlatformKind::Android),
334 "ios" => Ok(PlatformKind::IOS),
335 _ => Err(OpaqueError::from_display(format!("invalid platform: {}", s))),
336 }
337 }
338 }
339}
340
341impl Serialize for PlatformKind {
342 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
343 where
344 S: serde::ser::Serializer,
345 {
346 serializer.serialize_str(self.as_str())
347 }
348}
349
350impl<'de> Deserialize<'de> for PlatformKind {
351 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
352 where
353 D: Deserializer<'de>,
354 {
355 let s = <std::borrow::Cow<'de, str>>::deserialize(deserializer)?;
356 s.parse::<PlatformKind>().map_err(serde::de::Error::custom)
357 }
358}
359
360impl fmt::Display for PlatformKind {
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 write!(f, "{}", self.as_str())
363 }
364}
365
366#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
368pub enum HttpAgent {
369 Chromium,
371 Firefox,
373 Safari,
375 Preserve,
380}
381
382impl HttpAgent {
383 pub fn as_str(&self) -> &'static str {
384 match self {
385 HttpAgent::Chromium => "Chromium",
386 HttpAgent::Firefox => "Firefox",
387 HttpAgent::Safari => "Safari",
388 HttpAgent::Preserve => "Preserve",
389 }
390 }
391}
392
393impl fmt::Display for HttpAgent {
394 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395 write!(f, "{}", self.as_str())
396 }
397}
398
399impl Serialize for HttpAgent {
400 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
401 where
402 S: serde::ser::Serializer,
403 {
404 serializer.serialize_str(self.as_str())
405 }
406}
407
408impl<'de> Deserialize<'de> for HttpAgent {
409 fn deserialize<D>(deserializer: D) -> Result<HttpAgent, D::Error>
410 where
411 D: Deserializer<'de>,
412 {
413 let s = <std::borrow::Cow<'de, str>>::deserialize(deserializer)?;
414 s.parse::<HttpAgent>().map_err(serde::de::Error::custom)
415 }
416}
417
418impl FromStr for HttpAgent {
419 type Err = OpaqueError;
420
421 fn from_str(s: &str) -> Result<Self, Self::Err> {
422 match_ignore_ascii_case_str! {
423 match (s) {
424 "chrome" | "chromium" => Ok(HttpAgent::Chromium),
425 "Firefox" => Ok(HttpAgent::Firefox),
426 "Safari" => Ok(HttpAgent::Safari),
427 "preserve" => Ok(HttpAgent::Preserve),
428 _ => Err(OpaqueError::from_display(format!("invalid http agent: {}", s))),
429 }
430 }
431 }
432}
433
434#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
436pub enum TlsAgent {
437 Rustls,
440 Boringssl,
442 Nss,
444 Preserve,
450}
451
452impl TlsAgent {
453 pub fn as_str(&self) -> &'static str {
454 match self {
455 TlsAgent::Rustls => "Rustls",
456 TlsAgent::Boringssl => "Boringssl",
457 TlsAgent::Nss => "NSS",
458 TlsAgent::Preserve => "Preserve",
459 }
460 }
461}
462
463impl fmt::Display for TlsAgent {
464 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465 write!(f, "{}", self.as_str())
466 }
467}
468
469impl Serialize for TlsAgent {
470 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
471 where
472 S: serde::ser::Serializer,
473 {
474 serializer.serialize_str(self.as_str())
475 }
476}
477
478impl<'de> Deserialize<'de> for TlsAgent {
479 fn deserialize<D>(deserializer: D) -> Result<TlsAgent, D::Error>
480 where
481 D: Deserializer<'de>,
482 {
483 let s = <std::borrow::Cow<'de, str>>::deserialize(deserializer)?;
484 s.parse::<TlsAgent>().map_err(serde::de::Error::custom)
485 }
486}
487
488impl FromStr for TlsAgent {
489 type Err = OpaqueError;
490
491 fn from_str(s: &str) -> Result<Self, Self::Err> {
492 match_ignore_ascii_case_str! {
493 match (s) {
494 "rustls" => Ok(TlsAgent::Rustls),
495 "boring" | "boringssl" => Ok(TlsAgent::Boringssl),
496 "nss" => Ok(TlsAgent::Nss),
497 "preserve" => Ok(TlsAgent::Preserve),
498 _ => Err(OpaqueError::from_display(format!("invalid tls agent: {}", s))),
499 }
500 }
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn test_user_agent_new() {
510 let ua = UserAgent::new("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36".to_owned());
511 assert_eq!(
512 ua.header_str(),
513 "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
514 );
515 assert_eq!(
516 ua.info(),
517 Some(UserAgentInfo {
518 kind: UserAgentKind::Chromium,
519 version: Some(124)
520 })
521 );
522 assert_eq!(ua.platform(), Some(PlatformKind::MacOS));
523 assert_eq!(ua.device(), Some(DeviceKind::Desktop));
524 assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium));
525 assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl));
526 }
527
528 #[test]
529 fn test_user_agent_parse() {
530 let ua: UserAgent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36".parse().unwrap();
531 assert_eq!(
532 ua.header_str(),
533 "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
534 );
535 assert_eq!(
536 ua.info(),
537 Some(UserAgentInfo {
538 kind: UserAgentKind::Chromium,
539 version: Some(124)
540 })
541 );
542 assert_eq!(ua.platform(), Some(PlatformKind::MacOS));
543 assert_eq!(ua.device(), Some(DeviceKind::Desktop));
544 assert_eq!(ua.http_agent(), Some(HttpAgent::Chromium));
545 assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl));
546 }
547
548 #[test]
549 fn test_user_agent_display() {
550 let ua: UserAgent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36".parse().unwrap();
551 assert_eq!(
552 ua.to_string(),
553 "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
554 );
555 }
556
557 #[test]
558 fn test_tls_agent_parse() {
559 assert_eq!("rustls".parse::<TlsAgent>().unwrap(), TlsAgent::Rustls);
560 assert_eq!("rUsTlS".parse::<TlsAgent>().unwrap(), TlsAgent::Rustls);
561
562 assert_eq!("boring".parse::<TlsAgent>().unwrap(), TlsAgent::Boringssl);
563 assert_eq!("BoRiNg".parse::<TlsAgent>().unwrap(), TlsAgent::Boringssl);
564
565 assert_eq!("nss".parse::<TlsAgent>().unwrap(), TlsAgent::Nss);
566 assert_eq!("NSS".parse::<TlsAgent>().unwrap(), TlsAgent::Nss);
567
568 assert_eq!("preserve".parse::<TlsAgent>().unwrap(), TlsAgent::Preserve);
569 assert_eq!("Preserve".parse::<TlsAgent>().unwrap(), TlsAgent::Preserve);
570
571 assert!("".parse::<TlsAgent>().is_err());
572 assert!("invalid".parse::<TlsAgent>().is_err());
573 }
574
575 #[test]
576 fn test_tls_agent_deserialize() {
577 assert_eq!(
578 serde_json::from_str::<TlsAgent>(r#""rustls""#).unwrap(),
579 TlsAgent::Rustls
580 );
581 assert_eq!(
582 serde_json::from_str::<TlsAgent>(r#""RuStLs""#).unwrap(),
583 TlsAgent::Rustls
584 );
585
586 assert_eq!(
587 serde_json::from_str::<TlsAgent>(r#""boringssl""#).unwrap(),
588 TlsAgent::Boringssl
589 );
590 assert_eq!(
591 serde_json::from_str::<TlsAgent>(r#""BoringSSL""#).unwrap(),
592 TlsAgent::Boringssl
593 );
594
595 assert_eq!(
596 serde_json::from_str::<TlsAgent>(r#""nss""#).unwrap(),
597 TlsAgent::Nss
598 );
599 assert_eq!(
600 serde_json::from_str::<TlsAgent>(r#""NsS""#).unwrap(),
601 TlsAgent::Nss
602 );
603
604 assert_eq!(
605 serde_json::from_str::<TlsAgent>(r#""preserve""#).unwrap(),
606 TlsAgent::Preserve
607 );
608 assert_eq!(
609 serde_json::from_str::<TlsAgent>(r#""PreSeRvE""#).unwrap(),
610 TlsAgent::Preserve
611 );
612
613 assert!(serde_json::from_str::<TlsAgent>(r#""invalid""#).is_err());
614 assert!(serde_json::from_str::<TlsAgent>(r#""""#).is_err());
615 assert!(serde_json::from_str::<TlsAgent>("1").is_err());
616 }
617
618 #[test]
619 fn test_http_agent_parse() {
620 assert_eq!("chrome".parse::<HttpAgent>().unwrap(), HttpAgent::Chromium);
621 assert_eq!("ChRoMe".parse::<HttpAgent>().unwrap(), HttpAgent::Chromium);
622
623 assert_eq!("firefox".parse::<HttpAgent>().unwrap(), HttpAgent::Firefox);
624 assert_eq!("FiRefoX".parse::<HttpAgent>().unwrap(), HttpAgent::Firefox);
625
626 assert_eq!("safari".parse::<HttpAgent>().unwrap(), HttpAgent::Safari);
627 assert_eq!("SaFaRi".parse::<HttpAgent>().unwrap(), HttpAgent::Safari);
628
629 assert_eq!(
630 "preserve".parse::<HttpAgent>().unwrap(),
631 HttpAgent::Preserve
632 );
633 assert_eq!(
634 "Preserve".parse::<HttpAgent>().unwrap(),
635 HttpAgent::Preserve
636 );
637
638 assert!("".parse::<HttpAgent>().is_err());
639 assert!("invalid".parse::<HttpAgent>().is_err());
640 }
641
642 #[test]
643 fn test_http_agent_deserialize() {
644 assert_eq!(
645 serde_json::from_str::<HttpAgent>(r#""chrome""#).unwrap(),
646 HttpAgent::Chromium
647 );
648 assert_eq!(
649 serde_json::from_str::<HttpAgent>(r#""ChRoMe""#).unwrap(),
650 HttpAgent::Chromium
651 );
652
653 assert_eq!(
654 serde_json::from_str::<HttpAgent>(r#""firefox""#).unwrap(),
655 HttpAgent::Firefox
656 );
657 assert_eq!(
658 serde_json::from_str::<HttpAgent>(r#""FirEfOx""#).unwrap(),
659 HttpAgent::Firefox
660 );
661
662 assert_eq!(
663 serde_json::from_str::<HttpAgent>(r#""safari""#).unwrap(),
664 HttpAgent::Safari
665 );
666 assert_eq!(
667 serde_json::from_str::<HttpAgent>(r#""SafArI""#).unwrap(),
668 HttpAgent::Safari
669 );
670
671 assert_eq!(
672 serde_json::from_str::<HttpAgent>(r#""preserve""#).unwrap(),
673 HttpAgent::Preserve
674 );
675 assert_eq!(
676 serde_json::from_str::<HttpAgent>(r#""PreSeRve""#).unwrap(),
677 HttpAgent::Preserve
678 );
679
680 assert!(serde_json::from_str::<HttpAgent>("1").is_err());
681 assert!(serde_json::from_str::<HttpAgent>(r#""""#).is_err());
682 assert!(serde_json::from_str::<HttpAgent>(r#""invalid""#).is_err());
683 }
684}