use axum::extract::FromRequestParts;
use http::{HeaderMap, request::Parts};
use crate::error::Error;
use crate::ip::ClientIp;
use super::device::{parse_device_name, parse_device_type};
use super::fingerprint::compute_fingerprint;
#[derive(Debug, Clone, Default)]
pub struct ClientInfo {
ip: Option<String>,
user_agent: Option<String>,
device_name: Option<String>,
device_type: Option<String>,
fingerprint: Option<String>,
}
impl ClientInfo {
pub fn new() -> Self {
Self::default()
}
pub fn ip(mut self, ip: impl Into<String>) -> Self {
self.ip = Some(ip.into());
self
}
pub fn user_agent(mut self, ua: impl Into<String>) -> Self {
self.user_agent = Some(ua.into());
self
}
pub fn device_name(mut self, name: impl Into<String>) -> Self {
self.device_name = Some(name.into());
self
}
pub fn device_type(mut self, kind: impl Into<String>) -> Self {
self.device_type = Some(kind.into());
self
}
pub fn fingerprint(mut self, fp: impl Into<String>) -> Self {
self.fingerprint = Some(fp.into());
self
}
pub fn from_headers(
ip: Option<String>,
user_agent: &str,
accept_language: &str,
accept_encoding: &str,
) -> Self {
Self {
ip,
user_agent: Some(user_agent.to_string()),
device_name: Some(parse_device_name(user_agent)),
device_type: Some(parse_device_type(user_agent)),
fingerprint: Some(compute_fingerprint(
user_agent,
accept_language,
accept_encoding,
)),
}
}
pub fn ip_value(&self) -> Option<&str> {
self.ip.as_deref()
}
pub fn user_agent_value(&self) -> Option<&str> {
self.user_agent.as_deref()
}
pub fn device_name_value(&self) -> Option<&str> {
self.device_name.as_deref()
}
pub fn device_type_value(&self) -> Option<&str> {
self.device_type.as_deref()
}
pub fn fingerprint_value(&self) -> Option<&str> {
self.fingerprint.as_deref()
}
}
impl<S: Send + Sync> FromRequestParts<S> for ClientInfo {
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let ip = parts.extensions.get::<ClientIp>().map(|c| c.0.to_string());
let user_agent = parts
.headers
.get(http::header::USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let accept_lang = header_str(&parts.headers, "accept-language");
let accept_enc = header_str(&parts.headers, "accept-encoding");
let (device_name, device_type) = match user_agent.as_deref() {
Some(ua) => (Some(parse_device_name(ua)), Some(parse_device_type(ua))),
None => (None, None),
};
let fingerprint = Some(compute_fingerprint(
user_agent.as_deref().unwrap_or(""),
accept_lang,
accept_enc,
));
Ok(Self {
ip,
user_agent,
device_name,
device_type,
fingerprint,
})
}
}
pub fn header_str<'a>(headers: &'a HeaderMap, name: &str) -> &'a str {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_has_all_none() {
let info = ClientInfo::new();
assert!(info.ip_value().is_none());
assert!(info.user_agent_value().is_none());
assert!(info.device_name_value().is_none());
assert!(info.device_type_value().is_none());
assert!(info.fingerprint_value().is_none());
}
#[test]
fn builder_sets_fields() {
let info = ClientInfo::new()
.ip("1.2.3.4")
.user_agent("Mozilla/5.0")
.device_name("Chrome on macOS")
.device_type("desktop")
.fingerprint("abc123");
assert_eq!(info.ip_value(), Some("1.2.3.4"));
assert_eq!(info.user_agent_value(), Some("Mozilla/5.0"));
assert_eq!(info.device_name_value(), Some("Chrome on macOS"));
assert_eq!(info.device_type_value(), Some("desktop"));
assert_eq!(info.fingerprint_value(), Some("abc123"));
}
#[test]
fn from_headers_populates_derived_fields() {
let info = ClientInfo::from_headers(
Some("10.0.0.1".to_string()),
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) Chrome/120.0",
"en-US",
"gzip",
);
assert_eq!(info.ip_value(), Some("10.0.0.1"));
assert_eq!(info.device_name_value(), Some("Chrome on macOS"));
assert_eq!(info.device_type_value(), Some("desktop"));
let fp = info.fingerprint_value().unwrap();
assert_eq!(fp.len(), 64);
assert!(fp.chars().all(|c| c.is_ascii_hexdigit()));
}
#[tokio::test]
async fn extracts_from_request_parts() {
use std::net::IpAddr;
let mut req = http::Request::builder()
.header("user-agent", "Mozilla/5.0 (iPhone) Safari/605")
.header("accept-language", "en-US")
.header("accept-encoding", "gzip")
.body(())
.unwrap();
let ip: IpAddr = "10.0.0.1".parse().unwrap();
req.extensions_mut().insert(ClientIp(ip));
let (mut parts, _) = req.into_parts();
let info = ClientInfo::from_request_parts(&mut parts, &())
.await
.unwrap();
assert_eq!(info.ip_value(), Some("10.0.0.1"));
assert_eq!(
info.user_agent_value(),
Some("Mozilla/5.0 (iPhone) Safari/605")
);
assert_eq!(info.device_name_value(), Some("Safari on iPhone"));
assert_eq!(info.device_type_value(), Some("mobile"));
assert!(info.fingerprint_value().is_some());
}
#[tokio::test]
async fn extracts_with_missing_fields() {
let req = http::Request::builder().body(()).unwrap();
let (mut parts, _) = req.into_parts();
let info = ClientInfo::from_request_parts(&mut parts, &())
.await
.unwrap();
assert!(info.ip_value().is_none());
assert!(info.user_agent_value().is_none());
assert!(info.device_name_value().is_none());
assert!(info.device_type_value().is_none());
assert!(info.fingerprint_value().is_some());
}
#[tokio::test]
async fn extracts_with_only_accept_headers() {
let req = http::Request::builder()
.header("accept-language", "en-US")
.header("accept-encoding", "gzip")
.body(())
.unwrap();
let (mut parts, _) = req.into_parts();
let info = ClientInfo::from_request_parts(&mut parts, &())
.await
.unwrap();
assert!(info.user_agent_value().is_none());
assert!(info.device_name_value().is_none());
assert!(info.device_type_value().is_none());
assert_eq!(info.fingerprint_value().map(str::len), Some(64));
}
}