use std::net::SocketAddr;
use std::sync::Mutex;
use http::{Extensions, HeaderMap, Method, Uri};
use hyper::upgrade::OnUpgrade;
use crate::body::ReqBody;
use crate::error::{Error, Result};
use crate::state::{AppStateRef, StateMap};
use crate::ws::Upgrade;
pub mod body;
pub mod header;
pub mod path;
pub mod valid;
pub use header::{BearerToken, LastEventId, SseResume};
pub use path::{__extract_path_param, FromPathParam};
pub use valid::Valid;
#[derive(Debug, Default, Clone)]
pub struct PathParams {
entries: Vec<(String, String)>,
}
impl PathParams {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, name: String, value: String) {
self.entries.push((name, value));
}
pub fn get(&self, name: &str) -> Option<&str> {
self.entries
.iter()
.find(|(key, _)| key == name)
.map(|(_, value)| value.as_str())
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn len(&self) -> usize {
self.entries.len()
}
}
pub struct RequestContext {
head: http::request::Parts,
path_params: PathParams,
state: AppStateRef,
body: Mutex<Option<ReqBody>>,
upgrade: Mutex<Option<Upgrade>>,
}
#[derive(Clone, Copy)]
pub(crate) struct RequestPeerAddr(pub(crate) SocketAddr);
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum RequestScheme {
Http,
Https,
}
impl RequestScheme {
pub(crate) fn as_str(self) -> &'static str {
match self {
RequestScheme::Http => "http",
RequestScheme::Https => "https",
}
}
}
pub(crate) fn peer_addr_from_extensions(extensions: &Extensions) -> Option<SocketAddr> {
extensions.get::<RequestPeerAddr>().map(|peer| peer.0)
}
pub(crate) fn scheme_from_extensions(extensions: &Extensions) -> Option<RequestScheme> {
extensions.get::<RequestScheme>().copied()
}
impl RequestContext {
pub fn new(
mut head: http::request::Parts,
path_params: PathParams,
state: AppStateRef,
body: ReqBody,
) -> Self {
let upgrade = head.extensions.remove::<OnUpgrade>().map(Upgrade::Hyper);
Self {
head,
path_params,
state,
body: Mutex::new(Some(body)),
upgrade: Mutex::new(upgrade),
}
}
#[allow(dead_code)]
pub(crate) fn with_duplex_upgrade(
head: http::request::Parts,
path_params: PathParams,
state: AppStateRef,
body: ReqBody,
duplex: tokio::io::DuplexStream,
) -> Self {
Self {
head,
path_params,
state,
body: Mutex::new(Some(body)),
upgrade: Mutex::new(Some(Upgrade::Duplex(duplex))),
}
}
pub fn method(&self) -> &Method {
&self.head.method
}
pub fn uri(&self) -> &Uri {
&self.head.uri
}
pub fn headers(&self) -> &HeaderMap {
&self.head.headers
}
pub fn peer_addr(&self) -> Option<SocketAddr> {
peer_addr_from_extensions(&self.head.extensions)
}
pub fn scheme(&self) -> Option<&'static str> {
scheme_from_extensions(&self.head.extensions).map(RequestScheme::as_str)
}
pub fn head(&self) -> &http::request::Parts {
&self.head
}
pub fn state(&self) -> &StateMap {
self.state.as_ref()
}
pub fn resource<T: Clone + Send + Sync + 'static>(&self) -> Result<T> {
self.state()
.get::<T>()
.map(|value| (*value).clone())
.ok_or_else(|| {
Error::internal(format!(
"resource `{}` was not registered",
std::any::type_name::<T>()
))
.with_code("MISSING_RESOURCE")
})
}
pub fn path_params(&self) -> &PathParams {
&self.path_params
}
pub fn path_param(&self, name: &str) -> Option<&str> {
self.path_params.get(name)
}
pub fn take_body(&self) -> Result<ReqBody> {
self.body
.lock()
.expect("request body mutex poisoned")
.take()
.ok_or_else(|| Error::bad_request("request body has already been consumed"))
}
pub(crate) fn take_upgrade(&self) -> Result<Upgrade> {
self.upgrade
.lock()
.expect("request upgrade mutex poisoned")
.take()
.ok_or_else(|| {
Error::bad_request("request is not a WebSocket upgrade").with_code("NOT_AN_UPGRADE")
})
}
}
pub trait FromRequest: Sized + Send {
fn from_request(ctx: &RequestContext)
-> impl std::future::Future<Output = Result<Self>> + Send;
}
impl<T: Send + Sync + 'static> FromRequest for std::sync::Arc<T> {
fn from_request(
ctx: &RequestContext,
) -> impl std::future::Future<Output = Result<Self>> + Send {
let resolved = ctx.resource::<std::sync::Arc<T>>();
async move { resolved }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::body::box_body;
use crate::error::ErrorKind;
use bytes::Bytes;
use http_body_util::Full;
use std::sync::Arc;
fn test_context(path_params: PathParams, body: &'static str) -> RequestContext {
let head = http::Request::new(()).into_parts().0;
let body = box_body(Full::new(Bytes::from_static(body.as_bytes())));
RequestContext::new(head, path_params, Arc::new(StateMap::new()), body)
}
#[test]
fn path_param_lookup_and_parse() {
let mut params = PathParams::new();
params.push("user_id".to_owned(), "42".to_owned());
let ctx = test_context(params, "");
let parsed: i64 = __extract_path_param(&ctx, "user_id").unwrap();
assert_eq!(parsed, 42);
}
#[test]
fn invalid_path_param_is_unprocessable() {
let mut params = PathParams::new();
params.push("user_id".to_owned(), "not-a-number".to_owned());
let ctx = test_context(params, "");
let error = __extract_path_param::<i64>(&ctx, "user_id").unwrap_err();
assert_eq!(error.kind(), ErrorKind::Unprocessable);
}
#[test]
fn take_upgrade_errors_without_an_upgrade() {
let ctx = test_context(PathParams::new(), "");
let error = ctx
.take_upgrade()
.err()
.expect("should error without an upgrade");
assert_eq!(error.code(), "NOT_AN_UPGRADE");
}
#[test]
fn body_can_only_be_taken_once() {
let ctx = test_context(PathParams::new(), "hello");
assert!(ctx.take_body().is_ok());
let error = ctx.take_body().unwrap_err();
assert_eq!(error.kind(), ErrorKind::BadRequest);
}
#[test]
fn resource_is_cloned_from_registry() {
let mut map = StateMap::new();
map.insert(42_i64);
let head = http::Request::new(()).into_parts().0;
let body = box_body(Full::new(Bytes::from_static(b"")));
let ctx = RequestContext::new(head, PathParams::new(), Arc::new(map), body);
assert_eq!(ctx.resource::<i64>().unwrap(), 42);
assert!(ctx.resource::<String>().is_err());
}
}