1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use std::{io, net::SocketAddr, sync::Arc};

use hyper::{
    server::{
        conn::{AddrIncoming, AddrStream},
        Server as HyperServer,
    },
    service::{make_service_fn, service_fn},
};

use viz_core::{
    config::Config,
    http,
    types::{Params, State, StateFactory},
    Context, Error, Result,
};
use viz_router::{Method, Router, Tree};
use viz_utils::{anyhow::anyhow, tracing};

/// Viz Server
pub struct Server {
    tree: Arc<Tree>,
    config: Option<Arc<Config>>,
    state: Option<Vec<Arc<dyn StateFactory>>>,
}

impl Default for Server {
    fn default() -> Self {
        Self::new()
    }
}

impl Server {
    pub fn new() -> Self {
        Self { state: None, config: None, tree: Arc::new(Tree::new()) }
    }

    pub fn state<T>(&mut self, state: T) -> &mut Self
    where
        T: Clone + Send + Sync + 'static,
    {
        self.state.get_or_insert_with(Vec::new).push(Arc::new(State::new(state)));
        self
    }

    pub fn routes(&mut self, router: Router) -> &mut Self {
        router.finish(Arc::get_mut(&mut self.tree).unwrap());
        self
    }

    pub async fn config(&mut self) -> Arc<Config> {
        tracing::info!("loading config");
        self.config.replace(Arc::new(Config::load().await.unwrap_or_default()));
        self.config.clone().unwrap()
    }

    pub async fn listen<A: ToString>(self, addr: A) -> Result<()> {
        let addr = addr
            .to_string()
            .parse::<SocketAddr>()
            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;

        let mut incoming = AddrIncoming::bind(&addr)?;
        let addr = incoming.local_addr();
        incoming.set_nodelay(true);

        let config = self.config.unwrap_or_default();
        let state = self.state.unwrap_or_default();
        let tree = self.tree;
        let srv =
            HyperServer::builder(incoming).serve(make_service_fn(move |stream: &AddrStream| {
                let addr = stream.remote_addr();
                let config = config.clone();
                let state = state.clone();
                let tree = tree.clone();
                async move {
                    Ok::<_, Error>(service_fn(move |req| {
                        serve(req, addr, state.clone(), config.clone(), tree.clone())
                    }))
                }
            }));

        tracing::info!("listening on http://{}", addr);

        srv.await.map_err(|e| anyhow!(e))
    }
}

/// Serves a request and returns a response.
pub async fn serve(
    req: http::Request,
    addr: SocketAddr,
    state: Vec<Arc<dyn StateFactory>>,
    config: Arc<Config>,
    tree: Arc<Tree>,
) -> Result<http::Response> {
    let mut cx = Context::from(req);
    cx.extensions_mut().insert(addr);
    cx.extensions_mut().insert(config);
    for t in state.iter() {
        t.create(cx.extensions_mut());
    }

    let method = cx.method().to_owned();
    let path = cx.path();

    if let Some((handler, params)) = tree
        .get(&Method::Verb(method.to_owned()))
        .and_then(|t| t.find(path))
        .or_else(|| {
            if method == http::Method::HEAD {
                tree.get(&Method::Verb(http::Method::GET)).and_then(|t| t.find(path))
            } else {
                None
            }
        })
        .or_else(|| tree.get(&Method::All).and_then(|t| t.find(path)))
    {
        let params: Params = params.into();
        *cx.middleware_mut() = handler.to_owned();
        cx.extensions_mut().insert(params);
    }

    Ok(cx.next().await?.into())
}