use std::fmt;
use std::io::{self, ErrorKind, BufWriter, Write};
use std::net::{SocketAddr, ToSocketAddrs, Shutdown};
use std::sync::Arc;
use std::time::Duration;
use cogo::coroutine::yield_now;
pub use self::request::Request;
pub use self::response::Response;
pub use crate::net::{Fresh, Streaming};
use crate::{Error, runtime};
use crate::buffer::BufReader;
use crate::header::{Headers, Expect, Connection};
use crate::http;
use crate::method::Method;
use crate::net::{NetworkListener, NetworkStream, HttpListener, HttpsListener, SslServer};
use crate::status::StatusCode;
use crate::uri::RequestUri;
use crate::version::HttpVersion::Http11;
use self::listener::ListenerPool;
pub mod request;
pub mod response;
mod listener;
#[derive(Debug)]
pub struct Server<L = HttpListener> {
pub listener: L,
pub timeouts: Timeouts,
}
#[derive(Clone, Copy, Debug)]
pub struct Timeouts {
read: Option<Duration>,
keep_alive: Option<Duration>,
keep_alive_type: KeepAliveType,
}
#[derive(Clone, Copy, Debug)]
pub enum KeepAliveType {
WaitTime(Duration),
WaitError(i32),}
impl Default for Timeouts {
fn default() -> Timeouts {
Timeouts {
read: None,
keep_alive: Some(Duration::from_secs(5)),
keep_alive_type: KeepAliveType::WaitTime(Duration::from_secs(5)),
}
}
}
impl<L: NetworkListener> Server<L> {
#[inline]
pub fn new(listener: L) -> Server<L> {
Server {
listener: listener,
timeouts: Timeouts::default(),
}
}
#[inline]
pub fn keep_alive(mut self, timeout: Option<Duration>) -> Self {
self.timeouts.keep_alive = timeout;
self.timeouts.keep_alive_type = KeepAliveType::WaitTime(timeout.unwrap_or(Duration::from_secs(5)));
self
}
pub fn set_read_timeout(mut self, dur: Option<Duration>) -> Self {
self.listener.set_read_timeout(dur);
self.timeouts.read = dur;
self
}
pub fn set_write_timeout(mut self, dur: Option<Duration>) -> Self {
self.listener.set_write_timeout(dur);
self
}
pub fn set_keep_alive_type(mut self, t: KeepAliveType) -> Self {
self.timeouts.keep_alive_type = t;
self
}
pub fn local_addr(&mut self) -> io::Result<SocketAddr> {
self.listener.local_addr()
}
}
impl Server<HttpListener> {
pub fn http<To: ToSocketAddrs>(addr: To) -> crate::Result<Server<HttpListener>> {
HttpListener::new(addr).map(Server::new)
}
}
impl<S: SslServer + Clone + Send> Server<HttpsListener<S>> {
pub fn https<A: ToSocketAddrs>(addr: A, ssl: S) -> crate::Result<Server<HttpsListener<S>>> {
HttpsListener::new(addr, ssl).map(Server::new)
}
}
macro_rules! t_c {
($e: expr) => {
match $e {
Ok(val) => val,
Err(err) => {
error!("call = {:?}\nerr = {:?}", stringify!($e), err);
continue;
}
}
};
}
impl<L: NetworkListener + Send + 'static> Server<L> {
pub fn handle<H: Handler + 'static>(self, handler: H) -> crate::Result<Listening> {
Self::handle_stack(self, handler, 0x2000)
}
pub fn handle_stack<H: Handler + 'static>(self, handler: H, stack_size: usize) -> crate::Result<Listening> {
let worker = Arc::new(Worker::new(handler, self.timeouts));
let mut listener = self.listener.clone();
let h = runtime::spawn_stack_size(move || {
for stream in listener.incoming() {
let mut stream = t_c!(stream);
let w = worker.clone();
runtime::spawn_stack_size(move || {
{
#[cfg(unix)]
stream.set_nonblocking(true);
{
match w.timeouts.keep_alive_type {
KeepAliveType::WaitTime(timeout) => {
let mut now = std::time::Instant::now();
loop {
stream.reset_io();
let keep_alive = w.handle_connection(&mut stream);
stream.wait_io();
if keep_alive == false {
if now.elapsed() >= timeout {
return;
} else {
yield_now();
continue;
}
} else {
if now.elapsed() <= timeout {
now = std::time::Instant::now();
}
}
}
}
KeepAliveType::WaitError(total) => {
let mut count = 0;
loop {
stream.reset_io();
let keep_alive = w.handle_connection(&mut stream);
stream.wait_io();
if keep_alive == false {
count += 1;
if count >= total {
return;
}
yield_now();
}
}
}
}
}
}
}, stack_size);
}
}, stack_size);
let socket = r#try!(self.listener.clone().local_addr());
return Ok(Listening {
_guard: Some(h),
socket: socket,
});
}
pub fn handle_tasks<H: Handler + 'static>(self, handler: H, tasks: usize) -> crate::Result<Listening> {
handle_task(self, handler, tasks)
}
}
fn handle_task<H, L>(mut server: Server<L>, handler: H, tasks: usize) -> crate::Result<Listening>
where H: Handler + 'static, L: NetworkListener + Send + 'static {
let socket = r#try!(server.listener.local_addr());
debug!("tasks = {:?}", tasks);
let pool = ListenerPool::new(server.listener);
let worker = Worker::new(handler, server.timeouts);
let work = move |mut stream| {
worker.handle_connection(&mut stream);
};
let guard = runtime::spawn(move || {
pool.accept(work, tasks);
});
Ok(Listening {
_guard: Some(guard),
socket: socket,
})
}
pub struct Worker<H: Handler + 'static> {
handler: H,
timeouts: Timeouts,
}
impl<H: Handler + 'static> Worker<H> {
pub fn new(handler: H, timeouts: Timeouts) -> Worker<H> {
Worker {
handler: handler,
timeouts: timeouts,
}
}
pub fn handle_connection<S>(&self, stream: &mut S) -> bool where S: NetworkStream {
debug!("Incoming stream");
self.handler.on_connection_start();
let addr = match stream.peer_addr() {
Ok(addr) => addr,
Err(e) => {
info!("Peer Name error: {:?}", e);
return false;
}
};
let mut s: S = unsafe { std::mem::transmute_copy(stream) };
let stream2: &mut dyn NetworkStream = &mut s;
let mut rdr = BufReader::new(stream2);
let mut wrt = BufWriter::new(stream);
let mut keep_alive = false;
while self.keep_alive_loop(&mut rdr, &mut wrt, addr) {
if let Err(e) = self.set_read_timeout(*rdr.get_ref(), self.timeouts.keep_alive) {
info!("set_read_timeout keep_alive {:?}", e);
break;
}
keep_alive = true;
}
self.handler.on_connection_end();
debug!("keep_alive loop ending for {}", addr);
std::mem::forget(s);
keep_alive
}
fn set_read_timeout(&self, s: &dyn NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
s.set_read_timeout(timeout)
}
fn keep_alive_loop<W: Write>(&self, rdr: &mut BufReader<&mut dyn NetworkStream>,
wrt: &mut W, addr: SocketAddr) -> bool {
let req = match Request::new(rdr, addr) {
Ok(req) => req,
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
trace!("tcp closed, cancelling keep-alive loop");
return false;
}
Err(Error::Io(e)) => {
debug!("ioerror in keepalive loop = {:?}", e);
return false;
}
Err(e) => {
info!("request error = {:?}", e);
return false;
}
};
if !self.handle_expect(&req, wrt) {
return false;
}
if let Err(e) = req.set_read_timeout(self.timeouts.read) {
info!("set_read_timeout {:?}", e);
return false;
}
let mut keep_alive = self.timeouts.keep_alive.is_some() &&
http::should_keep_alive(req.version, &req.headers);
let version = req.version;
let mut res_headers = Headers::with_capacity(1);
if !keep_alive {
res_headers.set(Connection::close());
}
{
let mut res = Response::new(wrt, &mut res_headers);
res.version = version;
self.handler.handle(req, res);
}
if keep_alive {
keep_alive = http::should_keep_alive(version, &res_headers);
}
debug!("keep_alive = {:?} for {}", keep_alive, addr);
keep_alive
}
fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {
if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
let status = self.handler.check_continue((&req.method, &req.uri, &req.headers));
match write!(wrt, "{} {}\r\n\r\n", Http11, status).and_then(|_| wrt.flush()) {
Ok(..) => (),
Err(e) => {
info!("error writing 100-continue: {:?}", e);
return false;
}
}
if status != StatusCode::Continue {
debug!("non-100 status ({}) for Expect 100 request", status);
return false;
}
}
true
}
}
pub struct Listening {
_guard: Option<runtime::JoinHandle<()>>,
pub socket: SocketAddr,
}
impl fmt::Debug for Listening {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Listening {{ socket: {:?} }}", self.socket)
}
}
impl Drop for Listening {
fn drop(&mut self) {
let _ = self._guard.take().map(|g| g.join());
}
}
impl Listening {
pub fn close(&mut self) -> crate::Result<()> {
let _ = self._guard.take();
debug!("closing server");
Ok(())
}
}
pub trait Handler: Sync + Send {
fn handle(&self, req: Request, resp: Response<'_, Fresh>);
fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode {
StatusCode::Continue
}
fn on_connection_start(&self) {}
fn on_connection_end(&self) {}
}
impl<F> Handler for F where F: Fn(Request, Response<Fresh>), F: Sync + Send {
fn handle(&self, req: Request, res: Response) {
self(req, res)
}
}
#[cfg(test)]
mod tests {
use crate::header::Headers;
use crate::method::Method;
use crate::mock::MockStream;
use crate::status::StatusCode;
use crate::uri::RequestUri;
use super::{Request, Response, Fresh, Handler, Worker};
#[test]
fn test_check_continue_default() {
let mut mock = MockStream::with_input(b"\
POST /upload HTTP/1.1\r\n\
Host: example.domain\r\n\
Expect: 100-continue\r\n\
Content-Length: 10\r\n\
\r\n\
1234567890\
");
fn handle(_: Request, res: Response<Fresh>) {
res.start().unwrap().end().unwrap();
}
Worker::new(handle, Default::default()).handle_connection(&mut mock);
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
assert_eq!(&mock.write[..cont.len()], cont);
let res = b"HTTP/1.1 200 OK\r\n";
assert_eq!(&mock.write[cont.len()..cont.len() + res.len()], res);
}
#[test]
fn test_check_continue_reject() {
struct Reject;
impl Handler for Reject {
fn handle(&self, _: Request, res: Response<'_, Fresh>) {
res.start().unwrap().end().unwrap();
}
fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode {
StatusCode::ExpectationFailed
}
}
let mut mock = MockStream::with_input(b"\
POST /upload HTTP/1.1\r\n\
Host: example.domain\r\n\
Expect: 100-continue\r\n\
Content-Length: 10\r\n\
\r\n\
1234567890\
");
Worker::new(Reject, Default::default()).handle_connection(&mut mock);
assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]);
}
}