#![doc = include_str!("../README.md")]
use std::{
collections::HashMap,
future::Future,
io::{Read, Write, BufWriter},
net::{TcpListener, TcpStream, Shutdown},
pin::Pin,
sync::Arc,
time::{Instant, SystemTime, UNIX_EPOCH},
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};
pub trait BodyStream: Send {
fn next_chunk(&mut self) -> Option<Vec<u8>>;
}
impl BodyStream for Vec<u8> {
fn next_chunk(&mut self) -> Option<Vec<u8>> {
if self.is_empty() { None } else { Some(std::mem::take(self)) }
}
}
pub trait IntoBytes { fn into_bytes(self) -> Vec<u8>; }
impl IntoBytes for String { fn into_bytes(self) -> Vec<u8> { self.into_bytes() } }
impl IntoBytes for &str { fn into_bytes(self) -> Vec<u8> { self.as_bytes().to_vec() } }
impl IntoBytes for Vec<u8> { fn into_bytes(self) -> Vec<u8> { self } }
pub struct Req {
pub method: String,
pub path: String,
pub body: String,
pub headers: HashMap<String, String>
}
pub struct Params(pub HashMap<String, String>);
#[derive(Copy, Clone)]
#[repr(u16)]
pub enum StatusCode {
Ok = 200,
Unauthorized = 401,
Forbidden = 403,
NotFound = 404
}
pub struct Reply {
pub status: u16,
pub headers: HashMap<String, String>,
pub body: Box<dyn BodyStream>,
}
impl Reply {
pub fn new(status: StatusCode) -> Self {
Self { status: status as u16, headers: HashMap::new(), body: Box::new(Vec::new()) }
}
pub fn header(mut self, key: &str, value: &str) -> Self {
self.headers.insert(key.to_string(), value.to_string());
self
}
pub fn body<T: IntoBytes>(mut self, data: T) -> Self {
self.body = Box::new(data.into_bytes());
self
}
}
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub type Handler = Box<dyn Fn(Req, Params) -> BoxFuture<'static, Reply> + Send + Sync>;
pub type Middleware = Box<dyn Fn(&str) -> Option<Reply> + Send + Sync>;
pub enum Method { GET, POST }
pub use Method::*;
pub struct WebIo {
routes: Vec<(String, String, Handler)>,
mw: Option<Middleware>,
handlers_404: HashMap<String, Handler>,
}
impl WebIo {
pub fn new() -> Self {
Self {
routes: Vec::new(),
mw: None,
handlers_404: HashMap::new()
}
}
pub fn use_mw<F>(&mut self, f: F) where F: Fn(&str) -> Option<Reply> + Send + Sync + 'static {
self.mw = Some(Box::new(f));
}
pub fn on_404<F, Fut>(&mut self, handler: F)
where F: Fn(Req, Params) -> Fut + Send + Sync + 'static, Fut: Future<Output = Reply> + Send + 'static,
{
let h: Handler = Box::new(move |r, p| Box::pin(handler(r, p)));
let dummy_req = Req { method: "".into(), path: "".into(), body: "".into(), headers: HashMap::new() };
let sniff = execute(h(dummy_req, Params(HashMap::new())));
let ct = sniff.headers.get("Content-Type").cloned().unwrap_or_default();
if ct.contains("json") {
self.handlers_404.insert("json".to_string(), h);
} else {
self.handlers_404.insert("html".to_string(), h);
}
}
pub fn route<F, Fut>(&mut self, method: Method, path: &str, handler: F)
where F: Fn(Req, Params) -> Fut + Send + Sync + 'static, Fut: Future<Output = Reply> + Send + 'static,
{
let m = match method { GET => "GET", POST => "POST" }.to_string();
self.routes.push((m, path.to_string(), Box::new(move |r, p| Box::pin(handler(r, p)))));
}
pub async fn run(self, host: &str, port: &str) {
let listener = TcpListener::bind(format!("{}:{}", host, port)).expect("Bind failed");
println!("🦀 WebIo Live: http://{}:{}", host, port);
let app = Arc::new(self);
for stream in listener.incoming() {
if let Ok(s) = stream {
let a = Arc::clone(&app);
std::thread::spawn(move || execute(a.handle_connection(s)));
}
}
}
async fn handle_connection(&self, mut stream: TcpStream) {
let start_time = Instant::now();
let _ = stream.set_nodelay(true); let _ = stream.set_read_timeout(Some(std::time::Duration::from_millis(150)));
let mut buffer = [0; 4096];
let n = match stream.read(&mut buffer) { Ok(n) if n > 0 => n, _ => return };
let header_str = String::from_utf8_lossy(&buffer[..n]);
let mut lines = header_str.lines();
let parts: Vec<&str> = lines.next().unwrap_or("").split_whitespace().collect();
if parts.len() < 2 || parts[1] == "/favicon.ico" { return; }
let (method, full_path) = (parts[0], parts[1]);
let mut headers = HashMap::new();
for line in lines {
if line.is_empty() { break; }
if let Some((k, v)) = line.split_once(": ") {
headers.insert(k.to_lowercase(), v.to_string());
}
}
if let Some(ref mw_func) = self.mw {
if let Some(early_reply) = mw_func(full_path) {
self.finalize(stream, early_reply, method, full_path, start_time).await;
return;
}
}
let path_only = full_path.split('?').next().unwrap_or("/");
let mut final_params = HashMap::new();
let mut active_handler: Option<&Handler> = None;
let path_segments: Vec<&str> = path_only.split('/').filter(|s| !s.is_empty()).collect();
for (r_method, r_path, handler) in &self.routes {
if r_method != method { continue; }
let route_segments: Vec<&str> = r_path.split('/').filter(|s| !s.is_empty()).collect();
if route_segments.len() == path_segments.len() {
let mut matches = true;
let mut temp_params = HashMap::new();
for (r_seg, p_seg) in route_segments.iter().zip(path_segments.iter()) {
if r_seg.starts_with('<') && r_seg.ends_with('>') {
temp_params.insert(r_seg[1..r_seg.len()-1].to_string(), p_seg.to_string());
} else if r_seg != p_seg { matches = false; break; }
}
if matches { final_params = temp_params; active_handler = Some(handler); break; }
}
}
let req = Req { method: method.to_string(), path: full_path.to_string(), body: String::new(), headers };
let reply = if let Some(handler) = active_handler {
handler(req, Params(final_params)).await
} else {
let accept = req.headers.get("accept").cloned().unwrap_or_default();
let h_404 = if accept.contains("text/html") {
self.handlers_404.get("html")
} else {
self.handlers_404.get("json")
};
if let Some(h) = h_404 {
h(req, Params(HashMap::new())).await
} else {
Reply::new(StatusCode::NotFound).body("404 Not Found")
}
};
self.finalize(stream, reply, method, full_path, start_time).await;
}
async fn finalize(&self, stream: TcpStream, reply: Reply, method: &str, path: &str, start: Instant) {
{
let mut writer = BufWriter::with_capacity(65536, &stream);
let mut head = format!(
"HTTP/1.1 {} OK\r\nConnection: close\r\nTransfer-Encoding: chunked\r\n", reply.status
);
for (k, v) in &reply.headers {
head.push_str(&format!("{}: {}\r\n", k, v));
}
head.push_str("\r\n");
let _ = writer.write_all(head.as_bytes());
let mut b = reply.body;
while let Some(data) = b.next_chunk() {
let _ = writer.write_all(format!("{:X}\r\n", data.len()).as_bytes());
let _ = writer.write_all(&data);
let _ = writer.write_all(b"\r\n");
}
let _ = writer.write_all(b"0\r\n\r\n");
let _ = writer.flush();
}
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
println!(
"[{:02}:{:02}:{:02}] {} {} -> {} ({:?})",
(now/3600)%24, (now/60)%60, now%60,
method, path, reply.status, start.elapsed() );
let _ = stream.shutdown(Shutdown::Both);
}
}
pub fn execute<F: Future>(mut future: F) -> F::Output {
let mut future = unsafe { Pin::new_unchecked(&mut future) };
static VTABLE: RawWakerVTable = RawWakerVTable::new(
|p| RawWaker::new(p, &VTABLE), |_| {}, |_| {}, |_| {});
let waker = unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) };
let mut cx = Context::from_waker(&waker);
loop {
match future.as_mut().poll(&mut cx) {
Poll::Ready(v) => return v,
Poll::Pending => std::thread::yield_now()
}
}
}