#[cfg(not(feature = "tls"))]
use std::io::prelude::*;
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::panic::catch_unwind;
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::thread::{JoinHandle, spawn};
#[cfg(feature = "tls")]
use {
super::{TlsConfig, TlsConfigProvider},
rustls::{ServerConnection, Stream as RustlsStream},
};
use crate::http::common::ReadWrite;
use crate::{Fail, Result};
use super::{ErrorHandler, Handler, HttpRequest, HttpSettings};
#[derive(Debug)]
pub struct HttpServer {
listener: Arc<RwLock<TcpListener>>,
settings: Arc<HttpSettings>,
handler: Handler,
error_handler: ErrorHandler,
threads: RwLock<Vec<JoinHandle<()>>>,
#[cfg(feature = "tls")]
tls_config: Option<TlsConfigProvider>,
}
impl HttpServer {
pub fn new(
addr: String,
settings: Arc<HttpSettings>,
handler: Handler,
error_handler: ErrorHandler,
#[cfg(feature = "tls")] tls_config: Option<TlsConfigProvider>,
) -> Result<Arc<Self>> {
let listener = TcpListener::bind(addr)?;
let server = Self {
listener: Arc::new(RwLock::new(listener)),
settings,
handler,
error_handler,
threads: RwLock::default(),
#[cfg(feature = "tls")]
tls_config,
};
let server = Arc::new(server);
use super::HttpThreads::{CONSTANT, SPAWN};
let (no_catch, threads) = match server.settings.threads {
SPAWN(threads) => (true, threads),
CONSTANT(threads) => (false, threads),
};
(0..threads).for_each(|_| {
let server_clone = server.clone();
server.threads_mut().unwrap().push(spawn(move || {
if no_catch {
accept_all(server_clone);
} else {
loop {
catch_unwind(|| accept_all(server_clone.clone())).ok();
eprintln!("HTTP thread panicked, restarting...");
}
}
}));
});
Ok(server)
}
pub fn settings(&self) -> &HttpSettings {
&self.settings
}
#[cfg(feature = "tls")]
pub fn tls_config(&self) -> Option<TlsConfig> {
self.tls_config.map(|f| f())
}
pub fn threads(&self) -> Result<RwLockReadGuard<Vec<JoinHandle<()>>>> {
self.threads.read().or_else(Fail::from)
}
pub fn threads_mut(&self) -> Result<RwLockWriteGuard<Vec<JoinHandle<()>>>> {
self.threads.write().or_else(Fail::from)
}
pub fn block(&self) -> Result<()> {
while let Some(thread) = self.threads_mut()?.pop() {
thread
.join()
.or_else(|_| Fail::from("listener thread crashed"))?;
}
Ok(())
}
}
fn process_request(
stream: &mut impl ReadWrite,
address: SocketAddr,
settings: &HttpSettings,
handler: Handler,
) -> Result<Vec<u8>> {
let (raw_header, partial_body) = read_header(stream, settings)?;
let request = HttpRequest::from(&raw_header, partial_body, stream, address, settings)?;
handler(request)
}
fn accept_all(server: Arc<HttpServer>) {
#[cfg(feature = "tls")]
let tls_config = server.tls_config();
loop {
if let Ok((stream, address)) = server.listener.read().unwrap().accept() {
let server = server.clone();
#[cfg(feature = "tls")]
let tls_config = tls_config.clone();
use super::HttpThreads::{CONSTANT, SPAWN};
match server.settings.threads {
SPAWN(_) => {
spawn(move || {
accepted(
&server,
stream,
address,
#[cfg(feature = "tls")]
tls_config.clone(),
)
.ok();
});
}
CONSTANT(_) => {
accepted(
&server,
stream,
address,
#[cfg(feature = "tls")]
tls_config.clone(),
)
.ok();
}
};
}
}
}
fn accepted(
server: &HttpServer,
mut stream: TcpStream,
address: SocketAddr,
#[cfg(feature = "tls")] tls_config: Option<TlsConfig>,
) -> Result<()> {
stream.set_read_timeout(server.settings.read_timeout)?;
stream.set_write_timeout(server.settings.write_timeout)?;
#[cfg(feature = "tls")]
let mut session;
#[cfg(feature = "tls")]
let mut stream: Box<dyn ReadWrite> = match tls_config.clone() {
Some(tls_config) => {
session = ServerConnection::new(tls_config)
.or_else(|_| Fail::from("could not initialize server connection"))?;
Box::new(RustlsStream::new(&mut session, &mut stream))
}
None => Box::new(stream),
};
let response = match process_request(&mut stream, address, server.settings(), server.handler) {
Ok(response) => response,
Err(err) => (server.error_handler)(err),
};
stream.write_all(&response)?;
stream.flush().or_else(Fail::from)
}
fn read_header(
stream: &mut impl ReadWrite,
http_settings: &HttpSettings,
) -> Result<(String, Vec<u8>)> {
let mut header = Vec::new();
let mut rest = Vec::new();
let mut buf = vec![0u8; http_settings.header_buffer];
let mut read_fails = 0;
'l: loop {
let length = stream.read(&mut buf)?;
if header.len() + length > http_settings.max_header_size {
return Fail::from("Max header size exceeded");
}
let buf = &buf[0..length];
'f: for (i, &b) in buf.iter().enumerate() {
if b == b'\r' {
if buf.len() < i + 4 {
let mut buf_temp = vec![0u8; i + 4 - buf.len()];
stream.read_exact(&mut buf_temp)?;
let mut buf2 = [buf, &buf_temp].concat();
let header_end =
buf2[i + 1] == b'\n' && buf2[i + 2] == b'\r' && buf2[i + 3] == b'\n';
header.append(&mut buf2);
if header_end {
break 'l;
} else {
break 'f;
}
} else if buf[i + 1] == b'\n' && buf[i + 2] == b'\r' && buf[i + 3] == b'\n' {
let (split1, split2) = buf.split_at(i + 4);
header.extend_from_slice(split1);
rest.extend_from_slice(split2);
break 'l;
}
}
if buf.len() == i + 1 {
header.extend_from_slice(buf);
}
}
if length < http_settings.header_buffer {
read_fails += 1;
if read_fails > http_settings.header_read_attempts {
return Fail::from("Read header failed too often");
}
}
}
Ok((
match String::from_utf8(header) {
Ok(header) => header,
Err(err) => return Fail::from(err),
},
rest,
))
}