use crate::{
HeaderName, Request,
headers::{self, HeaderMapExt},
};
use rama_core::{Context, Layer, Service};
use rama_utils::macros::define_inner_service_accessors;
use std::fmt::{self, Debug};
pub use rama_ua::{
DeviceKind, HttpAgent, PlatformKind, TlsAgent, UserAgent, UserAgentInfo, UserAgentKind,
UserAgentOverwrites,
};
pub struct UserAgentClassifier<S> {
inner: S,
overwrite_header: Option<HeaderName>,
}
impl<S> UserAgentClassifier<S> {
pub const fn new(inner: S, overwrite_header: Option<HeaderName>) -> Self {
Self {
inner,
overwrite_header,
}
}
define_inner_service_accessors!();
}
impl<S> Debug for UserAgentClassifier<S>
where
S: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("UserAgentClassifier")
.field("inner", &self.inner)
.finish()
}
}
impl<S> Clone for UserAgentClassifier<S>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
overwrite_header: self.overwrite_header.clone(),
}
}
}
impl<S> Default for UserAgentClassifier<S>
where
S: Default,
{
fn default() -> Self {
Self {
inner: S::default(),
overwrite_header: None,
}
}
}
impl<S, State, Body> Service<State, Request<Body>> for UserAgentClassifier<S>
where
S: Service<State, Request<Body>>,
State: Clone + Send + Sync + 'static,
{
type Response = S::Response;
type Error = S::Error;
fn serve(
&self,
mut ctx: Context<State>,
req: Request<Body>,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
let overwrites = self
.overwrite_header
.as_ref()
.and_then(|header| req.headers().get(header))
.map(|header| header.as_bytes())
.and_then(|value| serde_html_form::from_bytes::<UserAgentOverwrites>(value).ok());
let mut user_agent = overwrites
.as_ref()
.and_then(|o| o.ua.as_deref())
.map(UserAgent::new)
.or_else(|| {
req.headers()
.typed_get::<headers::UserAgent>()
.map(|ua| UserAgent::new(ua.to_string()))
});
if let Some(mut ua) = user_agent.take() {
if let Some(overwrites) = overwrites {
if let Some(http_agent) = overwrites.http {
ua.set_http_agent(http_agent);
}
if let Some(tls_agent) = overwrites.tls {
ua.set_tls_agent(tls_agent);
}
}
ctx.insert(ua);
}
self.inner.serve(ctx, req)
}
}
#[derive(Debug, Clone, Default)]
pub struct UserAgentClassifierLayer {
overwrite_header: Option<HeaderName>,
}
impl UserAgentClassifierLayer {
pub const fn new() -> Self {
Self {
overwrite_header: None,
}
}
pub fn overwrite_header(mut self, header: HeaderName) -> Self {
self.overwrite_header = Some(header);
self
}
pub fn set_overwrite_header(&mut self, header: HeaderName) -> &mut Self {
self.overwrite_header = Some(header);
self
}
}
impl<S> Layer<S> for UserAgentClassifierLayer {
type Service = UserAgentClassifier<S>;
fn layer(&self, inner: S) -> Self::Service {
UserAgentClassifier::new(inner, self.overwrite_header.clone())
}
fn into_layer(self, inner: S) -> Self::Service {
UserAgentClassifier::new(inner, self.overwrite_header)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layer::required_header::AddRequiredRequestHeadersLayer;
use crate::service::client::HttpClientExt;
use crate::service::web::response::IntoResponse;
use crate::{Response, StatusCode, headers};
use rama_core::Context;
use rama_core::service::service_fn;
use std::convert::Infallible;
#[tokio::test]
async fn test_user_agent_classifier_layer_ua_rama() {
async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
let ua: &UserAgent = ctx.get().unwrap();
assert_eq!(
ua.header_str(),
format!("{}/{}", rama_utils::info::NAME, rama_utils::info::VERSION).as_str(),
);
assert!(ua.info().is_none());
assert!(ua.platform().is_none());
Ok(StatusCode::OK.into_response())
}
let service = (
AddRequiredRequestHeadersLayer::default(),
UserAgentClassifierLayer::new(),
)
.into_layer(service_fn(handle));
let _ = service
.get("http://www.example.com")
.send(Context::default())
.await
.unwrap();
}
#[tokio::test]
async fn test_user_agent_classifier_layer_ua_iphone_app() {
const UA: &str = "iPhone App/1.0";
async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
let ua: &UserAgent = ctx.get().unwrap();
assert_eq!(ua.header_str(), UA);
assert!(ua.info().is_none());
assert_eq!(ua.platform(), Some(PlatformKind::IOS));
assert_eq!(ua.http_agent(), None);
assert_eq!(ua.tls_agent(), None);
Ok(StatusCode::OK.into_response())
}
let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
let _ = service
.get("http://www.example.com")
.typed_header(headers::UserAgent::from_static(UA))
.send(Context::default())
.await
.unwrap();
}
#[tokio::test]
async fn test_user_agent_classifier_layer_ua_chrome() {
const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
let ua: &UserAgent = ctx.get().unwrap();
assert_eq!(ua.header_str(), UA);
let ua_info = ua.info().unwrap();
assert_eq!(ua_info.kind, UserAgentKind::Chromium);
assert_eq!(ua_info.version, Some(124));
assert_eq!(ua.platform(), Some(PlatformKind::Windows));
Ok(StatusCode::OK.into_response())
}
let service = UserAgentClassifierLayer::new().into_layer(service_fn(handle));
let _ = service
.get("http://www.example.com")
.typed_header(headers::UserAgent::from_static(UA))
.send(Context::default())
.await
.unwrap();
}
#[tokio::test]
async fn test_user_agent_classifier_layer_overwrite_ua() {
const UA: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.2478.67";
async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
let ua: &UserAgent = ctx.get().unwrap();
assert_eq!(ua.header_str(), UA);
let ua_info = ua.info().unwrap();
assert_eq!(ua_info.kind, UserAgentKind::Chromium);
assert_eq!(ua_info.version, Some(124));
assert_eq!(ua.platform(), Some(PlatformKind::Windows));
Ok(StatusCode::OK.into_response())
}
let service = UserAgentClassifierLayer::new()
.overwrite_header(HeaderName::from_static("x-proxy-ua"))
.into_layer(service_fn(handle));
let _ = service
.get("http://www.example.com")
.header(
"x-proxy-ua",
serde_html_form::to_string(&UserAgentOverwrites {
ua: Some(UA.to_owned()),
..Default::default()
})
.unwrap(),
)
.send(Context::default())
.await
.unwrap();
}
#[tokio::test]
async fn test_user_agent_classifier_layer_overwrite_ua_all() {
const UA: &str = "iPhone App/1.0";
async fn handle<S>(ctx: Context<S>, _req: Request) -> Result<Response, Infallible> {
let ua: &UserAgent = ctx.get().unwrap();
assert_eq!(ua.header_str(), UA);
assert!(ua.info().is_none());
assert_eq!(ua.platform(), Some(PlatformKind::IOS));
assert_eq!(ua.http_agent(), Some(HttpAgent::Firefox));
assert_eq!(ua.tls_agent(), Some(TlsAgent::Boringssl));
Ok(StatusCode::OK.into_response())
}
let service = UserAgentClassifierLayer::new()
.overwrite_header(HeaderName::from_static("x-proxy-ua"))
.into_layer(service_fn(handle));
let _ = service
.get("http://www.example.com")
.header(
"x-proxy-ua",
serde_html_form::to_string(&UserAgentOverwrites {
ua: Some(UA.to_owned()),
http: Some(HttpAgent::Firefox),
tls: Some(TlsAgent::Boringssl),
})
.unwrap(),
)
.send(Context::default())
.await
.unwrap();
}
}