#![deny(missing_docs)]
mod status_code;
use anyhow::{anyhow, Context, Error};
use bufstream::BufStream;
use fehler::{throw, throws};
use log::error;
use serde::{Deserialize, Serialize};
pub use status_code::StatusCode;
use std::collections::HashMap;
use std::convert::Infallible;
use std::fmt::{Debug, Display};
use std::io::{BufRead, Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::str::FromStr;
use std::sync::{Arc, RwLock};
use std::thread;
use url::Url;
type HeaderName = unicase::UniCase<String>;
pub struct Request {
method: String,
path_params: HashMap<String, String>,
req_headers: HashMap<HeaderName, String>,
req_body: Vec<u8>,
url: Url,
status: StatusCode,
resp_body: Vec<u8>,
resp_headers: HashMap<String, String>,
}
impl Request {
pub fn url(&self) -> &Url {
&self.url
}
pub fn headers(&self) -> &HashMap<HeaderName, String> {
&self.req_headers
}
#[throws]
pub fn read_json<'a, D: Deserialize<'a>>(&'a self) -> D {
serde_json::from_slice(&self.req_body)?
}
pub fn write_bytes(&mut self, body: &[u8]) {
self.resp_body = body.to_vec();
self.set_content_type("application/octet-stream");
}
#[throws]
pub fn write_json<S: Serialize>(&mut self, body: &S) {
self.resp_body = serde_json::to_vec(body)?;
self.set_content_type("application/json");
}
pub fn write_text(&mut self, body: &str) {
self.resp_body = body.as_bytes().to_vec();
self.set_content_type("text/plain; charset=UTF-8");
}
pub fn set_status(&mut self, status: StatusCode) {
self.status = status;
}
pub fn set_not_found(&mut self) {
self.set_status(StatusCode::NotFound);
}
pub fn set_header(&mut self, name: &str, value: &str) {
self.resp_headers.insert(name.into(), value.into());
}
pub fn set_content_type(&mut self, value: &str) {
self.set_header("Content-Type", value);
}
#[throws]
pub fn path_param<F>(&self, name: &str) -> F
where
F::Err: std::error::Error + Send + Sync + 'static,
F: FromStr,
{
let value = self
.path_params
.get(name)
.ok_or_else(|| anyhow!("path param {} not found", name))?;
value
.parse()
.with_context(|| format!("failed to parse path param {}", name))?
}
}
pub type Handler<E> = dyn Fn(&mut Request) -> Result<(), E> + Send + Sync;
pub type ErrorHandler<E> = dyn Fn(&mut Request, &RequestError<E>) + Send + Sync;
type ErrorHandlerArc<E> = Arc<RwLock<ErrorHandler<E>>>;
#[derive(Clone)]
struct Path {
parts: Vec<String>,
}
fn match_path(
path: &Path,
route_path: &Path,
) -> Option<HashMap<String, String>> {
let mut map = HashMap::new();
for (left, right) in path.parts.iter().zip(route_path.parts.iter()) {
let is_placeholder = right.starts_with(':');
if !is_placeholder && left != right {
return None;
}
if is_placeholder {
map.insert(right[1..].to_string(), left.to_string());
}
}
Some(map)
}
impl FromStr for Path {
type Err = Infallible;
#[throws(Self::Err)]
fn from_str(s: &str) -> Path {
Path {
parts: s.split('/').map(|p| p.to_string()).collect(),
}
}
}
struct Route<E> {
method: String,
path: Path,
handler: Box<Handler<E>>,
}
type Routes<E> = Arc<RwLock<Vec<Route<E>>>>;
#[derive(Debug, thiserror::Error)]
pub enum RequestError<E: Debug + Display> {
#[error("not found")]
NotFound,
#[error("custom: {0}")]
Custom(E),
}
fn default_error_handler<E: Debug + Display>(
req: &mut Request,
error: &RequestError<E>,
) {
match error {
RequestError::NotFound => {
error!("not found: {}", req.url().path());
req.set_status(StatusCode::NotFound);
req.write_text("not found");
}
RequestError::Custom(err) => {
error!("error handling {}: {}", req.url().path(), err);
req.set_status(StatusCode::InternalServerError);
req.write_text("internal server error");
}
}
}
fn dispatch_request<E: Debug + Display>(
routes: Routes<E>,
path: &Path,
req: &mut Request,
) -> Result<(), RequestError<E>> {
for route in &*routes.read().unwrap() {
if req.method != route.method {
continue;
}
if let Some(path_params) = match_path(path, &route.path) {
req.path_params = path_params;
(route.handler)(req).map_err(RequestError::Custom)?;
return Ok(());
}
}
Err(RequestError::NotFound)
}
#[throws]
fn handle_connection<E: Debug + Display>(
stream: TcpStream,
routes: Routes<E>,
error_handler: ErrorHandlerArc<E>,
) {
let mut stream = BufStream::new(stream);
let mut line = String::new();
stream
.read_line(&mut line)
.context("missing request header")?;
let parts = line.split_whitespace().take(3).collect::<Vec<_>>();
if parts.len() != 3 {
throw!(anyhow!("invalid request: {}", line));
}
let method = parts[0];
let raw_path = parts[1];
let path = raw_path.parse::<Path>()?;
let mut headers: HashMap<HeaderName, String> = HashMap::new();
loop {
let mut line = String::new();
stream.read_line(&mut line).context("failed to read line")?;
let mut parts = line.splitn(2, ':');
if let Some(name) = parts.next() {
let value = parts.next().unwrap_or("");
headers.insert(name.into(), value.trim().to_string());
}
if line.trim().is_empty() {
break;
}
}
let mut req_body = Vec::new();
if let Some(len) = headers.get(&HeaderName::new("Content-Length".into())) {
if let Ok(len) = len.parse::<usize>() {
req_body.resize(len, 0);
stream.read_exact(&mut req_body)?;
}
}
let host = headers
.get(&HeaderName::new("host".into()))
.ok_or_else(|| anyhow!("missing host header"))?;
let mut url = Url::parse(&format!("http://{}", host))
.with_context(|| format!("failed to parse host {}", host))?;
url.set_path(raw_path);
let mut req = Request {
method: method.into(),
path_params: HashMap::new(),
req_headers: headers,
req_body,
url,
resp_body: Vec::new(),
status: StatusCode::Ok,
resp_headers: HashMap::new(),
};
if let Err(err) = dispatch_request(routes, &path, &mut req) {
(error_handler.read().unwrap())(&mut req, &err);
}
stream.write_all(
format!(
"HTTP/1.1 {} {}\n",
req.status,
req.status.canonical_reason(),
)
.as_bytes(),
)?;
for (name, value) in req.resp_headers {
stream.write_all(format!("{}: {}\n", name, value).as_bytes())?;
}
stream.write_all(
format!("Content-Length: {}\n", req.resp_body.len()).as_bytes(),
)?;
stream.write_all(b"\n")?;
stream.write_all(&req.resp_body)?;
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TestRequest {
body: Vec<u8>,
method: String,
url: Url,
headers: HashMap<String, String>,
}
impl TestRequest {
#[throws]
pub fn new_with_body(s: &str, body: &[u8]) -> TestRequest {
let parts = s.split_whitespace().collect::<Vec<_>>();
TestRequest {
body: body.into(),
method: parts[0].into(),
url: Url::parse(&format!("http://example.com{}", parts[1]))?,
headers: HashMap::new(),
}
}
#[throws]
pub fn new_with_json<S: Serialize>(s: &str, body: &S) -> TestRequest {
let parts = s.split_whitespace().collect::<Vec<_>>();
TestRequest {
body: serde_json::to_vec(body)?,
method: parts[0].into(),
url: Url::parse(&format!("http://example.com{}", parts[1]))?,
headers: HashMap::new(),
}
}
#[throws]
pub fn new(s: &str) -> TestRequest {
Self::new_with_body(s, &Vec::new())?
}
#[throws]
fn path(&self) -> Path {
self.url.path().parse()?
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TestResponse {
pub status: StatusCode,
pub body: Vec<u8>,
pub headers: HashMap<HeaderName, String>,
}
impl TestResponse {
#[throws]
pub fn json<'a, D: Deserialize<'a>>(&'a self) -> D {
serde_json::from_slice(&self.body)?
}
}
fn convert_header_map_to_unicase(
map: &HashMap<String, String>,
) -> HashMap<HeaderName, String> {
map.iter()
.map(|(key, val)| (HeaderName::new(key.clone()), val.clone()))
.collect()
}
pub struct Server<E: Debug + Display> {
address: SocketAddr,
routes: Routes<E>,
error_handler: ErrorHandlerArc<E>,
}
impl<E: Debug + Display + 'static> Server<E> {
#[throws]
pub fn new(address: &str) -> Server<E> {
Server {
address: address.parse::<SocketAddr>()?,
routes: Arc::new(RwLock::new(Vec::new())),
error_handler: Arc::new(RwLock::new(Box::new(
default_error_handler,
))),
}
}
#[throws]
pub fn route(&mut self, route: &str, handler: &'static Handler<E>) {
let mut iter = route.split_whitespace();
let method = iter.next().ok_or_else(|| anyhow!("missing method"))?;
let path = iter.next().ok_or_else(|| anyhow!("missing path"))?;
let mut routes = self.routes.write().unwrap();
routes.push(Route {
method: method.into(),
path: path.parse()?,
handler: Box::new(handler),
});
}
pub fn set_error_handler(
&mut self,
error_handler: &'static ErrorHandler<E>,
) {
self.error_handler = Arc::new(RwLock::new(Box::new(error_handler)));
}
pub fn launch(self) -> Result<(), Error> {
let listener = TcpListener::bind(self.address)?;
loop {
let (tcp_stream, _addr) = listener.accept()?;
let routes = self.routes.clone();
let error_handler = self.error_handler.clone();
if let Err(err) = thread::Builder::new()
.name("shs-handler".into())
.spawn(move || {
if let Err(err) =
handle_connection(tcp_stream, routes, error_handler)
{
error!("{}", err);
}
})
{
error!("failed to spawn thread: {}", err);
}
}
}
pub fn test_request(
&self,
input: &TestRequest,
) -> Result<TestResponse, RequestError<E>> {
let mut req = Request {
method: input.method.clone(),
path_params: HashMap::new(),
req_headers: convert_header_map_to_unicase(&input.headers),
req_body: input.body.clone(),
url: input.url.clone(),
resp_body: Vec::new(),
status: StatusCode::Ok,
resp_headers: HashMap::new(),
};
let path = input.path().unwrap();
dispatch_request(self.routes.clone(), &path, &mut req)?;
Ok(TestResponse {
status: req.status,
body: req.resp_body,
headers: convert_header_map_to_unicase(&req.resp_headers),
})
}
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
assert_eq!(2 + 2, 4);
}
}