mod body;
mod methods;
mod params;
use crate::extensions::Extensions;
use bytes::Bytes;
use hyper::{HeaderMap, Method, Uri, Version};
#[cfg(feature = "parsers")]
use reinhardt_core::parsers::parser::{ParsedData, Parser};
use std::collections::HashMap;
use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct TrustedProxies {
trusted_ips: HashSet<IpAddr>,
}
impl TrustedProxies {
pub fn none() -> Self {
Self {
trusted_ips: HashSet::new(),
}
}
pub fn new(ips: impl IntoIterator<Item = IpAddr>) -> Self {
Self {
trusted_ips: ips.into_iter().collect(),
}
}
pub fn is_trusted(&self, addr: &IpAddr) -> bool {
self.trusted_ips.contains(addr)
}
pub fn has_trusted_proxies(&self) -> bool {
!self.trusted_ips.is_empty()
}
}
pub struct Request {
pub method: Method,
pub uri: Uri,
pub version: Version,
pub headers: HeaderMap,
body: Bytes,
pub path_params: HashMap<String, String>,
pub query_params: HashMap<String, String>,
pub is_secure: bool,
pub remote_addr: Option<SocketAddr>,
#[cfg(feature = "parsers")]
parsers: Vec<Box<dyn Parser>>,
#[cfg(feature = "parsers")]
parsed_data: Arc<Mutex<Option<ParsedData>>>,
body_consumed: Arc<AtomicBool>,
pub extensions: Extensions,
}
pub struct RequestBuilder {
method: Method,
uri: Option<Uri>,
version: Version,
headers: HeaderMap,
body: Bytes,
is_secure: bool,
remote_addr: Option<SocketAddr>,
path_params: HashMap<String, String>,
uri_error: Option<String>,
header_error: Option<String>,
#[cfg(feature = "parsers")]
parsers: Vec<Box<dyn Parser>>,
}
impl Default for RequestBuilder {
fn default() -> Self {
Self {
method: Method::GET,
uri: None,
version: Version::HTTP_11,
headers: HeaderMap::new(),
body: Bytes::new(),
is_secure: false,
remote_addr: None,
path_params: HashMap::new(),
uri_error: None,
header_error: None,
#[cfg(feature = "parsers")]
parsers: Vec::new(),
}
}
}
impl RequestBuilder {
pub fn method(mut self, method: Method) -> Self {
self.method = method;
self
}
pub fn uri<T>(mut self, uri: T) -> Self
where
T: TryInto<Uri>,
T::Error: std::fmt::Display,
{
match uri.try_into() {
Ok(uri) => {
self.uri = Some(uri);
}
Err(e) => {
self.uri_error = Some(format!("Invalid URI: {}", e));
}
}
self
}
pub fn version(mut self, version: Version) -> Self {
self.version = version;
self
}
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.headers = headers;
self
}
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
K: hyper::header::IntoHeaderName,
V: TryInto<hyper::header::HeaderValue>,
V::Error: std::fmt::Display,
{
match value.try_into() {
Ok(val) => {
self.headers.insert(key, val);
}
Err(e) => {
self.header_error = Some(format!("Invalid header value: {}", e));
}
}
self
}
pub fn body(mut self, body: Bytes) -> Self {
self.body = body;
self
}
pub fn secure(mut self, is_secure: bool) -> Self {
self.is_secure = is_secure;
self
}
pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
self.remote_addr = Some(addr);
self
}
#[cfg(feature = "parsers")]
pub fn parser<P: Parser + 'static>(mut self, parser: P) -> Self {
self.parsers.push(Box::new(parser));
self
}
pub fn path_params(mut self, params: HashMap<String, String>) -> Self {
self.path_params = params;
self
}
pub fn build(self) -> Result<Request, String> {
if let Some(err) = self.uri_error {
return Err(err);
}
if let Some(err) = self.header_error {
return Err(err);
}
let uri = self.uri.ok_or_else(|| "URI is required".to_string())?;
let query_params = Request::parse_query_params(&uri);
Ok(Request {
method: self.method,
uri,
version: self.version,
headers: self.headers,
body: self.body,
path_params: self.path_params,
query_params,
is_secure: self.is_secure,
remote_addr: self.remote_addr,
#[cfg(feature = "parsers")]
parsers: self.parsers,
#[cfg(feature = "parsers")]
parsed_data: Arc::new(Mutex::new(None)),
body_consumed: Arc::new(AtomicBool::new(false)),
extensions: Extensions::new(),
})
}
}
impl Request {
pub fn builder() -> RequestBuilder {
RequestBuilder::default()
}
pub fn set_di_context<T: Send + Sync + 'static>(&mut self, ctx: T) {
self.extensions.insert(Arc::new(ctx));
}
pub fn get_di_context<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.extensions.get::<Arc<T>>()
}
pub fn extract_bearer_token(&self) -> Option<String> {
self.headers
.get(hyper::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.and_then(|auth_str| auth_str.strip_prefix("Bearer ").map(|s| s.to_string()))
}
pub fn get_header(&self, name: &str) -> Option<String> {
self.headers
.get(name)
.and_then(|value| value.to_str().ok())
.map(|s| s.to_string())
}
pub fn get_client_ip(&self) -> Option<std::net::IpAddr> {
if self.is_from_trusted_proxy() {
if let Some(forwarded) = self.get_header("x-forwarded-for") {
if let Some(first_ip) = forwarded.split(',').next()
&& let Ok(ip) = first_ip.trim().parse()
{
return Some(ip);
}
}
if let Some(real_ip) = self.get_header("x-real-ip")
&& let Ok(ip) = real_ip.parse()
{
return Some(ip);
}
}
self.remote_addr.map(|addr| addr.ip())
}
fn is_from_trusted_proxy(&self) -> bool {
if let Some(trusted) = self.extensions.get::<TrustedProxies>()
&& let Some(addr) = self.remote_addr
{
return trusted.is_trusted(&addr.ip());
}
false
}
pub fn set_trusted_proxies(&self, proxies: TrustedProxies) {
self.extensions.insert(proxies);
}
pub fn validate_content_type(&self, expected: &str) -> crate::Result<()> {
match self.get_header("content-type") {
Some(content_type) if content_type.starts_with(expected) => Ok(()),
Some(content_type) => Err(crate::Error::Http(format!(
"Invalid Content-Type: expected '{}', got '{}'",
expected, content_type
))),
None => Err(crate::Error::Http(
"Missing Content-Type header".to_string(),
)),
}
}
pub fn query_as<T: serde::de::DeserializeOwned>(&self) -> crate::Result<T> {
let params: Vec<(String, String)> = self
.query_params
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let encoded = serde_urlencoded::to_string(¶ms)
.map_err(|e| crate::Error::Http(format!("Failed to encode query parameters: {}", e)))?;
serde_urlencoded::from_str(&encoded)
.map_err(|e| crate::Error::Http(format!("Failed to parse query parameters: {}", e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, Version, header};
use rstest::rstest;
#[rstest]
fn test_extract_bearer_token() {
let mut headers = HeaderMap::new();
headers.insert(
header::AUTHORIZATION,
"Bearer test_token_123".parse().unwrap(),
);
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let token = request.extract_bearer_token();
assert_eq!(token, Some("test_token_123".to_string()));
}
#[rstest]
fn test_extract_bearer_token_missing() {
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let token = request.extract_bearer_token();
assert_eq!(token, None);
}
#[rstest]
fn test_get_header() {
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, "TestClient/1.0".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let user_agent = request.get_header("user-agent");
assert_eq!(user_agent, Some("TestClient/1.0".to_string()));
}
#[rstest]
fn test_get_header_missing() {
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let header = request.get_header("x-custom-header");
assert_eq!(header, None);
}
#[rstest]
fn test_get_client_ip_forwarded_for_with_trusted_proxy() {
let proxy_ip: std::net::IpAddr = "10.0.0.254".parse().unwrap();
let mut headers = HeaderMap::new();
headers.insert(
header::HeaderName::from_static("x-forwarded-for"),
"192.168.1.1, 10.0.0.1".parse().unwrap(),
);
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.remote_addr(std::net::SocketAddr::new(proxy_ip, 8080))
.build()
.unwrap();
request.set_trusted_proxies(TrustedProxies::new(vec![proxy_ip]));
let ip = request.get_client_ip();
assert_eq!(ip, Some("192.168.1.1".parse().unwrap()));
}
#[rstest]
fn test_get_client_ip_forwarded_for_without_trusted_proxy() {
let mut headers = HeaderMap::new();
headers.insert(
header::HeaderName::from_static("x-forwarded-for"),
"192.168.1.1, 10.0.0.1".parse().unwrap(),
);
let remote_ip: std::net::IpAddr = "10.0.0.254".parse().unwrap();
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.remote_addr(std::net::SocketAddr::new(remote_ip, 8080))
.build()
.unwrap();
let ip = request.get_client_ip();
assert_eq!(ip, Some(remote_ip));
}
#[rstest]
fn test_get_client_ip_real_ip_with_trusted_proxy() {
let proxy_ip: std::net::IpAddr = "10.0.0.254".parse().unwrap();
let mut headers = HeaderMap::new();
headers.insert(
header::HeaderName::from_static("x-real-ip"),
"203.0.113.5".parse().unwrap(),
);
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.remote_addr(std::net::SocketAddr::new(proxy_ip, 8080))
.build()
.unwrap();
request.set_trusted_proxies(TrustedProxies::new(vec![proxy_ip]));
let ip = request.get_client_ip();
assert_eq!(ip, Some("203.0.113.5".parse().unwrap()));
}
#[rstest]
fn test_get_client_ip_none() {
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let ip = request.get_client_ip();
assert_eq!(ip, None);
}
#[rstest]
fn test_validate_content_type_valid() {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
let request = Request::builder()
.method(Method::POST)
.uri("/")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
assert!(request.validate_content_type("application/json").is_ok());
}
#[rstest]
fn test_validate_content_type_invalid() {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "text/plain".parse().unwrap());
let request = Request::builder()
.method(Method::POST)
.uri("/")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
assert!(request.validate_content_type("application/json").is_err());
}
#[rstest]
fn test_validate_content_type_missing() {
let request = Request::builder()
.method(Method::POST)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
assert!(request.validate_content_type("application/json").is_err());
}
}