use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
use tokio::net::{TcpListener, TcpStream};
use crate::error::Error;
use crate::method::Method;
use crate::middleware::Next;
use crate::request::Request;
use crate::response::Response;
use crate::router::Router;
use crate::status::Status;
pub struct Server {
addr: SocketAddr,
}
impl Server {
pub fn bind(addr: &str) -> Self {
let addr: SocketAddr = addr.parse().expect("invalid socket address");
Self { addr }
}
pub async fn serve(self, router: Router) -> Result<(), Error> {
let listener = TcpListener::bind(self.addr).await?;
let router = Arc::new(router);
let mut tasks = tokio::task::JoinSet::new();
let shutdown = shutdown_signal();
tokio::pin!(shutdown);
loop {
tokio::select! {
biased;
() = &mut shutdown => {
break;
}
res = listener.accept() => {
let (stream, _remote_addr) = match res {
Ok(v) => v,
Err(_) => continue,
};
let router = Arc::clone(&router);
tasks.spawn(async move {
let _ = serve_connection(stream, router).await;
});
}
Some(_) = tasks.join_next(), if !tasks.is_empty() => {}
}
}
while tasks.join_next().await.is_some() {}
Ok(())
}
}
async fn serve_connection(stream: TcpStream, router: Arc<Router>) -> Result<(), Error> {
let (read_half, mut write_half) = stream.into_split();
let mut reader = BufReader::new(read_half);
loop {
let mut line = String::new();
if reader.read_line(&mut line).await? == 0 {
break; }
let line = line.trim_end();
let mut parts = line.splitn(3, ' ');
let method_str = parts.next().unwrap_or("");
let raw = parts.next().unwrap_or("/");
let (path, query) = match raw.find('?') {
Some(i) => (&raw[..i], &raw[i + 1..]),
None => (raw, ""),
};
let path = path.to_owned();
let query = query.to_owned();
let Ok(method) = method_str.parse::<Method>() else { break };
let mut headers: Vec<(String, String)> = Vec::new();
loop {
let mut hline = String::new();
reader.read_line(&mut hline).await?;
let hline = hline.trim_end();
if hline.is_empty() { break; }
if let Some((name, value)) = hline.split_once(": ") {
headers.push((name.to_owned(), value.to_owned()));
}
}
let body = read_body(&mut reader, &headers).await?;
let response = match router.lookup(method, &path) {
Some((handler, middleware, params)) => {
let req = Request::new(body, headers, method, params, path, query);
Next::new(middleware, handler).call(req).await
}
None => Response::status(Status::NotFound),
};
response.write_to(&mut write_half).await?;
}
Ok(())
}
async fn read_body<R: AsyncBufReadExt + Unpin>(
reader: &mut R,
headers: &[(String, String)],
) -> Result<Vec<u8>, Error> {
if let Some(len) = headers.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-length"))
.and_then(|(_, v)| v.trim().parse::<usize>().ok())
{
let mut buf = vec![0u8; len];
reader.read_exact(&mut buf).await?;
return Ok(buf);
}
Ok(Vec::new())
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl-C handler");
};
#[cfg(unix)]
let sigterm = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let sigterm = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {}
() = sigterm => {}
}
}