use crate::http::endpoints::{args::FromRequestRef, route::RoutePipeline};
use crate::{
HttpRequest, HttpRequestMut, HttpResult,
error::Error,
http::cors::{CorsHeaders, CorsOverride},
status,
};
use std::sync::Arc;
#[cfg(feature = "di")]
use crate::di::Container;
#[cfg(feature = "rate-limiting")]
use crate::rate_limiting::{GlobalRateLimiter, RateLimiter};
#[cfg(feature = "rate-limiting")]
use crate::http::request_scope::HttpRequestScope;
pub struct HttpContext {
request: HttpRequestMut,
pipeline: Option<RoutePipeline>,
cors: CorsOverride,
}
impl std::fmt::Debug for HttpContext {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("HttpContext(..)")
}
}
impl HttpContext {
#[inline]
pub(crate) fn new(
request: HttpRequest,
pipeline: Option<RoutePipeline>,
cors: CorsOverride,
) -> Self {
Self {
request: HttpRequestMut::new(request),
pipeline,
cors,
}
}
#[inline]
#[allow(dead_code)]
pub(crate) fn into_parts(self) -> (HttpRequestMut, Option<RoutePipeline>, CorsOverride) {
(self.request, self.pipeline, self.cors)
}
#[inline]
pub(crate) fn from_parts(
request: HttpRequestMut,
pipeline: Option<RoutePipeline>,
cors: CorsOverride,
) -> Self {
Self {
request,
pipeline,
cors,
}
}
#[inline]
pub fn extract<T: FromRequestRef>(&self) -> Result<T, Error> {
self.request.extract()
}
#[inline]
#[cfg(feature = "di")]
pub(crate) fn container(&self) -> Result<&Container, Error> {
self.request.extensions().try_into().map_err(Into::into)
}
#[inline]
#[cfg(feature = "di")]
pub fn resolve<T: Send + Sync + Clone + 'static>(&self) -> Result<T, Error> {
self.container()?.resolve::<T>().map_err(Into::into)
}
#[inline]
#[cfg(feature = "di")]
pub fn resolve_shared<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, Error> {
self.container()?.resolve_shared::<T>().map_err(Into::into)
}
#[inline]
#[cfg(feature = "rate-limiting")]
fn rate_limiter(&self) -> Option<&Arc<GlobalRateLimiter>> {
self.request
.extensions()
.get::<HttpRequestScope>()?
.rate_limiter
.as_ref()
}
#[inline]
#[cfg(feature = "rate-limiting")]
pub(crate) fn fixed_window_rate_limiter(
&self,
policy: Option<&str>,
) -> Option<&impl RateLimiter> {
self.rate_limiter()?.fixed_window(policy)
}
#[inline]
#[cfg(feature = "rate-limiting")]
pub(crate) fn sliding_window_rate_limiter(
&self,
policy: Option<&str>,
) -> Option<&impl RateLimiter> {
self.rate_limiter()?.sliding_window(policy)
}
#[inline]
#[cfg(feature = "rate-limiting")]
pub(crate) fn token_bucket_rate_limiter(
&self,
policy: Option<&str>,
) -> Option<&impl RateLimiter> {
self.rate_limiter()?.token_bucket(policy)
}
#[inline]
#[cfg(feature = "rate-limiting")]
pub(crate) fn gcra_rate_limiter(&self, policy: Option<&str>) -> Option<&impl RateLimiter> {
self.rate_limiter()?.gcra(policy)
}
#[inline]
pub fn request(&self) -> &HttpRequest {
self.request.as_read_only()
}
#[inline]
pub fn request_mut(&mut self) -> &mut HttpRequestMut {
&mut self.request
}
#[inline]
pub(crate) fn resolve_cors(
&self,
default: Option<&Arc<CorsHeaders>>,
) -> Option<Arc<CorsHeaders>> {
match &self.cors {
CorsOverride::Named(cors) => Some(cors.clone()),
CorsOverride::Inherit => default.cloned(),
CorsOverride::Disabled => None,
}
}
#[inline]
pub(crate) async fn execute(self) -> HttpResult {
let (request, pipeline, cors) = self.into_parts();
if let Some(pipeline) = pipeline {
pipeline
.call(Self {
request,
cors,
pipeline: None,
})
.await
} else {
status!(405)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::HttpBody;
use hyper::Request;
#[cfg(feature = "di")]
use std::collections::HashMap;
#[cfg(feature = "di")]
use std::sync::Mutex;
#[cfg(feature = "di")]
use crate::di::ContainerBuilder;
use crate::http::CorsConfig;
#[cfg(feature = "di")]
#[allow(dead_code)]
#[derive(Clone, Default)]
struct InMemoryCache {
inner: Arc<Mutex<HashMap<String, String>>>,
}
fn create_ctx() -> HttpContext {
let (parts, body) = Request::get("/")
.body(HttpBody::empty())
.unwrap()
.into_parts();
HttpContext::new(
HttpRequest::from_parts(parts, body),
None,
CorsOverride::Inherit,
)
}
#[test]
fn it_debugs() {
let ctx = create_ctx();
assert_eq!(format!("{ctx:?}"), "HttpContext(..)");
}
#[test]
fn it_splits_into_parts() {
let ctx = create_ctx();
let (parts, _, _) = ctx.into_parts();
assert_eq!(parts.uri(), "/")
}
#[test]
#[cfg(feature = "di")]
fn it_returns_err_if_there_is_no_di_container() {
let req = Request::get("http://localhost/")
.body(HttpBody::full("foo"))
.unwrap();
let (parts, body) = req.into_parts();
let http_req = HttpRequest::from_parts(parts, body);
let ctx = HttpContext::new(http_req, None, CorsOverride::Inherit);
assert!(ctx.container().is_err());
}
#[test]
#[cfg(feature = "di")]
fn it_resolves_from_di_container() {
let mut container = ContainerBuilder::new();
container.register_singleton(InMemoryCache::default());
let req = Request::get("http://localhost/")
.extension(container.build())
.body(HttpBody::full("foo"))
.unwrap();
let (parts, body) = req.into_parts();
let http_req = HttpRequest::from_parts(parts, body);
let ctx = HttpContext::new(http_req, None, CorsOverride::Inherit);
let cache = ctx.resolve::<InMemoryCache>();
assert!(cache.is_ok());
}
#[test]
#[cfg(feature = "di")]
fn it_resolves_shared_from_di_container() {
let mut container = ContainerBuilder::new();
container.register_singleton(InMemoryCache::default());
let req = Request::get("http://localhost/")
.extension(container.build())
.body(HttpBody::full("foo"))
.unwrap();
let (parts, body) = req.into_parts();
let http_req = HttpRequest::from_parts(parts, body);
let ctx = HttpContext::new(http_req, None, CorsOverride::Inherit);
let cache = ctx.resolve_shared::<InMemoryCache>();
assert!(cache.is_ok());
}
#[test]
fn it_resolves_cors() {
let req = Request::get("http://localhost/")
.body(HttpBody::full("foo"))
.unwrap();
let (parts, body) = req.into_parts();
let http_req = HttpRequest::from_parts(parts, body);
let permissive_cors = CorsConfig::default()
.with_name("permissive")
.with_any_method()
.with_any_header()
.with_any_origin()
.precompute();
let ctx = HttpContext::new(
http_req,
None,
CorsOverride::Named(Arc::new(permissive_cors)),
);
let resolved_cors = ctx.resolve_cors(None);
assert!(resolved_cors.is_some());
}
#[test]
fn it_resolves_default_cors() {
let req = Request::get("http://localhost/")
.body(HttpBody::full("foo"))
.unwrap();
let (parts, body) = req.into_parts();
let http_req = HttpRequest::from_parts(parts, body);
let default_cors = CorsConfig::default()
.with_methods(["GET", "POST"])
.with_any_header()
.with_any_origin()
.precompute();
let default_cors = Some(Arc::new(default_cors));
let ctx = HttpContext::new(http_req, None, CorsOverride::Inherit);
let resolved_cors = ctx.resolve_cors(default_cors.as_ref());
assert!(resolved_cors.is_some());
}
#[test]
fn it_resolves_disabled_cors() {
let req = Request::get("http://localhost/")
.body(HttpBody::full("foo"))
.unwrap();
let (parts, body) = req.into_parts();
let http_req = HttpRequest::from_parts(parts, body);
let default_cors = CorsConfig::default()
.with_methods(["GET", "POST"])
.with_any_header()
.with_any_origin()
.precompute();
let default_cors = Some(Arc::new(default_cors));
let ctx = HttpContext::new(http_req, None, CorsOverride::Disabled);
let resolved_cors = ctx.resolve_cors(default_cors.as_ref());
assert!(resolved_cors.is_none());
}
}