civet 0.12.0-alpha.5

civetweb-based server implementation for conduit
#![warn(rust_2018_idioms)]

extern crate civet_sys as _;
extern crate conduit;
extern crate libc;

use std::io::prelude::*;
use std::io::{self, BufWriter};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};

use conduit::{
    header, Body, Extensions, Handler, HeaderMap, Host, Method, Scheme, StartInstant, Version,
};

use raw::{get_header, get_headers, get_request_info};
use raw::{Header, RequestInfo};

pub use config::Config;

mod config;
mod raw;

pub struct Connection<'a> {
    request: CivetRequest<'a>,
    written: bool,
}

pub struct CivetRequest<'a> {
    conn: &'a raw::Connection,
    request_info: RequestInfo<'a>,
    headers: HeaderMap,
    extensions: Extensions,
    version: Version,
    method: Method,
    path_rewrite: Option<String>,
}

impl<'a> conduit::RequestExt for CivetRequest<'a> {
    fn http_version(&self) -> Version {
        self.version
    }

    fn method(&self) -> &Method {
        &self.method
    }

    fn scheme(&self) -> Scheme {
        if self.request_info.is_ssl() {
            Scheme::Https
        } else {
            Scheme::Http
        }
    }

    fn host(&self) -> Host<'_> {
        Host::Name(get_header(self.conn, header::HOST).unwrap())
    }

    fn virtual_root(&self) -> Option<&str> {
        None
    }

    fn path(&self) -> &str {
        if let Some(rewrite) = &self.path_rewrite {
            rewrite
        } else {
            self.request_info.url().unwrap()
        }
    }

    fn path_mut(&mut self) -> &mut String {
        if self.path_rewrite.is_none() {
            self.path_rewrite = Some(String::from(self.path()))
        }
        self.path_rewrite.as_mut().unwrap() // Unwrap is safe because Option is definitely Some
    }

    fn query_string(&self) -> Option<&str> {
        self.request_info.query_string()
    }

    fn remote_addr(&self) -> SocketAddr {
        let ip = self.request_info.remote_ip();
        let ip = Ipv4Addr::new(
            (ip >> 24) as u8,
            (ip >> 16) as u8,
            (ip >> 8) as u8,
            ip as u8,
        );
        SocketAddr::V4(SocketAddrV4::new(ip, self.request_info.remote_port()))
    }

    fn content_length(&self) -> Option<u64> {
        get_header(self.conn, header::CONTENT_LENGTH).and_then(|s| s.parse().ok())
    }

    fn headers(&self) -> &HeaderMap {
        &self.headers
    }

    fn body(&mut self) -> &mut dyn Read {
        self
    }

    fn extensions(&self) -> &Extensions {
        &self.extensions
    }

    fn mut_extensions(&mut self) -> &mut Extensions {
        &mut self.extensions
    }
}

impl<'a> Connection<'a> {
    fn new(conn: &raw::Connection) -> Result<Connection<'_>, String> {
        match request_info(conn) {
            Ok(info) => {
                let method = Method::from_bytes(info.method().unwrap_or_default())
                    .map_err(|e| e.to_string())?;

                let version = match info.http_version().unwrap() {
                    b"1.0" => Version::HTTP_10,
                    b"1.1" => Version::HTTP_11,
                    _ => Version::default(),
                };

                let mut headers = HeaderMap::new();
                for (name, value) in HeaderIterator::new(conn) {
                    headers.insert(
                        header::HeaderName::from_bytes(name).map_err(|e| e.to_string())?,
                        header::HeaderValue::from_bytes(value).map_err(|e| e.to_string())?,
                    );
                }

                let mut extensions = Extensions::new();
                extensions.insert(StartInstant::now());
                let request = CivetRequest {
                    conn,
                    request_info: info,
                    headers,
                    extensions,
                    method,
                    version,
                    path_rewrite: None,
                };

                Ok(Connection {
                    request,
                    written: false,
                })
            }
            Err(err) => Err(err),
        }
    }
}

impl<'a> Write for Connection<'a> {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        self.written = true;
        match raw::write(self.request.conn, buf) {
            n if n < 0 => Err(io::Error::new(
                io::ErrorKind::Other,
                &format!("write error ({})", n)[..],
            )),
            n => Ok(n as usize),
        }
    }
    fn flush(&mut self) -> io::Result<()> {
        Ok(())
    }
}

impl<'a> Read for CivetRequest<'a> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        match raw::read(self.conn, buf) {
            n if n < 0 => Err(io::Error::new(
                io::ErrorKind::Other,
                &format!("read error ({})", n)[..],
            )),
            n => Ok(n as usize),
        }
    }
}

impl<'a> Drop for Connection<'a> {
    fn drop(&mut self) {
        if !self.written {
            let _ = write!(
                self,
                "HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\n\r\n"
            );
        }
    }
}

struct HeaderIterator<'a> {
    headers: Vec<Header<'a>>,
    position: usize,
}

impl<'a> HeaderIterator<'a> {
    fn new(conn: &raw::Connection) -> HeaderIterator<'_> {
        HeaderIterator {
            headers: get_headers(conn),
            position: 0,
        }
    }
}

impl<'a> Iterator for HeaderIterator<'a> {
    type Item = (&'a [u8], &'a [u8]);
    fn next(&mut self) -> Option<Self::Item> {
        let pos = self.position;
        let headers = &self.headers;

        if self.headers.len() <= pos {
            None
        } else {
            let header = &headers[pos];
            self.position += 1;
            header.name().map(|name| (name, header.value().unwrap()))
        }
    }
}

pub struct Server(raw::Server<Box<dyn Handler + 'static + Sync>>);

impl Server {
    pub fn start<H: Handler + 'static + Sync>(options: Config, handler: H) -> io::Result<Server> {
        #[allow(clippy::borrowed_box)]
        fn internal_handler(
            conn: &mut raw::Connection,
            handler: &Box<dyn Handler + 'static + Sync>,
        ) -> Result<(), ()> {
            let mut connection = Connection::new(conn).unwrap();
            let response = handler.call(&mut connection.request);
            let mut writer = BufWriter::new(connection);

            fn err<W: Write>(writer: &mut W) {
                let _ = write!(
                    writer,
                    "HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\n\r\n"
                );
            }

            let (head, body) = match response {
                Ok(r) => r,
                Err(_) => {
                    err(&mut writer);
                    return Err(());
                }
            }
            .into_parts();

            write!(
                &mut writer,
                "HTTP/1.1 {} {}\r\n",
                head.status.as_str(),
                head.status.canonical_reason().unwrap_or("UNKNOWN")
            )
            .map_err(|_| ())?;

            for (key, value) in head.headers.iter() {
                write!(&mut writer, "{}: ", *key).map_err(|_| ())?;
                writer.write(value.as_bytes()).map_err(|_| ())?;
                writer.write(b"\r\n").map_err(|_| ())?;
            }

            write!(&mut writer, "\r\n").map_err(|_| ())?;
            match body {
                Body::Static(slice) => writer.write(slice).map(|_| ()),
                Body::Owned(vec) => writer.write(vec.as_ref()).map(|_| ()),
                Body::File(mut file) => io::copy(&mut file, &mut writer).map(|_| ()),
            }
            .map_err(|_| ())?;

            Ok(())
        }

        let handler = Box::new(handler);
        let raw_callback = raw::ServerCallback::new(internal_handler, handler);
        Ok(Server(raw::Server::start(options, raw_callback)?))
    }
}

fn request_info(connection: &raw::Connection) -> Result<RequestInfo<'_>, String> {
    match get_request_info(connection) {
        Some(info) => Ok(info),
        None => Err("Couldn't get request info for connection".to_string()),
    }
}

#[cfg(test)]
mod test {
    use super::{Config, Server};
    use conduit::{box_error, Body, Handler, HandlerResult, HttpResult, RequestExt, Response};
    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, TcpStream};
    use std::sync::atomic::{AtomicUsize, Ordering};
    use std::sync::mpsc::{channel, Sender};
    use std::sync::Mutex;

    fn noop(_: &mut dyn RequestExt) -> HttpResult {
        unreachable!()
    }

    fn request(addr: SocketAddr, req: &str) -> String {
        use std::io::{Read, Write};

        let mut s = TcpStream::connect(&addr).unwrap();
        s.write_all(req.trim_start().as_bytes()).unwrap();
        let mut ret = String::new();
        s.read_to_string(&mut ret).unwrap();
        ret
    }

    fn port() -> u16 {
        static CNT: AtomicUsize = AtomicUsize::new(0);
        CNT.fetch_add(1, Ordering::SeqCst) as u16 + 13038
    }

    fn cfg(port: u16) -> Config {
        let mut cfg = Config::new();
        cfg.port(port).threads(1);
        cfg
    }

    #[test]
    fn smoke() {
        Server::start(cfg(port()), noop).unwrap();
    }

    #[test]
    fn dupe_port() {
        let port = port();
        let s1 = Server::start(cfg(port), noop);
        assert!(s1.is_ok());
        let s2 = Server::start(cfg(port), noop);
        assert!(s2.is_err());
    }

    #[test]
    fn drops_handler() {
        static mut DROPPED: bool = false;
        struct Foo;
        impl Handler for Foo {
            fn call(&self, _req: &mut dyn RequestExt) -> HandlerResult {
                panic!()
            }
        }
        impl Drop for Foo {
            fn drop(&mut self) {
                unsafe {
                    DROPPED = true;
                }
            }
        }

        drop(Server::start(cfg(port()), Foo));
        unsafe {
            assert!(DROPPED);
        }
    }

    #[test]
    fn invokes() {
        struct Foo(Mutex<Sender<()>>);
        impl Handler for Foo {
            fn call(&self, _req: &mut dyn RequestExt) -> HandlerResult {
                let Foo(ref tx) = *self;
                tx.lock().unwrap().send(()).unwrap();
                Response::builder().body(Body::empty()).map_err(box_error)
            }
        }

        let (tx, rx) = channel();
        let handler = Foo(Mutex::new(tx));
        let port = port();
        let ip = Ipv4Addr::new(127, 0, 0, 1);
        let addr = SocketAddr::V4(SocketAddrV4::new(ip, port));
        let _s = Server::start(cfg(port), handler);
        request(
            addr,
            r"
GET / HTTP/1.1

",
        );
        rx.recv().unwrap();
    }

    #[test]
    fn header_sent() {
        struct Foo(Mutex<Sender<Vec<u8>>>);
        impl Handler for Foo {
            fn call(&self, req: &mut dyn RequestExt) -> HandlerResult {
                let Foo(ref tx) = *self;
                let mut header_val = Vec::new();
                header_val.extend_from_slice(req.headers().get("Foo").unwrap().as_bytes());
                tx.lock().unwrap().send(header_val).unwrap();
                Response::builder().body(Body::empty()).map_err(box_error)
            }
        }

        let (tx, rx) = channel();
        let handler = Foo(Mutex::new(tx));
        let port = port();
        let ip = Ipv4Addr::new(127, 0, 0, 1);
        let addr = SocketAddr::V4(SocketAddrV4::new(ip, port));
        let _s = Server::start(cfg(port), handler);
        request(
            addr,
            r"
GET / HTTP/1.1
Foo: bar

",
        );
        assert_eq!(rx.recv().unwrap(), b"bar");
    }

    #[test]
    fn failing_handler() {
        struct Foo;
        impl Handler for Foo {
            fn call(&self, _req: &mut dyn RequestExt) -> HandlerResult {
                panic!()
            }
        }

        let port = port();
        let ip = Ipv4Addr::new(127, 0, 0, 1);
        let addr = SocketAddr::V4(SocketAddrV4::new(ip, port));
        let _s = Server::start(cfg(port), Foo);
        request(
            addr,
            r"
GET / HTTP/1.1
Foo: bar

",
        );
    }

    #[test]
    fn failing_handler_is_500() {
        struct Foo;
        impl Handler for Foo {
            fn call(&self, _req: &mut dyn RequestExt) -> HandlerResult {
                panic!()
            }
        }

        let port = port();
        let ip = Ipv4Addr::new(127, 0, 0, 1);
        let addr = SocketAddr::V4(SocketAddrV4::new(ip, port));
        let _s = Server::start(cfg(port), Foo);
        let response = request(
            addr,
            r"
GET / HTTP/1.1
Foo: bar

",
        );
        assert!(
            response.contains("500 Internal"),
            "not a failing response: {}",
            response
        );
    }
}