#![allow(clippy::new_without_default)]
use crate::http::cors::Cors;
use crate::http::date::DateTime;
use crate::http::headers::HeaderType;
use crate::http::method::Method;
use crate::http::request::{Request, RequestError};
use crate::http::response::Response;
use crate::http::status::StatusCode;
use crate::krauss::wildcard_match;
use crate::monitor::event::{Event, EventType};
use crate::monitor::MonitorConfig;
use crate::route::{Route, RouteHandler, SubApp};
use crate::stream::Stream;
use crate::thread::pool::ThreadPool;
use std::io::Write;
use std::net::{TcpListener, TcpStream, ToSocketAddrs};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "tls")]
use rustls::ServerConfig;
pub struct App<State = ()>
where
State: Send + Sync + 'static,
{
thread_pool: ThreadPool,
subapps: Vec<SubApp<State>>,
default_subapp: SubApp<State>,
error_handler: ErrorHandler,
state: Arc<State>,
monitor: MonitorConfig,
connection_handler: ConnectionHandler<State>,
connection_condition: ConnectionCondition<State>,
connection_timeout: Option<Duration>,
#[cfg(feature = "tls")]
tls_config: Option<Arc<ServerConfig>>,
#[cfg(feature = "tls")]
force_https: bool,
}
pub type ConnectionHandler<State> = fn(
Stream,
Arc<Vec<SubApp<State>>>,
Arc<SubApp<State>>,
Arc<ErrorHandler>,
Arc<State>,
MonitorConfig,
Option<Duration>,
);
pub type ConnectionCondition<State> = fn(&mut TcpStream, Arc<State>) -> bool;
pub use crate::handler_traits::*;
pub type ErrorHandler = fn(StatusCode) -> Response;
pub type HumphreyError = Box<dyn std::error::Error>;
impl<State> App<State>
where
State: Send + Sync + 'static,
{
pub fn new() -> Self
where
State: Default,
{
Self {
thread_pool: ThreadPool::new(32),
subapps: Vec::new(),
default_subapp: SubApp::default(),
error_handler,
state: Arc::new(State::default()),
monitor: MonitorConfig::default(),
connection_handler: client_handler,
connection_condition: |_, _| true,
connection_timeout: None,
#[cfg(feature = "tls")]
tls_config: None,
#[cfg(feature = "tls")]
force_https: false,
}
}
pub fn new_with_config(threads: usize, state: State) -> Self {
Self {
thread_pool: ThreadPool::new(threads),
subapps: Vec::new(),
default_subapp: SubApp::default(),
error_handler,
state: Arc::new(state),
monitor: MonitorConfig::default(),
connection_handler: client_handler,
connection_condition: |_, _| true,
connection_timeout: None,
#[cfg(feature = "tls")]
tls_config: None,
#[cfg(feature = "tls")]
force_https: false,
}
}
pub fn run<A>(mut self, addr: A) -> Result<(), HumphreyError>
where
A: ToSocketAddrs,
{
let socket = TcpListener::bind(addr)?;
let subapps = Arc::new(self.subapps);
let default_subapp = Arc::new(self.default_subapp);
let error_handler = Arc::new(self.error_handler);
self.thread_pool.register_monitor(self.monitor.clone());
self.thread_pool.start();
for stream in socket.incoming() {
match stream {
Ok(mut stream) => {
let cloned_state = self.state.clone();
if (self.connection_condition)(&mut stream, cloned_state) {
let cloned_state = self.state.clone();
let cloned_monitor = self.monitor.clone();
let cloned_subapps = subapps.clone();
let cloned_default_subapp = default_subapp.clone();
let cloned_error_handler = error_handler.clone();
let cloned_handler = self.connection_handler;
let cloned_timeout = self.connection_timeout;
cloned_monitor.send(
Event::new(EventType::ConnectionSuccess)
.with_peer_result(stream.peer_addr()),
);
self.thread_pool.execute(move || {
cloned_monitor.send(
Event::new(EventType::ThreadPoolProcessStarted)
.with_peer_result(stream.peer_addr()),
);
(cloned_handler)(
Stream::Tcp(stream),
cloned_subapps,
cloned_default_subapp,
cloned_error_handler,
cloned_state,
cloned_monitor,
cloned_timeout,
)
});
} else {
self.monitor.send(
Event::new(EventType::ConnectionDenied)
.with_peer_result(stream.peer_addr()),
);
}
}
Err(e) => self
.monitor
.send(Event::new(EventType::ConnectionError).with_info(e.to_string())),
}
}
Ok(())
}
#[cfg(feature = "tls")]
pub fn run_tls<A>(mut self, addr: A) -> Result<(), HumphreyError>
where
A: ToSocketAddrs,
{
use rustls::ServerConnection;
let socket = TcpListener::bind(addr)?;
let subapps = Arc::new(self.subapps);
let default_subapp = Arc::new(self.default_subapp);
let error_handler = Arc::new(self.error_handler);
self.thread_pool.register_monitor(self.monitor.clone());
self.thread_pool.start();
if self.force_https {
let cloned_monitor = self.monitor.clone();
if self.thread_pool.thread_count() < 2 {
println!("Error: A minimum of two threads are required to force HTTPS since one is required for redirects.");
std::process::exit(1);
}
self.thread_pool
.execute(|| force_https_thread(cloned_monitor).unwrap_or(()));
}
for sock in socket.incoming() {
match sock {
Ok(mut sock) => {
let cloned_state = self.state.clone();
if (self.connection_condition)(&mut sock, cloned_state) {
let cloned_state = self.state.clone();
let cloned_subapps = subapps.clone();
let cloned_default_subapp = default_subapp.clone();
let cloned_error_handler = error_handler.clone();
let cloned_handler = self.connection_handler;
let cloned_timeout = self.connection_timeout;
let cloned_monitor = self.monitor.clone();
let cloned_config = self
.tls_config
.as_ref()
.expect("TLS certificate not supplied")
.clone();
cloned_monitor.send(
Event::new(EventType::ConnectionSuccess)
.with_peer_result(sock.peer_addr()),
);
self.thread_pool.execute(move || {
cloned_monitor.send(
Event::new(EventType::ThreadPoolProcessStarted)
.with_peer_result(sock.peer_addr()),
);
let server = ServerConnection::new(cloned_config).unwrap();
let tls_stream = rustls::StreamOwned::new(server, sock);
let stream = Stream::Tls(tls_stream);
(cloned_handler)(
stream,
cloned_subapps,
cloned_default_subapp,
cloned_error_handler,
cloned_state,
cloned_monitor,
cloned_timeout,
)
});
} else {
self.monitor.send(
Event::new(EventType::ConnectionDenied)
.with_peer_result(sock.peer_addr()),
);
}
}
Err(e) => self
.monitor
.send(Event::new(EventType::ConnectionError).with_info(e.to_string())),
}
}
Ok(())
}
pub fn with_state(mut self, state: State) -> Self {
self.state = Arc::new(state);
self
}
pub fn with_host(mut self, host: &str, mut handler: SubApp<State>) -> Self {
if host == "*" {
panic!("Cannot add a sub-app with wildcard `*`");
}
handler.host = host.to_string();
self.subapps.push(handler);
self
}
pub fn with_route<T>(mut self, route: &str, handler: T) -> Self
where
T: RequestHandler<State> + 'static,
{
self.default_subapp = self.default_subapp.with_route(route, handler);
self
}
pub fn with_stateless_route<T>(mut self, route: &str, handler: T) -> Self
where
T: StatelessRequestHandler<State> + 'static,
{
self.default_subapp = self.default_subapp.with_stateless_route(route, handler);
self
}
pub fn with_path_aware_route<T>(mut self, route: &'static str, handler: T) -> Self
where
T: PathAwareRequestHandler<State> + 'static,
{
self.default_subapp = self.default_subapp.with_path_aware_route(route, handler);
self
}
pub fn with_websocket_route<T>(mut self, route: &str, handler: T) -> Self
where
T: WebsocketHandler<State> + 'static,
{
self.default_subapp = self.default_subapp.with_websocket_route(route, handler);
self
}
pub fn with_default_subapp(mut self, subapp: SubApp<State>) -> Self {
self.default_subapp = subapp;
self
}
pub fn with_monitor(mut self, monitor: MonitorConfig) -> Self {
self.monitor = monitor;
self
}
pub fn with_error_handler(mut self, handler: ErrorHandler) -> Self {
self.error_handler = handler;
self
}
pub fn with_connection_condition(mut self, condition: ConnectionCondition<State>) -> Self {
self.connection_condition = condition;
self
}
pub fn with_connection_timeout(mut self, timeout: Option<Duration>) -> Self {
self.connection_timeout = timeout;
self
}
pub fn with_cors(mut self, cors: Cors) -> Self {
self.default_subapp = self.default_subapp.with_cors(cors);
self
}
pub fn with_cors_config(mut self, route: &str, cors: Cors) -> Self {
self.default_subapp = self.default_subapp.with_cors_config(route, cors);
self
}
#[cfg(feature = "tls")]
pub fn with_forced_https(mut self, forced: bool) -> Self {
self.force_https = forced;
self
}
#[cfg(feature = "tls")]
pub fn with_cert(mut self, cert_path: impl AsRef<str>, key_path: impl AsRef<str>) -> Self {
use rustls::{Certificate, PrivateKey};
use rustls_pemfile::{read_one, Item};
use std::fs::File;
use std::io::BufReader;
let mut cert_file =
BufReader::new(File::open(cert_path.as_ref()).expect("failed to open cert file"));
let mut key_file =
BufReader::new(File::open(key_path.as_ref()).expect("failed to open key file"));
let certs: Vec<Certificate> = match read_one(&mut cert_file).unwrap().unwrap() {
Item::X509Certificate(cert) => vec![Certificate(cert)],
_ => panic!("failed to parse cert file"),
};
let key: PrivateKey = match read_one(&mut key_file).unwrap().unwrap() {
Item::PKCS8Key(key) => PrivateKey(key),
_ => panic!("failed to parse key file"),
};
let config = Arc::new(
ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.expect("failed to create server config"),
);
self.tls_config = Some(config);
self
}
#[deprecated(since = "0.3.0", note = "Please use `with_websocket_route` instead")]
pub fn with_websocket_handler<T>(mut self, handler: T) -> Self
where
T: WebsocketHandler<State> + 'static,
{
self.default_subapp = self.default_subapp.with_websocket_route("*", handler);
self
}
pub fn with_custom_connection_handler(mut self, handler: ConnectionHandler<State>) -> Self {
self.connection_handler = handler;
self
}
pub fn get_state(&self) -> Arc<State> {
self.state.clone()
}
}
#[allow(clippy::too_many_arguments)]
fn client_handler<State>(
mut stream: Stream,
subapps: Arc<Vec<SubApp<State>>>,
default_subapp: Arc<SubApp<State>>,
error_handler: Arc<ErrorHandler>,
state: Arc<State>,
monitor: MonitorConfig,
timeout: Option<Duration>,
) {
let addr = if let Ok(addr) = stream.peer_addr() {
addr
} else {
monitor.send(EventType::StreamDisconnectedWhileWaiting);
return;
};
loop {
let request = match timeout {
Some(timeout) => Request::from_stream_with_timeout(&mut stream, addr, timeout),
None => Request::from_stream(&mut stream, addr),
};
let cloned_state = state.clone();
if let Ok(req) = &request {
if req.headers.get(&HeaderType::Upgrade) == Some("websocket") {
monitor.send(Event::new(EventType::WebsocketConnectionRequested).with_peer(addr));
call_websocket_handler(req, &subapps, &default_subapp, cloned_state, stream);
monitor.send(Event::new(EventType::WebsocketConnectionClosed).with_peer(addr));
break;
}
}
let keep_alive = if let Ok(request) = &request {
if let Some(connection) = request.headers.get(&HeaderType::Connection) {
connection.to_ascii_lowercase() == "keep-alive"
} else {
false
}
} else {
false
};
let response = match &request {
Ok(request) if request.method == Method::Options => {
let handler = get_handler(request, &subapps, &default_subapp);
match handler {
Some(handler) => {
let mut response = Response::empty(StatusCode::NoContent)
.with_header(HeaderType::Date, DateTime::now().to_string())
.with_header(HeaderType::Server, "Humphrey")
.with_header(
HeaderType::Connection,
match keep_alive {
true => "Keep-Alive",
false => "Close",
},
);
handler.cors.set_headers(&mut response.headers);
response
}
None => error_handler(StatusCode::NotFound),
}
}
Ok(request) => {
let handler = get_handler(request, &subapps, &default_subapp);
let mut response = match handler {
Some(handler) => {
let mut response: Response =
handler.handler.serve(request.clone(), state.clone());
handler.cors.set_headers(&mut response.headers);
response
}
None => error_handler(StatusCode::NotFound),
};
match response.headers.get_mut(HeaderType::Connection) {
Some(_) => (),
None => {
if let Some(connection) = &request.headers.get(&HeaderType::Connection) {
response.headers.add(HeaderType::Connection, connection);
} else {
response.headers.add(HeaderType::Connection, "Close");
}
}
}
match response.headers.get_mut(HeaderType::Server) {
Some(_) => (),
None => {
response.headers.add(HeaderType::Server, "Humphrey");
}
}
match response.headers.get_mut(HeaderType::Date) {
Some(_) => (),
None => {
response
.headers
.add(HeaderType::Date, DateTime::now().to_string());
}
}
match response.headers.get_mut(HeaderType::ContentLength) {
Some(_) => (),
None => {
response
.headers
.add(HeaderType::ContentLength, response.body.len().to_string());
}
}
response.version = request.version.clone();
response
}
Err(e) => match e {
RequestError::Request => error_handler(StatusCode::BadRequest),
RequestError::Timeout => error_handler(StatusCode::RequestTimeout),
RequestError::Disconnected => return,
RequestError::Stream => {
return monitor.send(Event::new(EventType::RequestServedError))
}
},
};
let status = response.status_code;
let response_bytes: Vec<u8> = response.into();
if let Err(e) = stream.write_all(&response_bytes) {
monitor.send(
Event::new(EventType::RequestServedError)
.with_peer(addr)
.with_info(e.to_string()),
);
break;
};
let status_str: &str = status.into();
match status {
StatusCode::OK => monitor.send(
Event::new(EventType::RequestServedSuccess)
.with_peer(addr)
.with_info(format!("200 OK {}", request.unwrap().uri)),
),
StatusCode::RequestTimeout => monitor.send(
Event::new(EventType::RequestTimeout)
.with_peer(addr)
.with_info("408 Request Timeout"),
),
e => {
if let Ok(request) = request {
monitor.send(
Event::new(EventType::RequestServedError)
.with_peer(addr)
.with_info(format!("{} {} {}", u16::from(e), status_str, request.uri)),
)
} else {
monitor.send(
Event::new(EventType::RequestServedError)
.with_peer(addr)
.with_info(format!("{} {}", u16::from(e), status_str)),
)
}
}
}
if !keep_alive {
break;
}
monitor.send(Event::new(EventType::KeepAliveRespected).with_peer(addr));
}
monitor.send(Event::new(EventType::ConnectionClosed).with_peer(addr));
}
pub(crate) fn get_handler<'a, State>(
request: &'a Request,
subapps: &'a [SubApp<State>],
default_subapp: &'a SubApp<State>,
) -> Option<&'a RouteHandler<State>> {
if let Some(host) = request.headers.get(&HeaderType::Host) {
if let Some(subapp) = subapps
.iter()
.find(|subapp| wildcard_match(&subapp.host, host))
{
if let Some(handler) = subapp
.routes .iter() .find(|route| route.route.route_matches(&request.uri))
{
return Some(handler);
}
}
}
if let Some(handler) = default_subapp
.routes
.iter()
.find(|route| route.route.route_matches(&request.uri))
{
return Some(handler);
}
None
}
fn call_websocket_handler<State>(
request: &Request,
subapps: &[SubApp<State>],
default_subapp: &SubApp<State>,
state: Arc<State>,
stream: Stream,
) {
if let Some(host) = request.headers.get(&HeaderType::Host) {
if let Some(subapp) = subapps
.iter()
.find(|subapp| wildcard_match(&subapp.host, host))
{
if let Some(handler) = subapp
.websocket_routes .iter() .find(|route| route.route.route_matches(&request.uri))
{
handler.handler.serve(request.clone(), stream, state);
return;
}
}
}
if let Some(handler) = default_subapp
.websocket_routes
.iter()
.find(|route| route.route.route_matches(&request.uri))
{
handler.handler.serve(request.clone(), stream, state)
}
}
#[cfg(feature = "tls")]
fn force_https_thread(monitor: MonitorConfig) -> Result<(), Box<dyn std::error::Error>> {
let socket = TcpListener::bind("0.0.0.0:80")?;
for mut stream in socket.incoming().flatten() {
let addr = stream.peer_addr()?;
let request = Request::from_stream(&mut stream, addr)?;
let response = if let Some(host) = request.headers.get(&HeaderType::Host) {
Response::empty(StatusCode::MovedPermanently)
.with_header(
HeaderType::Location,
format!("https://{}{}", host, request.uri),
)
.with_header(HeaderType::Connection, "Close")
} else {
Response::empty(StatusCode::OK)
.with_bytes(b"<h1>Please access over HTTPS</h1>")
.with_header(HeaderType::ContentLength, "33")
.with_header(HeaderType::Connection, "Close")
};
let response_bytes: Vec<u8> = response.into();
stream.write_all(&response_bytes)?;
monitor.send(Event::new(EventType::HTTPSRedirect).with_peer(addr));
}
Ok(())
}
pub(crate) fn error_handler(status_code: StatusCode) -> Response {
let body = format!(
"<html><body><h1>{} {}</h1></body></html>",
Into::<u16>::into(status_code),
Into::<&str>::into(status_code)
);
Response::new(status_code, body.as_bytes())
}