use crate::errors::{error, nil, New};
use crate::net::http::request::{canonicalize, Header, Request};
use crate::net::http::response::ResponseWriter;
use crate::net::url::URL;
use crate::types::{int64, string};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex, OnceLock};
pub type HandlerFunc = Arc<dyn Fn(&mut ResponseWriter, &mut Request) + Send + Sync + 'static>;
pub trait Handler: Send + Sync + 'static {
#[allow(non_snake_case)]
fn ServeHTTP(&self, w: &mut ResponseWriter, r: &mut Request);
}
impl<F> Handler for F
where
F: Fn(&mut ResponseWriter, &mut Request) + Send + Sync + 'static,
{
fn ServeHTTP(&self, w: &mut ResponseWriter, r: &mut Request) { (self)(w, r) }
}
#[derive(Default, Clone)]
pub struct ServeMux {
routes: Arc<Mutex<Vec<(string, HandlerFunc)>>>,
}
impl ServeMux {
pub fn new() -> Self { ServeMux::default() }
#[allow(non_snake_case)]
pub fn HandleFunc<F>(&self, pattern: &str, f: F)
where
F: Fn(&mut ResponseWriter, &mut Request) + Send + Sync + 'static,
{
self.routes
.lock()
.unwrap()
.push((pattern.to_owned(), Arc::new(f)));
}
#[allow(non_snake_case)]
pub fn Handle<H: Handler>(&self, pattern: &str, h: H) {
let h = Arc::new(h);
self.HandleFunc(pattern, move |w, r| h.ServeHTTP(w, r));
}
fn match_route(&self, path: &str) -> Option<HandlerFunc> {
let g = self.routes.lock().unwrap();
let mut best: Option<(usize, HandlerFunc)> = None;
for (pat, h) in g.iter() {
if path.starts_with(pat.as_str())
&& pat.len() >= best.as_ref().map(|(l, _)| *l).unwrap_or(0)
{
best = Some((pat.len(), h.clone()));
}
}
best.map(|(_, h)| h)
}
}
impl Handler for ServeMux {
fn ServeHTTP(&self, w: &mut ResponseWriter, r: &mut Request) {
match self.match_route(&r.URL.Path) {
Some(h) => h(w, r),
None => {
w.WriteHeader(404);
let _ = w.Write(b"404 page not found\n");
}
}
}
}
fn default_mux() -> &'static ServeMux {
static MUX: OnceLock<ServeMux> = OnceLock::new();
MUX.get_or_init(ServeMux::new)
}
#[allow(non_snake_case)]
pub fn HandleFunc<F>(pattern: &str, f: F)
where
F: Fn(&mut ResponseWriter, &mut Request) + Send + Sync + 'static,
{
default_mux().HandleFunc(pattern, f);
}
#[allow(non_snake_case)]
pub fn ListenAndServe<H: IntoMux>(addr: &str, handler: H) -> error {
let mux = handler.into_mux().unwrap_or_else(|| default_mux().clone());
let srv = Server::new(addr, mux);
srv.ListenAndServe()
}
pub trait IntoMux {
fn into_mux(self) -> Option<ServeMux>;
}
impl IntoMux for ServeMux {
fn into_mux(self) -> Option<ServeMux> { Some(self) }
}
impl IntoMux for Option<ServeMux> {
fn into_mux(self) -> Option<ServeMux> { self }
}
impl IntoMux for crate::errors::error {
fn into_mux(self) -> Option<ServeMux> {
if self == crate::errors::nil { None } else { None }
}
}
pub struct Server {
pub Addr: string,
pub Handler: ServeMux,
}
impl Server {
pub fn new(addr: &str, handler: ServeMux) -> Self {
Server { Addr: addr.to_owned(), Handler: handler }
}
#[allow(non_snake_case)]
pub fn ListenAndServe(&self) -> error {
let addr_s = self.Addr.clone();
let handler = self.Handler.clone();
super::block_on(async move {
let addr = match parse_addr(&addr_s) {
Ok(a) => a,
Err(e) => return e,
};
match run_server(addr, handler).await {
Ok(()) => nil,
Err(e) => New(&format!("http: {}", e)),
}
})
}
}
fn parse_addr(addr: &str) -> Result<SocketAddr, error> {
let a = if let Some(rest) = addr.strip_prefix(':') {
format!("0.0.0.0:{}", rest)
} else {
addr.to_owned()
};
a.parse::<SocketAddr>()
.map_err(|e| New(&format!("http: parse addr {:?}: {}", addr, e)))
}
async fn run_server(addr: SocketAddr, handler: ServeMux) -> Result<(), std::io::Error> {
use hyper::server::conn::http1;
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
let listener = TcpListener::bind(addr).await?;
loop {
let (stream, remote) = listener.accept().await?;
let handler = handler.clone();
tokio::spawn(async move {
let io = TokioIo::new(stream);
let service = hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let handler = handler.clone();
let remote = remote;
async move { Ok::<_, Infallible>(serve_one(handler, remote, req).await) }
});
let _ = http1::Builder::new().serve_connection(io, service).await;
});
}
}
async fn serve_one(
handler: ServeMux,
remote: SocketAddr,
req: hyper::Request<hyper::body::Incoming>,
) -> hyper::Response<http_body_util::Full<bytes::Bytes>> {
use http_body_util::BodyExt;
let method = req.method().as_str().to_owned();
let uri = req.uri().clone();
let version = format!("{:?}", req.version());
let host_hdr = req
.headers()
.get(hyper::header::HOST)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_owned();
let mut header = Header::new();
for (name, value) in req.headers().iter() {
if let Ok(s) = value.to_str() {
header.Add(name.as_str(), s);
}
}
let body_bytes = req
.into_body()
.collect()
.await
.map(|c| c.to_bytes())
.unwrap_or_default();
let content_length = body_bytes.len() as int64;
let mut url = URL::default();
url.Path = uri.path().to_owned();
if let Some(q) = uri.query() { url.RawQuery = q.to_owned(); }
url.Host = host_hdr.clone();
let mut request = Request::from_parts(
method,
url,
version,
header,
crate::net::http::body::Body::from_bytes(body_bytes.to_vec()),
host_hdr,
remote.to_string(),
content_length,
);
let (ctx, cancel) = crate::context::WithCancel(crate::context::Background());
request.set_context(ctx);
struct CancelOnDrop(Option<crate::context::CancelFunc>);
impl Drop for CancelOnDrop {
fn drop(&mut self) {
if let Some(c) = self.0.take() { c.call(); }
}
}
let _cancel_guard = CancelOnDrop(Some(cancel));
let mut w = ResponseWriter::new();
let mut request = request;
let w = tokio::task::spawn_blocking(move || {
handler.ServeHTTP(&mut w, &mut request);
w
})
.await
.unwrap_or_else(|_| {
let mut w = ResponseWriter::new();
w.WriteHeader(500);
let _ = w.Write(b"handler panicked");
w
});
let mut builder = hyper::Response::builder()
.status(u16::try_from(w.status).unwrap_or(200));
for (k, vs) in w.header.iter() {
for v in vs {
builder = builder.header(k.as_str(), v);
}
}
if header_missing(&w.header, "Content-Type") {
builder = builder.header("Content-Type", "text/plain; charset=utf-8");
}
builder
.body(http_body_util::Full::new(bytes::Bytes::from(w.body)))
.unwrap()
}
fn header_missing(h: &Header, name: &str) -> bool {
let canon = canonicalize(name);
h.iter().all(|(k, _)| k.as_str() != canon)
}