mod storage;
use std::any::Any;
use std::borrow::Cow;
use std::net::SocketAddr;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use http::header::AsHeaderName;
use http::{Method, StatusCode, Uri, Version};
pub use storage::Variable;
use storage::{Storage, Value};
use crate::{status, Executor, Request, Response};
pub struct Context<S = ()> {
pub req: Request,
pub resp: Response,
pub exec: Executor,
pub remote_addr: SocketAddr,
storage: Storage,
state: S,
}
impl<S> Context<S> {
#[inline]
pub(crate) fn new(request: Request, state: S, exec: Executor, remote_addr: SocketAddr) -> Self {
Self {
req: request,
resp: Response::default(),
state,
exec,
storage: Storage::default(),
remote_addr,
}
}
#[inline]
pub fn uri(&self) -> &Uri {
&self.req.uri
}
#[inline]
pub fn method(&self) -> &Method {
&self.req.method
}
#[inline]
pub fn get(&self, name: impl AsHeaderName) -> Option<&str> {
self.req
.headers
.get(name)
.and_then(|value| value.to_str().ok())
}
#[inline]
pub fn must_get(&self, name: impl AsHeaderName) -> crate::Result<&str> {
let value = self
.req
.headers
.get(name)
.ok_or_else(|| status!(StatusCode::BAD_REQUEST))?;
value
.to_str()
.map_err(|err| status!(StatusCode::BAD_REQUEST, err))
}
#[inline]
pub fn status(&self) -> StatusCode {
self.resp.status
}
#[inline]
pub fn version(&self) -> Version {
self.req.version
}
#[inline]
pub fn store_scoped<SC, K, V>(&mut self, scope: SC, key: K, value: V) -> Option<Arc<V>>
where
SC: Any,
K: Into<Cow<'static, str>>,
V: Value,
{
self.storage.insert(scope, key, value)
}
#[inline]
pub fn store<K, V>(&mut self, key: K, value: V) -> Option<Arc<V>>
where
K: Into<Cow<'static, str>>,
V: Value,
{
self.store_scoped(PublicScope, key, value)
}
#[inline]
pub fn load_scoped<'a, SC, V>(&self, key: &'a str) -> Option<Variable<'a, V>>
where
SC: Any,
V: Value,
{
self.storage.get::<SC, V>(key)
}
#[inline]
pub fn load<'a, V>(&self, key: &'a str) -> Option<Variable<'a, V>>
where
V: Value,
{
self.load_scoped::<PublicScope, V>(key)
}
}
struct PublicScope;
impl<S> Deref for Context<S> {
type Target = S;
#[inline]
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl<S> DerefMut for Context<S> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.state
}
}
impl<S: Clone> Clone for Context<S> {
#[inline]
fn clone(&self) -> Self {
Self {
req: Request::default(),
resp: Response::new(),
state: self.state.clone(),
exec: self.exec.clone(),
storage: self.storage.clone(),
remote_addr: self.remote_addr,
}
}
}
#[cfg(all(test, feature = "runtime"))]
mod tests_with_runtime {
use std::error::Error;
use http::{HeaderValue, StatusCode, Version};
use crate::{App, Context, Next, Request, Status};
#[tokio::test]
async fn status_and_version() -> Result<(), Box<dyn Error>> {
async fn test(ctx: &mut Context) -> Result<(), Status> {
assert_eq!(Version::HTTP_11, ctx.version());
assert_eq!(StatusCode::OK, ctx.status());
Ok(())
}
let service = App::new().end(test).http_service();
service.serve(Request::default()).await;
Ok(())
}
#[derive(Clone)]
struct State {
data: usize,
}
#[tokio::test]
async fn state() -> Result<(), Box<dyn Error>> {
async fn gate(ctx: &mut Context<State>, next: Next<'_>) -> Result<(), Status> {
ctx.data = 1;
next.await
}
async fn test(ctx: &mut Context<State>) -> Result<(), Status> {
assert_eq!(1, ctx.data);
Ok(())
}
let service = App::state(State { data: 1 })
.gate(gate)
.end(test)
.http_service();
service.serve(Request::default()).await;
Ok(())
}
#[tokio::test]
async fn must_get() -> Result<(), Box<dyn Error>> {
use http::header::{CONTENT_TYPE, HOST};
async fn test(ctx: &mut Context) -> Result<(), Status> {
assert_eq!(Ok("github.com"), ctx.must_get(HOST));
ctx.must_get(CONTENT_TYPE)?;
unreachable!()
}
let service = App::new().end(test).http_service();
let mut req = Request::default();
req.headers
.insert(HOST, HeaderValue::from_static("github.com"));
let resp = service.serve(req).await;
assert_eq!(StatusCode::BAD_REQUEST, resp.status);
Ok(())
}
}