#![deny(unsafe_code)]
extern crate base64;
#[cfg(feature = "brotli")]
extern crate brotli;
extern crate chrono;
#[cfg(feature = "gzip")]
extern crate deflate;
extern crate filetime;
extern crate multipart;
extern crate rand;
extern crate serde;
#[macro_use]
extern crate serde_derive;
pub extern crate percent_encoding;
extern crate serde_json;
extern crate sha1_smol;
extern crate threadpool;
extern crate time;
extern crate tiny_http;
pub extern crate url;
pub const DEFAULT_ENCODE_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS
.add(b' ')
.add(b'"')
.add(b'#')
.add(b'<')
.add(b'>')
.add(b'`')
.add(b'?')
.add(b'{')
.add(b'}');
pub use assets::extension_to_mime;
pub use assets::match_assets;
pub use log::{log, log_custom};
pub use response::{Response, ResponseBody};
pub use tiny_http::ReadWrite;
use std::error::Error;
use std::fmt;
use std::io::Cursor;
use std::io::Read;
use std::io::Result as IoResult;
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::panic;
use std::panic::AssertUnwindSafe;
use std::slice::Iter as SliceIter;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
use std::time::Duration;
pub mod cgi;
pub mod content_encoding;
pub mod input;
pub mod proxy;
pub mod session;
pub mod websocket;
mod assets;
mod find_route;
mod log;
mod response;
mod router;
#[doc(hidden)]
pub mod try_or_400;
#[macro_export]
macro_rules! try_or_404 {
($result:expr) => {
match $result {
Ok(r) => r,
Err(_) => return $crate::Response::empty_404(),
}
};
}
#[macro_export]
macro_rules! assert_or_400 {
($cond:expr) => {
if !$cond {
return $crate::Response::empty_400();
}
};
}
pub fn start_server<A, F>(addr: A, handler: F) -> !
where
A: ToSocketAddrs,
F: Send + Sync + 'static + Fn(&Request) -> Response,
{
Server::new(addr, handler)
.expect("Failed to start server")
.run();
panic!("The server socket closed unexpectedly")
}
pub fn start_server_with_pool<A, F>(addr: A, pool_size: Option<usize>, handler: F) -> !
where
A: ToSocketAddrs,
F: Send + Sync + 'static + Fn(&Request) -> Response,
{
let pool_size = pool_size.unwrap_or_else(|| {
8 * thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
});
Server::new(addr, handler)
.expect("Failed to start server")
.pool_size(pool_size)
.run();
panic!("The server socket closed unexpectedly")
}
struct AtomicCounter(Arc<AtomicUsize>);
impl AtomicCounter {
fn new(count: &Arc<AtomicUsize>) -> Self {
count.fetch_add(1, Ordering::Relaxed);
AtomicCounter(Arc::clone(count))
}
}
impl Drop for AtomicCounter {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Release);
}
}
enum Executor {
Threaded { count: Arc<AtomicUsize> },
Pooled { pool: threadpool::ThreadPool },
}
impl Executor {
fn with_size(size: usize) -> Self {
let pool = threadpool::ThreadPool::new(size);
Executor::Pooled { pool }
}
#[inline]
fn execute<F: FnOnce() + Send + 'static>(&self, f: F) {
match *self {
Executor::Threaded { ref count } => {
let counter = AtomicCounter::new(count);
thread::spawn(move || {
let _counter = counter;
f()
});
}
Executor::Pooled { ref pool } => {
pool.execute(f);
}
}
}
fn join(&self) {
match *self {
Executor::Threaded { ref count } => {
while count.load(Ordering::Acquire) > 0 {
thread::sleep(Duration::from_millis(100));
}
}
Executor::Pooled { ref pool } => {
pool.join();
}
}
}
}
impl Default for Executor {
fn default() -> Self {
Executor::Threaded {
count: Arc::new(AtomicUsize::new(0)),
}
}
}
pub struct Server<F> {
server: tiny_http::Server,
handler: Arc<AssertUnwindSafe<F>>,
executor: Executor,
}
impl<F> Server<F>
where
F: Send + Sync + 'static + Fn(&Request) -> Response,
{
pub fn new<A>(addr: A, handler: F) -> Result<Server<F>, Box<dyn Error + Send + Sync + 'static>>
where
A: ToSocketAddrs,
{
let server = tiny_http::Server::http(addr)?;
Ok(Server {
server,
executor: Executor::default(),
handler: Arc::new(AssertUnwindSafe(handler)), })
}
#[cfg(any(feature = "ssl", feature = "rustls"))]
pub fn new_ssl<A>(
addr: A,
handler: F,
certificate: Vec<u8>,
private_key: Vec<u8>,
) -> Result<Server<F>, Box<dyn Error + Send + Sync + 'static>>
where
A: ToSocketAddrs,
{
let ssl_config = tiny_http::SslConfig {
certificate,
private_key,
};
let server = tiny_http::Server::https(addr, ssl_config)?;
Ok(Server {
server,
executor: Executor::default(),
handler: Arc::new(AssertUnwindSafe(handler)), })
}
pub fn pool_size(mut self, pool_size: usize) -> Self {
self.executor = Executor::with_size(pool_size);
self
}
#[inline]
pub fn server_addr(&self) -> SocketAddr {
self.server
.server_addr()
.to_ip()
.expect("Unexpected Unix socket listener")
}
#[inline]
pub fn run(self) {
for request in self.server.incoming_requests() {
self.process(request);
}
}
#[inline]
pub fn poll(&self) {
while let Ok(Some(request)) = self.server.try_recv() {
self.process(request);
}
}
#[inline]
pub fn stoppable(self) -> (thread::JoinHandle<()>, mpsc::Sender<()>) {
let (tx, rx) = mpsc::channel();
let handle = thread::spawn(move || {
while rx.try_recv().is_err() {
while let Ok(Some(request)) = self.server.recv_timeout(Duration::from_secs(1)) {
self.process(request);
}
}
});
(handle, tx)
}
#[inline]
pub fn poll_timeout(&self, dur: std::time::Duration) {
while let Ok(Some(request)) = self.server.recv_timeout(dur) {
self.process(request);
}
}
pub fn join(&self) {
self.executor.join();
}
fn process(&self, request: tiny_http::Request) {
let handler = self.handler.clone();
self.executor.execute(|| {
struct RequestRead(Arc<Mutex<Option<tiny_http::Request>>>);
impl Read for RequestRead {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
self.0
.lock()
.unwrap()
.as_mut()
.unwrap()
.as_reader()
.read(buf)
}
}
let tiny_http_request;
let rouille_request = {
let url = request.url().to_owned();
let method = request.method().as_str().to_owned();
let headers = request
.headers()
.iter()
.map(|h| (h.field.to_string(), h.value.clone().into()))
.collect();
let remote_addr = request.remote_addr().copied();
tiny_http_request = Arc::new(Mutex::new(Some(request)));
let data = Arc::new(Mutex::new(Some(
Box::new(RequestRead(tiny_http_request.clone())) as Box<_>,
)));
Request {
url,
method,
headers,
https: false,
data,
remote_addr,
}
};
let mut rouille_response = {
let rouille_request = AssertUnwindSafe(rouille_request);
let res = panic::catch_unwind(move || {
let rouille_request = rouille_request;
handler(&rouille_request)
});
match res {
Ok(r) => r,
Err(_) => Response::html(
"<h1>Internal Server Error</h1>\
<p>An internal error has occurred on the server.</p>",
)
.with_status_code(500),
}
};
let (res_data, res_len) = rouille_response.data.into_reader_and_size();
let mut response = tiny_http::Response::empty(rouille_response.status_code)
.with_data(res_data, res_len);
let mut upgrade_header = "".into();
for (key, value) in rouille_response.headers {
if key.eq_ignore_ascii_case("Content-Length") {
continue;
}
if key.eq_ignore_ascii_case("Upgrade") {
upgrade_header = value;
continue;
}
if let Ok(header) = tiny_http::Header::from_bytes(key.as_bytes(), value.as_bytes())
{
response.add_header(header);
} else {
}
}
if let Some(ref mut upgrade) = rouille_response.upgrade {
let trq = tiny_http_request.lock().unwrap().take().unwrap();
let socket = trq.upgrade(&upgrade_header, response);
upgrade.build(socket);
} else {
let _ = tiny_http_request
.lock()
.unwrap()
.take()
.unwrap()
.respond(response);
}
});
}
}
pub trait Upgrade {
fn build(&mut self, socket: Box<dyn ReadWrite + Send>);
}
pub struct Request {
method: String,
url: String,
headers: Vec<(String, String)>,
https: bool,
data: Arc<Mutex<Option<Box<dyn Read + Send>>>>,
remote_addr: Option<SocketAddr>,
}
impl fmt::Debug for Request {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Request")
.field("method", &self.method)
.field("url", &self.url)
.field("headers", &self.headers)
.field("https", &self.https)
.field("remote_addr", &self.remote_addr)
.finish()
}
}
impl Request {
pub fn fake_http<U, M>(
method: M,
url: U,
headers: Vec<(String, String)>,
data: Vec<u8>,
) -> Request
where
U: Into<String>,
M: Into<String>,
{
let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
let remote_addr = Some("127.0.0.1:12345".parse().unwrap());
Request {
url: url.into(),
method: method.into(),
https: false,
data,
headers,
remote_addr,
}
}
pub fn fake_http_from<U, M>(
from: SocketAddr,
method: M,
url: U,
headers: Vec<(String, String)>,
data: Vec<u8>,
) -> Request
where
U: Into<String>,
M: Into<String>,
{
let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
Request {
url: url.into(),
method: method.into(),
https: false,
data,
headers,
remote_addr: Some(from),
}
}
pub fn fake_https<U, M>(
method: M,
url: U,
headers: Vec<(String, String)>,
data: Vec<u8>,
) -> Request
where
U: Into<String>,
M: Into<String>,
{
let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
let remote_addr = Some("127.0.0.1:12345".parse().unwrap());
Request {
url: url.into(),
method: method.into(),
https: true,
data,
headers,
remote_addr,
}
}
pub fn fake_https_from<U, M>(
from: SocketAddr,
method: M,
url: U,
headers: Vec<(String, String)>,
data: Vec<u8>,
) -> Request
where
U: Into<String>,
M: Into<String>,
{
let data = Arc::new(Mutex::new(Some(Box::new(Cursor::new(data)) as Box<_>)));
Request {
url: url.into(),
method: method.into(),
https: true,
data,
headers,
remote_addr: Some(from),
}
}
pub fn remove_prefix(&self, prefix: &str) -> Option<Request> {
if !self.url().starts_with(prefix) {
return None;
}
assert!(self.url.starts_with(prefix));
Some(Request {
method: self.method.clone(),
url: self.url[prefix.len()..].to_owned(),
headers: self.headers.clone(), https: self.https,
data: self.data.clone(),
remote_addr: self.remote_addr,
})
}
#[inline]
pub fn is_secure(&self) -> bool {
self.https
}
#[inline]
pub fn method(&self) -> &str {
&self.method
}
#[inline]
pub fn raw_url(&self) -> &str {
&self.url
}
#[inline]
pub fn raw_query_string(&self) -> &str {
if let Some(pos) = self.url.bytes().position(|c| c == b'?') {
self.url.split_at(pos + 1).1
} else {
""
}
}
pub fn url(&self) -> String {
let url = self.url.as_bytes();
let url = if let Some(pos) = url.iter().position(|&c| c == b'?') {
&url[..pos]
} else {
url
};
percent_encoding::percent_decode(url)
.decode_utf8_lossy()
.into_owned()
}
pub fn get_param(&self, param_name: &str) -> Option<String> {
let name_pattern = &format!("{}=", param_name);
let param_pairs = self.raw_query_string().split('&');
param_pairs
.filter(|pair| pair.starts_with(name_pattern) || pair == ¶m_name)
.map(|pair| pair.split('=').nth(1).unwrap_or(""))
.next()
.map(|value| {
percent_encoding::percent_decode(value.replace('+', " ").as_bytes())
.decode_utf8_lossy()
.into_owned()
})
}
#[inline]
pub fn header(&self, key: &str) -> Option<&str> {
self.headers
.iter()
.find(|&&(ref k, _)| k.eq_ignore_ascii_case(key))
.map(|&(_, ref v)| &v[..])
}
#[inline]
pub fn headers(&self) -> HeadersIter {
HeadersIter {
iter: self.headers.iter(),
}
}
pub fn do_not_track(&self) -> Option<bool> {
match self.header("DNT") {
Some(h) if h == "1" => Some(true),
Some(h) if h == "0" => Some(false),
_ => None,
}
}
pub fn data(&self) -> Option<RequestBody> {
let reader = self.data.lock().unwrap().take();
reader.map(|r| RequestBody {
body: r,
marker: PhantomData,
})
}
#[inline]
pub fn remote_addr(&self) -> &SocketAddr {
self.remote_addr
.as_ref()
.expect("Unexpected Unix socket for request")
}
}
#[derive(Debug, Clone)]
pub struct HeadersIter<'a> {
iter: SliceIter<'a, (String, String)>,
}
impl<'a> Iterator for HeadersIter<'a> {
type Item = (&'a str, &'a str);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|&(ref k, ref v)| (&k[..], &v[..]))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<'a> ExactSizeIterator for HeadersIter<'a> {}
pub struct RequestBody<'a> {
body: Box<dyn Read + Send>,
marker: PhantomData<&'a ()>,
}
impl<'a> Read for RequestBody<'a> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
self.body.read(buf)
}
}
#[cfg(test)]
mod tests {
use Request;
#[test]
fn header() {
let request = Request::fake_http(
"GET",
"/",
vec![("Host".to_owned(), "localhost".to_owned())],
vec![],
);
assert_eq!(request.header("Host"), Some("localhost"));
assert_eq!(request.header("host"), Some("localhost"));
}
#[test]
fn get_param() {
let request = Request::fake_http("GET", "/?p=hello", vec![], vec![]);
assert_eq!(request.get_param("p"), Some("hello".to_owned()));
}
#[test]
fn get_param_multiple_param() {
let request = Request::fake_http("GET", "/?foo=bar&message=hello", vec![], vec![]);
assert_eq!(request.get_param("message"), Some("hello".to_owned()));
}
#[test]
fn get_param_no_match() {
let request = Request::fake_http("GET", "/?hello=world", vec![], vec![]);
assert_eq!(request.get_param("foo"), None);
}
#[test]
fn get_param_partial_suffix_match() {
let request = Request::fake_http("GET", "/?hello=world", vec![], vec![]);
assert_eq!(request.get_param("lo"), None);
}
#[test]
fn get_param_partial_prefix_match() {
let request = Request::fake_http("GET", "/?hello=world", vec![], vec![]);
assert_eq!(request.get_param("he"), None);
}
#[test]
fn get_param_superstring_match() {
let request = Request::fake_http("GET", "/?jan=01", vec![], vec![]);
assert_eq!(request.get_param("january"), None);
}
#[test]
fn get_param_flag_with_equals() {
let request = Request::fake_http("GET", "/?flag=", vec![], vec![]);
assert_eq!(request.get_param("flag"), Some("".to_owned()));
}
#[test]
fn get_param_flag_without_equals() {
let request = Request::fake_http("GET", "/?flag", vec![], vec![]);
assert_eq!(request.get_param("flag"), Some("".to_owned()));
}
#[test]
fn get_param_flag_with_multiple_params() {
let request = Request::fake_http("GET", "/?flag&foo=bar", vec![], vec![]);
assert_eq!(request.get_param("flag"), Some("".to_owned()));
}
#[test]
fn body_twice() {
let request = Request::fake_http("GET", "/", vec![], vec![62, 62, 62]);
assert!(request.data().is_some());
assert!(request.data().is_none());
}
#[test]
fn url_strips_get_query() {
let request = Request::fake_http("GET", "/?p=hello", vec![], vec![]);
assert_eq!(request.url(), "/");
}
#[test]
fn urlencode_query_string() {
let request = Request::fake_http("GET", "/?p=hello%20world", vec![], vec![]);
assert_eq!(request.get_param("p"), Some("hello world".to_owned()));
}
#[test]
fn plus_in_query_string() {
let request = Request::fake_http("GET", "/?p=hello+world", vec![], vec![]);
assert_eq!(request.get_param("p"), Some("hello world".to_owned()));
}
#[test]
fn encoded_plus_in_query_string() {
let request = Request::fake_http("GET", "/?p=hello%2Bworld", vec![], vec![]);
assert_eq!(request.get_param("p"), Some("hello+world".to_owned()));
}
#[test]
fn url_encode() {
let request = Request::fake_http("GET", "/hello%20world", vec![], vec![]);
assert_eq!(request.url(), "/hello world");
}
#[test]
fn plus_in_url() {
let request = Request::fake_http("GET", "/hello+world", vec![], vec![]);
assert_eq!(request.url(), "/hello+world");
}
#[test]
fn dnt() {
let request =
Request::fake_http("GET", "/", vec![("DNT".to_owned(), "1".to_owned())], vec![]);
assert_eq!(request.do_not_track(), Some(true));
let request =
Request::fake_http("GET", "/", vec![("DNT".to_owned(), "0".to_owned())], vec![]);
assert_eq!(request.do_not_track(), Some(false));
let request = Request::fake_http("GET", "/", vec![], vec![]);
assert_eq!(request.do_not_track(), None);
let request = Request::fake_http(
"GET",
"/",
vec![("DNT".to_owned(), "malformed".to_owned())],
vec![],
);
assert_eq!(request.do_not_track(), None);
}
}