#![allow(dead_code)]
#![allow(clippy::cmp_owned)]
use brotli::enc::{BrotliCompress, BrotliEncoderParams};
use cached::proc_macro::cached;
use cookie::Cookie;
use core::f64;
use futures_lite::{future::Boxed, Future, FutureExt};
use hyper::{
body,
body::HttpBody,
header,
service::{make_service_fn, service_fn},
HeaderMap,
};
use hyper::{Body, Method, Request, Response, Server as HyperServer};
use libflate::gzip;
use route_recognizer::{Params, Router};
use std::{
cmp::Ordering,
fmt::Display,
io,
pin::Pin,
result::Result,
str::{from_utf8, Split},
string::ToString,
};
use time::OffsetDateTime;
use crate::dbg_msg;
type BoxResponse = Pin<Box<dyn Future<Output = Result<Response<Body>, String>> + Send>>;
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
enum CompressionType {
Passthrough,
Gzip,
Brotli,
}
const DEFAULT_COMPRESSOR: CompressionType = CompressionType::Gzip;
impl CompressionType {
fn parse(s: &str) -> Option<Self> {
let c = match s {
"gzip" => Self::Gzip,
"br" => Self::Brotli,
"*" => DEFAULT_COMPRESSOR,
_ => return None,
};
Some(c)
}
}
impl Display for CompressionType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Gzip => write!(f, "gzip"),
Self::Brotli => write!(f, "br"),
Self::Passthrough => Ok(()),
}
}
}
pub struct Route<'a> {
router: &'a mut Router<fn(Request<Body>) -> BoxResponse>,
path: String,
}
pub struct Server {
pub default_headers: HeaderMap,
router: Router<fn(Request<Body>) -> BoxResponse>,
}
#[macro_export]
macro_rules! headers(
{ $($key:expr => $value:expr),+ } => {
{
let mut m = hyper::HeaderMap::new();
$(
if let Ok(val) = hyper::header::HeaderValue::from_str($value) {
m.insert($key, val);
}
)+
m
}
};
);
pub trait RequestExt {
fn params(&self) -> Params;
fn param(&self, name: &str) -> Option<String>;
fn set_params(&mut self, params: Params) -> Option<Params>;
fn cookies(&self) -> Vec<Cookie<'_>>;
fn cookie(&self, name: &str) -> Option<Cookie<'_>>;
}
pub trait ResponseExt {
fn cookies(&self) -> Vec<Cookie<'_>>;
fn insert_cookie(&mut self, cookie: Cookie<'_>);
fn remove_cookie(&mut self, name: String);
}
impl RequestExt for Request<Body> {
fn params(&self) -> Params {
self.extensions().get::<Params>().unwrap_or(&Params::new()).clone()
}
fn param(&self, name: &str) -> Option<String> {
self.params().find(name).map(std::borrow::ToOwned::to_owned)
}
fn set_params(&mut self, params: Params) -> Option<Params> {
self.extensions_mut().insert(params)
}
fn cookies(&self) -> Vec<Cookie<'_>> {
self.headers().get("Cookie").map_or(Vec::new(), |header| {
header
.to_str()
.unwrap_or_default()
.split("; ")
.map(|cookie| Cookie::parse(cookie).unwrap_or_else(|_| Cookie::from("")))
.collect()
})
}
fn cookie(&self, name: &str) -> Option<Cookie<'_>> {
self.cookies().into_iter().find(|c| c.name() == name)
}
}
impl ResponseExt for Response<Body> {
fn cookies(&self) -> Vec<Cookie<'_>> {
self.headers().get("Cookie").map_or(Vec::new(), |header| {
header
.to_str()
.unwrap_or_default()
.split("; ")
.map(|cookie| Cookie::parse(cookie).unwrap_or_else(|_| Cookie::from("")))
.collect()
})
}
fn insert_cookie(&mut self, cookie: Cookie<'_>) {
if let Ok(val) = header::HeaderValue::from_str(&cookie.to_string()) {
self.headers_mut().append("Set-Cookie", val);
}
}
fn remove_cookie(&mut self, name: String) {
let removal_cookie = Cookie::build(name).path("/").http_only(true).expires(OffsetDateTime::now_utc());
if let Ok(val) = header::HeaderValue::from_str(&removal_cookie.to_string()) {
self.headers_mut().append("Set-Cookie", val);
}
}
}
impl Route<'_> {
fn method(&mut self, method: &Method, dest: fn(Request<Body>) -> BoxResponse) -> &mut Self {
self.router.add(&format!("/{}{}", method.as_str(), self.path), dest);
self
}
pub fn get(&mut self, dest: fn(Request<Body>) -> BoxResponse) -> &mut Self {
self.method(&Method::GET, dest)
}
pub fn post(&mut self, dest: fn(Request<Body>) -> BoxResponse) -> &mut Self {
self.method(&Method::POST, dest)
}
}
impl Default for Server {
fn default() -> Self {
Self::new()
}
}
impl Server {
pub fn new() -> Self {
Self {
default_headers: HeaderMap::new(),
router: Router::new(),
}
}
pub fn at(&mut self, path: &str) -> Route<'_> {
Route {
path: path.to_owned(),
router: &mut self.router,
}
}
pub fn listen(self, addr: &str) -> Boxed<Result<(), hyper::Error>> {
let make_svc = make_service_fn(move |_conn| {
let router = self.router.clone();
let default_headers = self.default_headers.clone();
async move {
Ok::<_, String>(service_fn(move |req: Request<Body>| {
let req_headers = req.headers().clone();
let def_headers = default_headers.clone();
let mut path = req.uri().path().replace("//", "/").replace("%2F", "/");
if path != "/" && path.ends_with('/') {
path.pop();
}
let (method, is_head) = match req.method() {
&Method::HEAD => (&Method::GET, true),
method => (method, false),
};
match router.recognize(&format!("/{}{}", method.as_str(), path)) {
Ok(found) => {
let mut parammed = req;
parammed.set_params(found.params().clone());
let func = (found.handler().to_owned().to_owned())(parammed);
async move {
match func.await {
Ok(mut res) => {
res.headers_mut().extend(def_headers);
if is_head {
*res.body_mut() = Body::empty();
} else {
let _ = compress_response(&req_headers, &mut res).await;
}
Ok(res)
}
Err(msg) => new_boilerplate(def_headers, req_headers, 500, if is_head { Body::empty() } else { Body::from(msg) }).await,
}
}
.boxed()
}
Err(e) => new_boilerplate(def_headers, req_headers, 404, if is_head { Body::empty() } else { e.into() }).boxed(),
}
}))
}
});
let address = &addr.parse().unwrap_or_else(|_| panic!("Cannot parse {addr} as address (example format: 0.0.0.0:8080)"));
let server = HyperServer::bind(address).serve(make_svc).with_graceful_shutdown(async {
#[cfg(windows)]
tokio::signal::ctrl_c().await.expect("Failed to install CTRL+C signal handler");
#[cfg(unix)]
{
let mut signal_terminate = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()).expect("Failed to install SIGTERM signal handler");
tokio::select! {
_ = tokio::signal::ctrl_c() => (),
_ = signal_terminate.recv() => ()
}
}
});
server.boxed()
}
}
async fn new_boilerplate(
default_headers: HeaderMap<header::HeaderValue>,
req_headers: HeaderMap<header::HeaderValue>,
status: u16,
body: Body,
) -> Result<Response<Body>, String> {
match Response::builder().status(status).body(body) {
Ok(mut res) => {
let _ = compress_response(&req_headers, &mut res).await;
res.headers_mut().extend(default_headers.clone());
Ok(res)
}
Err(msg) => Err(msg.to_string()),
}
}
#[cached]
fn determine_compressor(accept_encoding: String) -> Option<CompressionType> {
if accept_encoding.is_empty() {
return None;
};
struct CompressorCandidate {
alg: CompressionType,
q: f64,
}
impl Ord for CompressorCandidate {
fn cmp(&self, other: &Self) -> Ordering {
match self.q.total_cmp(&other.q) {
Ordering::Equal => self.alg.cmp(&other.alg),
ord => ord,
}
}
}
impl PartialOrd for CompressorCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for CompressorCandidate {
fn eq(&self, other: &Self) -> bool {
(self.q == other.q) && (self.alg == other.alg)
}
}
impl Eq for CompressorCandidate {}
let mut cur_candidate = CompressorCandidate {
alg: CompressionType::Passthrough,
q: f64::NEG_INFINITY,
};
for val in accept_encoding.split(',') {
let mut q: f64 = 1.0;
let mut spl: Split<'_, char> = val.split(';');
let compressor: CompressionType = match spl.next() {
Some(s) => match CompressionType::parse(s.trim()) {
Some(candidate) => candidate,
None => continue,
},
None => continue,
};
if let Some(s) = spl.next() {
if !(s.len() > 2 && s.starts_with("q=")) {
return None;
}
match s[2..].parse::<f64>() {
Ok(val) => {
if (0.0..=1.0).contains(&val) {
q = val;
} else {
return None;
};
}
Err(_) => {
return None;
}
}
};
let new_candidate = CompressorCandidate { alg: compressor, q };
if let Some(ord) = new_candidate.partial_cmp(&cur_candidate) {
if ord == Ordering::Greater {
cur_candidate = new_candidate;
}
};
}
if cur_candidate.q == f64::NEG_INFINITY {
None
} else {
Some(cur_candidate.alg)
}
}
async fn compress_response(req_headers: &HeaderMap<header::HeaderValue>, res: &mut Response<Body>) -> Result<(), String> {
if let Some(hdr) = res.headers().get(header::CONTENT_TYPE) {
match from_utf8(hdr.as_bytes()) {
Ok(val) => {
let s = val.to_string();
if !(s.starts_with("text/") || s.starts_with("application/json")) {
return Ok(());
};
}
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
};
} else {
return Ok(());
};
if res.body().size_hint().lower() < 1452 {
return Ok(());
};
let accept_encoding: String = match req_headers.get(header::ACCEPT_ENCODING) {
None => return Ok(()),
Some(hdr) => match String::from_utf8(hdr.as_bytes().into()) {
Ok(val) => val,
#[cfg(debug_assertions)]
Err(e) => {
dbg_msg!(e);
return Ok(());
}
#[cfg(not(debug_assertions))]
Err(_) => return Ok(()),
},
};
let compressor: CompressionType = match determine_compressor(accept_encoding) {
Some(c) => c,
None => return Ok(()),
};
let body_bytes: Vec<u8> = match body::to_bytes(res.body_mut()).await {
Ok(b) => b.to_vec(),
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
};
match compress_body(compressor, body_bytes) {
Ok(compressed) => {
res.headers_mut().insert(header::CONTENT_ENCODING, compressor.to_string().parse().unwrap());
*(res.body_mut()) = Body::from(compressed);
}
Err(e) => return Err(e),
}
Ok(())
}
#[cached(size = 100, time = 600, result = true)]
fn compress_body(compressor: CompressionType, body_bytes: Vec<u8>) -> Result<Vec<u8>, String> {
let mut reader = io::Cursor::new(body_bytes);
let compressed: Vec<u8> = match compressor {
CompressionType::Gzip => {
let mut gz: gzip::Encoder<Vec<u8>> = match gzip::Encoder::new(Vec::new()) {
Ok(gz) => gz,
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
};
match io::copy(&mut reader, &mut gz) {
Ok(_) => match gz.finish().into_result() {
Ok(compressed) => compressed,
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
},
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
}
}
CompressionType::Brotli => {
let brotli_params = BrotliEncoderParams::default();
let mut compressed = Vec::<u8>::new();
match BrotliCompress(&mut reader, &mut compressed, &brotli_params) {
Ok(_) => compressed,
Err(e) => {
dbg_msg!(e);
return Err(e.to_string());
}
}
}
CompressionType::Passthrough => {
let msg = "unsupported compressor".to_string();
return Err(msg);
}
};
Ok(compressed)
}
#[cfg(test)]
mod tests {
use super::*;
use brotli::Decompressor as BrotliDecompressor;
use futures_lite::future::block_on;
use lipsum::lipsum;
use std::{boxed::Box, io};
#[test]
fn test_determine_compressor() {
assert_eq!(determine_compressor("unsupported".to_string()), None);
assert_eq!(determine_compressor("gzip".to_string()), Some(CompressionType::Gzip));
assert_eq!(determine_compressor("*".to_string()), Some(DEFAULT_COMPRESSOR));
assert_eq!(determine_compressor("gzip, br".to_string()), Some(CompressionType::Brotli));
assert_eq!(determine_compressor("gzip;q=0.8, br;q=0.3".to_string()), Some(CompressionType::Gzip));
assert_eq!(determine_compressor("br, gzip".to_string()), Some(CompressionType::Brotli));
assert_eq!(determine_compressor("br;q=0.3, gzip;q=0.4".to_string()), Some(CompressionType::Gzip));
assert_eq!(determine_compressor("gzip;q=NAN".to_string()), None);
}
#[test]
fn test_compress_response() {
macro_rules! ae_gen {
($x:expr) => {
$x.to_string().as_str()
};
($x:expr, $($y:expr),+) => {
format!("{}, {}", $x.to_string(), ae_gen!($($y),+)).as_str()
};
}
for accept_encoding in [
"*",
ae_gen!(CompressionType::Gzip),
ae_gen!(CompressionType::Brotli, CompressionType::Gzip),
ae_gen!(CompressionType::Brotli),
] {
let expected_encoding: CompressionType = match determine_compressor(accept_encoding.to_string()) {
Some(s) => s,
None => panic!("determine_compressor(accept_encoding.to_string()) => None"),
};
let mut req_headers = HeaderMap::new();
req_headers.insert(header::ACCEPT_ENCODING, header::HeaderValue::from_str(accept_encoding).unwrap());
let lorem_ipsum: String = lipsum(10000);
let expected_lorem_ipsum = Vec::<u8>::from(lorem_ipsum.as_str());
let mut res = Response::builder()
.status(200)
.header(header::CONTENT_TYPE, "text/plain")
.body(Body::from(lorem_ipsum))
.unwrap();
if let Err(e) = block_on(compress_response(&req_headers, &mut res)) {
panic!("compress_response(&req_headers, &mut res) => Err(\"{e}\")");
};
assert_eq!(
res
.headers()
.get(header::CONTENT_ENCODING)
.unwrap_or_else(|| panic!("missing content-encoding header"))
.to_str()
.unwrap_or_else(|_| panic!("failed to convert Content-Encoding header::HeaderValue to String")),
expected_encoding.to_string()
);
let body_vec = match block_on(body::to_bytes(res.body_mut())) {
Ok(b) => b.to_vec(),
Err(e) => panic!("{e}"),
};
if expected_encoding == CompressionType::Passthrough {
assert!(body_vec.eq(&expected_lorem_ipsum));
continue;
}
let mut body_cursor: io::Cursor<Vec<u8>> = io::Cursor::new(body_vec);
let mut decoder: Box<dyn io::Read> = match expected_encoding {
CompressionType::Gzip => match gzip::Decoder::new(&mut body_cursor) {
Ok(dgz) => Box::new(dgz),
Err(e) => panic!("{e}"),
},
CompressionType::Brotli => Box::new(BrotliDecompressor::new(body_cursor, expected_lorem_ipsum.len())),
_ => panic!("no decompressor for {}", expected_encoding),
};
let mut decompressed = Vec::<u8>::new();
if let Err(e) = io::copy(&mut decoder, &mut decompressed) {
panic!("{e}");
};
assert!(decompressed.eq(&expected_lorem_ipsum));
}
}
}