use std::{
error::Error,
io::{self, BufReader, BufWriter, Write},
net::{TcpListener, ToSocketAddrs},
time::{Duration, SystemTime},
};
use headers::{HeaderMapExt, HeaderValue};
use http::{Method, Request, Response, StatusCode, Version};
#[cfg(feature = "threadpool")]
use threadpool::ThreadPool;
use crate::{
body::HttpBody,
read_queue::ReadQueue,
request::{self, ParseError},
response::{self, Outcome},
Body, Connection,
};
type IncomingRequest = Request<Body>;
pub trait Service {
type Body: HttpBody;
type Error: Into<Box<dyn Error + Send + Sync>>;
fn call(&mut self, request: IncomingRequest) -> Result<Response<Self::Body>, Self::Error>;
fn should_continue(&mut self, _: &IncomingRequest) -> StatusCode {
StatusCode::CONTINUE
}
}
impl<F, Body, Err> Service for F
where
F: FnMut(IncomingRequest) -> Result<Response<Body>, Err>,
Body: HttpBody,
Err: Into<Box<dyn Error + Send + Sync>>,
{
type Body = Body;
type Error = Err;
fn call(&mut self, request: IncomingRequest) -> Result<Response<Self::Body>, Self::Error> {
self(request)
}
}
pub struct Server<'a> {
#[cfg(feature = "threadpool")]
thread_pool: ThreadPool,
incoming: Box<dyn Iterator<Item = Connection> + 'a>,
}
impl From<TcpListener> for Server<'static> {
fn from(listener: TcpListener) -> Self {
Self::builder().from_connections(TcpAcceptor { listener })
}
}
impl Server<'_> {
pub fn builder() -> ServerBuilder {
Default::default()
}
pub fn bind<A: ToSocketAddrs>(addr: A) -> Server<'static> {
Self::builder().bind(addr)
}
#[cfg(feature = "threadpool")]
pub fn serve<S>(self, service: S) -> io::Result<()>
where
S: Service,
S: Send + Clone + 'static,
{
for conn in self.incoming {
let mut app = service.clone();
self.thread_pool.execute(move || {
serve(conn, &mut app).ok();
});
}
Ok(())
}
pub fn serve_single_thread<S>(self, mut service: S) -> io::Result<()>
where
S: Service,
{
for conn in self.incoming {
serve(conn, &mut service).ok();
}
Ok(())
}
#[cfg(feature = "threadpool")]
pub fn make_service<M>(self, make_service: M) -> io::Result<()>
where
M: MakeService + 'static,
<M as MakeService>::Service: Send,
{
for conn in self.incoming {
if let Ok(mut handler) = make_service.call(&conn) {
self.thread_pool.execute(move || {
serve(conn, &mut handler).ok();
});
}
}
Ok(())
}
}
pub struct ServerBuilder {
#[cfg(feature = "threadpool")]
max_threads: usize,
read_timeout: Option<Duration>,
nodelay: bool,
}
impl Default for ServerBuilder {
fn default() -> Self {
Self {
#[cfg(feature = "threadpool")]
max_threads: 512,
read_timeout: None,
nodelay: false,
}
}
}
impl ServerBuilder {
#[cfg(feature = "threadpool")]
pub fn max_threads(self, max_threads: usize) -> Self {
Self {
max_threads,
..self
}
}
pub fn read_timeout<T: Into<Option<Duration>>>(self, timeout: T) -> Self {
Self {
read_timeout: timeout.into(),
..self
}
}
pub fn nodelay(self, nodelay: bool) -> Self {
Self { nodelay, ..self }
}
pub fn bind<A: ToSocketAddrs>(self, addr: A) -> Server<'static> {
self.try_bind(addr).unwrap()
}
pub fn try_bind<A: ToSocketAddrs>(self, addr: A) -> io::Result<Server<'static>> {
let listener = TcpListener::bind(addr)?;
Ok(self.from_connections(TcpAcceptor { listener }))
}
pub fn from_connections<'a, C: Into<Connection>>(
self,
conns: impl IntoIterator<Item = C> + 'a,
) -> Server<'a> {
Server {
#[cfg(feature = "threadpool")]
thread_pool: ThreadPool::new(self.max_threads),
incoming: Box::new(conns.into_iter().filter_map(move |conn| {
let conn = conn.into();
conn.set_read_timeout(self.read_timeout).ok()?;
conn.set_nodelay(self.nodelay).ok()?;
Some(conn)
})),
}
}
}
struct TcpAcceptor {
listener: TcpListener,
}
impl Iterator for TcpAcceptor {
type Item = Connection;
fn next(&mut self) -> Option<Self::Item> {
Some(self.listener.accept().ok()?.into())
}
}
pub trait MakeService {
type Service: Service;
type Error: Into<Box<dyn Error + Send + Sync>>;
fn call(&self, conn: &Connection) -> Result<Self::Service, Self::Error>;
}
impl<F, S, Err> MakeService for F
where
F: Fn(&Connection) -> Result<S, Err>,
Err: Into<Box<dyn Error + Send + Sync>>,
S: Service + Send,
{
type Service = S;
type Error = Err;
fn call(&self, conn: &Connection) -> Result<Self::Service, Self::Error> {
self(conn)
}
}
fn serve<A: Service>(conn: Connection, app: &mut A) -> io::Result<()> {
let mut read_queue = ReadQueue::new(BufReader::new(conn.clone()));
let mut reader = read_queue.enqueue();
let mut writer = BufWriter::new(conn);
loop {
match request::parse_request(reader) {
Ok(req) => {
reader = read_queue.enqueue();
let asks_for_close = req
.headers()
.typed_get::<headers::Connection>()
.filter(|conn| conn.contains("close"))
.is_some();
let asks_for_keep_alive = req
.headers()
.typed_get::<headers::Connection>()
.filter(|conn| conn.contains("keep-alive"))
.is_some();
let version = req.version();
let method = req.method().clone();
let demands_close = match version {
Version::HTTP_09 => true,
Version::HTTP_10 => !asks_for_keep_alive,
_ => asks_for_close,
};
let expects_continue = req
.headers()
.typed_get::<headers::Expect>()
.filter(|expect| expect == &headers::Expect::CONTINUE)
.is_some();
if expects_continue {
match app.should_continue(&req) {
status @ StatusCode::CONTINUE => {
let res = Response::builder().status(status).body(()).unwrap();
response::write_response(res, &mut writer, true)?;
writer.flush()?;
}
status => {
let res = Response::builder().status(status).body(()).unwrap();
response::write_response(res, &mut writer, true)?;
writer.flush()?;
continue;
}
};
}
let mut res = app
.call(req)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
*res.version_mut() = version;
if version == Version::HTTP_10 && !asks_for_keep_alive {
res.headers_mut()
.insert("connection", HeaderValue::from_static("close"));
}
if res.headers().typed_get::<headers::Date>().is_none() {
res.headers_mut()
.typed_insert(headers::Date::from(SystemTime::now()));
}
let should_write_body = match method {
Method::HEAD => false,
Method::CONNECT => res.status().is_success(),
_ => true,
};
match response::write_response(res, &mut writer, should_write_body)? {
Outcome::KeepAlive if demands_close => break,
Outcome::KeepAlive => writer.flush()?,
Outcome::Close => break,
Outcome::Upgrade(upgrade) => {
drop(reader);
drop(read_queue);
upgrade.handler.handle(writer.into_inner()?);
break;
}
}
}
Err(ParseError::ConnectionClosed) => break,
Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err)),
}
}
Ok(())
}