roa-core 0.6.1

core components of roa web framework
Documentation
mod storage;

use std::any::Any;
use std::borrow::Cow;
use std::net::SocketAddr;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;

use http::header::AsHeaderName;
use http::{Method, StatusCode, Uri, Version};
pub use storage::Variable;
use storage::{Storage, Value};

use crate::{status, Executor, Request, Response};

/// A structure to share request, response and other data between middlewares.
///
/// ### Example
///
/// ```rust
/// use roa_core::{App, Context, Next, Result};
/// use tracing::info;
/// use tokio::fs::File;
///
/// let app = App::new().gate(gate).end(end);
/// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
///     info!("{} {}", ctx.method(), ctx.uri());
///     next.await
/// }
///
/// async fn end(ctx: &mut Context) -> Result {
///     ctx.resp.write_reader(File::open("assets/welcome.html").await?);
///     Ok(())
/// }
/// ```
pub struct Context<S = ()> {
    /// The request, to read http method, uri, version, headers and body.
    pub req: Request,

    /// The response, to set http status, version, headers and body.
    pub resp: Response,

    /// The executor, to spawn futures or blocking works.
    pub exec: Executor,

    /// Socket addr of last client or proxy.
    pub remote_addr: SocketAddr,

    storage: Storage,
    state: S,
}

impl<S> Context<S> {
    /// Construct a context from a request, an app and a addr_stream.
    #[inline]
    pub(crate) fn new(request: Request, state: S, exec: Executor, remote_addr: SocketAddr) -> Self {
        Self {
            req: request,
            resp: Response::default(),
            state,
            exec,
            storage: Storage::default(),
            remote_addr,
        }
    }

    /// Clone URI.
    ///
    /// ### Example
    /// ```rust
    /// use roa_core::{App, Context, Result};
    ///
    /// let app = App::new().end(get);
    ///
    /// async fn get(ctx: &mut Context) -> Result {
    ///     assert_eq!("/", ctx.uri().to_string());
    ///     Ok(())
    /// }
    /// ```
    #[inline]
    pub fn uri(&self) -> &Uri {
        &self.req.uri
    }

    /// Clone request::method.
    ///
    /// ### Example
    /// ```rust
    /// use roa_core::{App, Context, Result};
    /// use roa_core::http::Method;
    ///
    /// let app = App::new().end(get);
    ///
    /// async fn get(ctx: &mut Context) -> Result {
    ///     assert_eq!(Method::GET, ctx.method());
    ///     Ok(())
    /// }
    /// ```
    #[inline]
    pub fn method(&self) -> &Method {
        &self.req.method
    }

    /// Search for a header value and try to get its string reference.
    ///
    /// ### Example
    /// ```rust
    /// use roa_core::{App, Context, Result};
    /// use roa_core::http::header::CONTENT_TYPE;
    ///
    /// let app = App::new().end(get);
    ///
    /// async fn get(ctx: &mut Context) -> Result {
    ///     assert_eq!(
    ///         Some("text/plain"),
    ///         ctx.get(CONTENT_TYPE),
    ///     );
    ///     Ok(())
    /// }
    /// ```
    #[inline]
    pub fn get(&self, name: impl AsHeaderName) -> Option<&str> {
        self.req
            .headers
            .get(name)
            .and_then(|value| value.to_str().ok())
    }

    /// Search for a header value and get its string reference.
    ///
    /// Otherwise return a 400 BAD REQUEST.
    ///
    /// ### Example
    /// ```rust
    /// use roa_core::{App, Context, Result};
    /// use roa_core::http::header::CONTENT_TYPE;
    ///
    /// let app = App::new().end(get);
    ///
    /// async fn get(ctx: &mut Context) -> Result {
    ///     assert_eq!(
    ///         "text/plain",
    ///         ctx.must_get(CONTENT_TYPE)?,
    ///     );
    ///     Ok(())
    /// }
    /// ```
    #[inline]
    pub fn must_get(&self, name: impl AsHeaderName) -> crate::Result<&str> {
        let value = self
            .req
            .headers
            .get(name)
            .ok_or_else(|| status!(StatusCode::BAD_REQUEST))?;
        value
            .to_str()
            .map_err(|err| status!(StatusCode::BAD_REQUEST, err))
    }
    /// Clone response::status.
    ///
    /// ### Example
    /// ```rust
    /// use roa_core::{App, Context, Result};
    /// use roa_core::http::StatusCode;
    ///
    /// let app = App::new().end(get);
    ///
    /// async fn get(ctx: &mut Context) -> Result {
    ///     assert_eq!(StatusCode::OK, ctx.status());
    ///     Ok(())
    /// }
    /// ```
    #[inline]
    pub fn status(&self) -> StatusCode {
        self.resp.status
    }

    /// Clone request::version.
    ///
    /// ### Example
    /// ```rust
    /// use roa_core::{App, Context, Result};
    /// use roa_core::http::Version;
    ///
    /// let app = App::new().end(get);
    ///
    /// async fn get(ctx: &mut Context) -> Result {
    ///     assert_eq!(Version::HTTP_11, ctx.version());
    ///     Ok(())
    /// }
    /// ```
    #[inline]
    pub fn version(&self) -> Version {
        self.req.version
    }

    /// Store key-value pair in specific scope.
    ///
    /// ### Example
    /// ```rust
    /// use roa_core::{App, Context, Result, Next};
    ///
    /// struct Scope;
    /// struct AnotherScope;
    ///
    /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
    ///     ctx.store_scoped(Scope, "id", "1".to_string());
    ///     next.await
    /// }
    ///
    /// async fn end(ctx: &mut Context) -> Result {
    ///     assert_eq!(1, ctx.load_scoped::<Scope, String>("id").unwrap().parse::<i32>()?);
    ///     assert!(ctx.load_scoped::<AnotherScope, String>("id").is_none());
    ///     Ok(())
    /// }
    ///
    /// let app = App::new().gate(gate).end(end);
    /// ```
    #[inline]
    pub fn store_scoped<SC, K, V>(&mut self, scope: SC, key: K, value: V) -> Option<Arc<V>>
    where
        SC: Any,
        K: Into<Cow<'static, str>>,
        V: Value,
    {
        self.storage.insert(scope, key, value)
    }

    /// Store key-value pair in public scope.
    ///
    /// ### Example
    /// ```rust
    /// use roa_core::{App, Context, Result, Next};
    ///
    /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
    ///     ctx.store("id", "1".to_string());
    ///     next.await
    /// }
    ///
    /// async fn end(ctx: &mut Context) -> Result {
    ///     assert_eq!(1, ctx.load::<String>("id").unwrap().parse::<i32>()?);
    ///     Ok(())
    /// }
    ///
    /// let app = App::new().gate(gate).end(end);
    /// ```
    #[inline]
    pub fn store<K, V>(&mut self, key: K, value: V) -> Option<Arc<V>>
    where
        K: Into<Cow<'static, str>>,
        V: Value,
    {
        self.store_scoped(PublicScope, key, value)
    }

    /// Search for value by key in specific scope.
    ///
    /// ### Example
    ///
    /// ```rust
    /// use roa_core::{App, Context, Result, Next};
    ///
    /// struct Scope;
    ///
    /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
    ///     ctx.store_scoped(Scope, "id", "1".to_owned());
    ///     next.await
    /// }
    ///
    /// async fn end(ctx: &mut Context) -> Result {
    ///     assert_eq!(1, ctx.load_scoped::<Scope, String>("id").unwrap().parse::<i32>()?);
    ///     Ok(())
    /// }
    ///
    /// let app = App::new().gate(gate).end(end);
    /// ```
    #[inline]
    pub fn load_scoped<'a, SC, V>(&self, key: &'a str) -> Option<Variable<'a, V>>
    where
        SC: Any,
        V: Value,
    {
        self.storage.get::<SC, V>(key)
    }

    /// Search for value by key in public scope.
    ///
    /// ### Example
    /// ```rust
    /// use roa_core::{App, Context, Result, Next};
    ///
    /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
    ///     ctx.store("id", "1".to_string());
    ///     next.await
    /// }
    ///
    /// async fn end(ctx: &mut Context) -> Result {
    ///     assert_eq!(1, ctx.load::<String>("id").unwrap().parse::<i32>()?);
    ///     Ok(())
    /// }
    ///
    /// let app = App::new().gate(gate).end(end);
    /// ```
    #[inline]
    pub fn load<'a, V>(&self, key: &'a str) -> Option<Variable<'a, V>>
    where
        V: Value,
    {
        self.load_scoped::<PublicScope, V>(key)
    }
}

/// Public storage scope.
struct PublicScope;

impl<S> Deref for Context<S> {
    type Target = S;
    #[inline]
    fn deref(&self) -> &Self::Target {
        &self.state
    }
}

impl<S> DerefMut for Context<S> {
    #[inline]
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.state
    }
}

impl<S: Clone> Clone for Context<S> {
    #[inline]
    fn clone(&self) -> Self {
        Self {
            req: Request::default(),
            resp: Response::new(),
            state: self.state.clone(),
            exec: self.exec.clone(),
            storage: self.storage.clone(),
            remote_addr: self.remote_addr,
        }
    }
}

#[cfg(all(test, feature = "runtime"))]
mod tests_with_runtime {
    use std::error::Error;

    use http::{HeaderValue, StatusCode, Version};

    use crate::{App, Context, Next, Request, Status};

    #[tokio::test]
    async fn status_and_version() -> Result<(), Box<dyn Error>> {
        async fn test(ctx: &mut Context) -> Result<(), Status> {
            assert_eq!(Version::HTTP_11, ctx.version());
            assert_eq!(StatusCode::OK, ctx.status());
            Ok(())
        }
        let service = App::new().end(test).http_service();
        service.serve(Request::default()).await;
        Ok(())
    }

    #[derive(Clone)]
    struct State {
        data: usize,
    }

    #[tokio::test]
    async fn state() -> Result<(), Box<dyn Error>> {
        async fn gate(ctx: &mut Context<State>, next: Next<'_>) -> Result<(), Status> {
            ctx.data = 1;
            next.await
        }

        async fn test(ctx: &mut Context<State>) -> Result<(), Status> {
            assert_eq!(1, ctx.data);
            Ok(())
        }
        let service = App::state(State { data: 1 })
            .gate(gate)
            .end(test)
            .http_service();
        service.serve(Request::default()).await;
        Ok(())
    }

    #[tokio::test]
    async fn must_get() -> Result<(), Box<dyn Error>> {
        use http::header::{CONTENT_TYPE, HOST};
        async fn test(ctx: &mut Context) -> Result<(), Status> {
            assert_eq!(Ok("github.com"), ctx.must_get(HOST));
            ctx.must_get(CONTENT_TYPE)?;
            unreachable!()
        }
        let service = App::new().end(test).http_service();
        let mut req = Request::default();
        req.headers
            .insert(HOST, HeaderValue::from_static("github.com"));
        let resp = service.serve(req).await;
        assert_eq!(StatusCode::BAD_REQUEST, resp.status);
        Ok(())
    }
}