mod logger;
mod metrics;
mod test;
use crate::{
config::{CARGO_PKG_NAME, CARGO_PKG_VERSION, Config},
http::{
errors::HttpError,
headers::{HeaderEntry, HttpHeader, HttpHeaderValue},
methods::HttpMethod,
requests::{
HttpRequest,
body::read_request_body,
parser::{parse_headers, parse_request_line},
},
responses::HttpResponse,
status::HttpStatus,
},
tree::{EndpointType, HttpEndpoint, HttpTree},
};
use chrono::Utc;
use logger::Logger;
use metrics::MetricsHandler;
use smallvec::smallvec;
use std::{
fs::exists,
sync::Arc,
time::{Duration, Instant},
};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
net::{TcpListener, TcpStream},
runtime::Builder,
signal, spawn,
sync::{Notify, RwLock},
};
pub fn default_not_found_handler(
_: &HttpRequest,
response: &mut HttpResponse,
) -> Result<(), String> {
response.status = HttpStatus::NotFound;
response.version = (1, 1);
Ok(())
}
async fn handle_client(server: Arc<HttpServer>, stream: &mut TcpStream) -> Result<(), String> {
let timeout_duration = Duration::from_secs(5);
loop {
let mut reader = BufReader::new(&mut *stream);
let start_time = Instant::now();
loop {
if start_time.elapsed() >= timeout_duration {
return Ok(());
}
if !reader
.fill_buf()
.await
.map_err(|e| e.to_string())?
.is_empty()
{
break;
}
}
let (method_str, uri, version) = match parse_request_line(&mut reader).await {
Ok(res) => res,
Err(err) => {
let mut response = server.get_default_response();
response.status = HttpStatus::from(err);
return server.write_response(response, stream).await;
}
};
if !server.is_supported_version(version) {
let mut response = server.get_default_response();
response.status = HttpStatus::from(HttpError::HttpVersionNotSupported);
return server.write_response(response, stream).await;
}
let headers = parse_headers(&mut reader).await?;
let mut request = HttpRequest::new(method_str, uri, version, headers, None);
request.process_headers();
request.body = read_request_body(&mut reader, &request).await?;
let response = server.get_response(&mut request).await;
if let Some(metrics_handler) = &server.metrics_handler {
if metrics_handler.read().await.metrics_endpoint != request.uri {
metrics_handler
.write()
.await
.add_metrics(&request, &response);
}
}
server
.logger
.write_req_res(&request, &response, &stream.peer_addr().unwrap());
server.write_response(response, stream).await?;
if !request.keep_alive {
break;
}
}
Ok(())
}
pub struct HttpServer {
name: String,
hostname: String,
tree: HttpTree,
not_found_endpoint: HttpEndpoint,
port: u16,
logger: Logger,
metrics_handler: Option<RwLock<MetricsHandler>>,
worker_threads: Option<usize>,
}
impl HttpServer {
pub fn new(tree: HttpTree, config: Config) -> Self {
let pool_config = config.pool;
let server_config = config.server;
let name = server_config.name.unwrap_or(CARGO_PKG_NAME.to_owned());
let hostname = server_config
.hostname
.unwrap_or(format!("localhost:{}", server_config.port));
let not_found_endpoint = server_config
.not_found_endpoint_path
.inspect(|endpoint| {
if !exists(&endpoint).unwrap() {
println!("Warn: The provided not found endpoint file path does not exist.");
}
})
.map(|endpoint| HttpEndpoint::new(HttpMethod::GET, EndpointType::File(endpoint)))
.or(Some(HttpEndpoint::new(
HttpMethod::GET,
EndpointType::Handler(default_not_found_handler),
)))
.unwrap();
let metrics_handler = if let Some(metrics_endpoint) = server_config.metrics_endpoint {
let metrics_endpoint = if !metrics_endpoint.starts_with('/') {
format!("/{}", metrics_endpoint)
} else {
metrics_endpoint
};
Some(RwLock::new(MetricsHandler::new(
name.to_owned(),
metrics_endpoint,
)))
} else {
None
};
let logger = Logger::new(
server_config.log_file,
server_config.log_level,
server_config.debug,
);
Self {
name,
hostname,
tree,
not_found_endpoint,
port: server_config.port,
logger,
metrics_handler,
worker_threads: pool_config.worker_threads,
}
}
pub fn set_worker_threads(&mut self, worker_threads: usize) {
self.worker_threads = Some(worker_threads);
}
pub fn start(self) {
let mut runtime_builder = Builder::new_multi_thread();
if let Some(worker_threads) = self.worker_threads {
runtime_builder.worker_threads(worker_threads);
}
let runtime = runtime_builder.enable_all().build().unwrap();
runtime.block_on(async {
println!(
"Runtime started with {} worker threads.",
runtime.metrics().num_workers()
);
let _ = self.async_start().await;
})
}
pub async fn async_start(self) {
let listener = match TcpListener::bind(format!("127.0.0.1:{}", self.port)).await {
Ok(listener) => listener,
Err(e) => {
eprintln!("Error binding to port {}: {}", self.port, e);
return ();
}
};
println!("Starting server...");
if let Some(metrics_handler) = &self.metrics_handler {
println!(
"Metrics available at: http://{}{}",
self.hostname,
metrics_handler.read().await.metrics_endpoint
);
}
println!("Waiting for connections on port {}.", self.port);
let server = Arc::new(self);
let shutdown_signal = Arc::new(Notify::new());
let shutdown_signal_clone = shutdown_signal.clone();
spawn(async move {
match signal::ctrl_c().await {
Ok(()) => {
println!("\nCtrl-C received, initiating graceful shutdown...");
shutdown_signal_clone.notify_one();
}
Err(err) => {
eprintln!("Error setting up Ctrl-C handler: {}", err);
}
}
});
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((mut stream, _)) => {
let server_arc_clone = Arc::clone(&server);
spawn(async move {
println!("Accepted connection from {}", stream.peer_addr().unwrap());
let _ = handle_client(server_arc_clone, &mut stream).await.inspect_err(|e| eprintln!("Error handling client: {}", e));
stream.shutdown().await.unwrap();
});
}
Err(e) => {
eprintln!("Error accepting connection: {}", e);
}
};
}
_ = shutdown_signal.notified() => {
println!("Shutdown signal received. Closing listener...");
break;
}
}
}
println!("Server shutdown complete.");
}
pub async fn get_response(&self, request: &mut HttpRequest) -> HttpResponse {
let mut response = self.get_default_response();
let endpoint = self
.tree
.get_endpoint_from_path(&request.uri, request.method);
match endpoint {
Ok((endpoint, path)) => {
let _ = endpoint.handle(&request, &mut response, path).map_err(|e| {
println!("Error handling request: {}", e);
response.status = HttpStatus::InternalServerError;
});
}
Err(error) => {
if error == HttpError::NotFound {
self.not_found_endpoint
.handle(request, &mut response, "")
.unwrap()
} else {
response.status = HttpStatus::from(error);
}
}
}
if request.method == HttpMethod::HEAD {
request.body = None;
request.body_type = None;
}
if let Some(metrics_handler) = &self.metrics_handler {
let lock = metrics_handler.read().await;
if request.uri == lock.metrics_endpoint {
response = metrics_handler.read().await.metrics();
}
}
response.range = request.range;
response
}
pub async fn write_response(
&self,
response: HttpResponse,
stream: &mut TcpStream,
) -> Result<(), String> {
let mut writer = BufWriter::new(&mut *stream);
writer
.write_all(&response.to_bytes())
.await
.map_err(|error| format!("Error writing to stream : {}", error))?;
writer
.flush()
.await
.map_err(|error| format!("Error flushing stream : {}", error))?;
Ok(())
}
pub fn get_default_response(&self) -> HttpResponse {
let host_header = HeaderEntry::new(
HttpHeader::Server,
HttpHeaderValue::Server(format!("{}/{}", self.name, CARGO_PKG_VERSION)),
);
let date_header = HeaderEntry::new(HttpHeader::Date, HttpHeaderValue::Date(Utc::now()));
HttpResponse::new(
(1, 1),
HttpStatus::Ok,
smallvec![host_header, date_header],
None,
None,
)
}
pub fn is_supported_version(&self, version: (u8, u8)) -> bool {
version <= (1, 1)
}
}