francoisgib_webserver 1.0.3

HTTP Webserver
Documentation
mod logger;
mod metrics;
mod test;

use crate::{
    config::{CARGO_PKG_NAME, CARGO_PKG_VERSION, Config},
    http::{
        errors::HttpError,
        headers::{HeaderEntry, HttpHeader, HttpHeaderValue},
        methods::HttpMethod,
        requests::{
            HttpRequest,
            body::read_request_body,
            parser::{parse_headers, parse_request_line},
        },
        responses::HttpResponse,
        status::HttpStatus,
    },
    tree::{EndpointType, HttpEndpoint, HttpTree},
};
use chrono::Utc;
use logger::Logger;
use metrics::MetricsHandler;
use smallvec::smallvec;
use std::{
    fs::exists,
    sync::Arc,
    time::{Duration, Instant},
};
use tokio::{
    io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
    net::{TcpListener, TcpStream},
    runtime::Builder,
    signal, spawn,
    sync::{Notify, RwLock},
};

/// Default handler for requests that don't match any registered endpoint.
/// Returns a simple 404 Not Found response.
///
/// # Arguments
/// * `_` - The HTTP request (unused in this default implementation)
///
/// # Returns
/// A 404 HTTP response with no body or headers
pub fn default_not_found_handler(
    _: &HttpRequest,
    response: &mut HttpResponse,
) -> Result<(), String> {
    response.status = HttpStatus::NotFound;
    response.version = (1, 1);
    Ok(())
}

/// Handles a client connection by processing HTTP requests.
///
/// This function:
/// 1. Reads HTTP requests from the stream
/// 2. Parses the request line and headers
/// 3. Finds the appropriate endpoint handler
/// 4. Generates and sends a response
/// 5. Continues processing requests as long as keep-alive is enabled
///
/// # Arguments
/// * `server` - Read guard to the HTTP server
/// * `mut stream` - TCP stream for the client connection
///
/// # Returns
/// `Ok(())` if the connection is handled successfully, otherwise an error
async fn handle_client(server: Arc<HttpServer>, stream: &mut TcpStream) -> Result<(), String> {
    let timeout_duration = Duration::from_secs(5);

    loop {
        let mut reader = BufReader::new(&mut *stream);
        let start_time = Instant::now();

        loop {
            if start_time.elapsed() >= timeout_duration {
                return Ok(());
            }
            if !reader
                .fill_buf()
                .await
                .map_err(|e| e.to_string())?
                .is_empty()
            {
                break;
            }
        }

        let (method_str, uri, version) = match parse_request_line(&mut reader).await {
            Ok(res) => res,
            Err(err) => {
                let mut response = server.get_default_response();
                response.status = HttpStatus::from(err);
                return server.write_response(response, stream).await;
            }
        };

        if !server.is_supported_version(version) {
            let mut response = server.get_default_response();
            response.status = HttpStatus::from(HttpError::HttpVersionNotSupported);
            return server.write_response(response, stream).await;
        }

        let headers = parse_headers(&mut reader).await?;
        let mut request = HttpRequest::new(method_str, uri, version, headers, None);
        request.process_headers();

        request.body = read_request_body(&mut reader, &request).await?;

        let response = server.get_response(&mut request).await;

        if let Some(metrics_handler) = &server.metrics_handler {
            if metrics_handler.read().await.metrics_endpoint != request.uri {
                metrics_handler
                    .write()
                    .await
                    .add_metrics(&request, &response);
            }
        }

        server
            .logger
            .write_req_res(&request, &response, &stream.peer_addr().unwrap());
        server.write_response(response, stream).await?;

        if !request.keep_alive {
            break;
        }
    }
    Ok(())
}

/// A simple HTTP server that routes requests to registered endpoints.
///
/// The server:
/// - Uses a thread pool to handle concurrent connections
/// - Routes requests based on a tree structure of endpoints
/// - Provides a default 404 handler for unmatched routes
/// - Supports keep-alive connections
pub struct HttpServer {
    name: String,
    hostname: String,
    tree: HttpTree,
    not_found_endpoint: HttpEndpoint,
    port: u16,
    logger: Logger,
    metrics_handler: Option<RwLock<MetricsHandler>>,
    worker_threads: Option<usize>,
}

impl HttpServer {
    pub fn new(tree: HttpTree, config: Config) -> Self {
        let pool_config = config.pool;
        let server_config = config.server;

        let name = server_config.name.unwrap_or(CARGO_PKG_NAME.to_owned());

        let hostname = server_config
            .hostname
            .unwrap_or(format!("localhost:{}", server_config.port));

        let not_found_endpoint = server_config
            .not_found_endpoint_path
            .inspect(|endpoint| {
                if !exists(&endpoint).unwrap() {
                    println!("Warn: The provided not found endpoint file path does not exist.");
                }
            })
            .map(|endpoint| HttpEndpoint::new(HttpMethod::GET, EndpointType::File(endpoint)))
            .or(Some(HttpEndpoint::new(
                HttpMethod::GET,
                EndpointType::Handler(default_not_found_handler),
            )))
            .unwrap();

        let metrics_handler = if let Some(metrics_endpoint) = server_config.metrics_endpoint {
            let metrics_endpoint = if !metrics_endpoint.starts_with('/') {
                format!("/{}", metrics_endpoint)
            } else {
                metrics_endpoint
            };
            Some(RwLock::new(MetricsHandler::new(
                name.to_owned(),
                metrics_endpoint,
            )))
        } else {
            None
        };

        let logger = Logger::new(
            server_config.log_file,
            server_config.log_level,
            server_config.debug,
        );

        Self {
            name,
            hostname,
            tree,
            not_found_endpoint,
            port: server_config.port,
            logger,
            metrics_handler,
            worker_threads: pool_config.worker_threads,
        }
    }

    pub fn set_worker_threads(&mut self, worker_threads: usize) {
        self.worker_threads = Some(worker_threads);
    }

    // Setup a tokio runtime and start the server into it.
    pub fn start(self) {
        let mut runtime_builder = Builder::new_multi_thread();
        if let Some(worker_threads) = self.worker_threads {
            runtime_builder.worker_threads(worker_threads);
        }
        let runtime = runtime_builder.enable_all().build().unwrap();
        runtime.block_on(async {
            println!(
                "Runtime started with {} worker threads.",
                runtime.metrics().num_workers()
            );
            let _ = self.async_start().await;
        })
    }

    // Start the server with the tokio runtime.
    pub async fn async_start(self) {
        let listener = match TcpListener::bind(format!("127.0.0.1:{}", self.port)).await {
            Ok(listener) => listener,
            Err(e) => {
                eprintln!("Error binding to port {}: {}", self.port, e);
                return ();
            }
        };

        println!("Starting server...");

        if let Some(metrics_handler) = &self.metrics_handler {
            println!(
                "Metrics available at: http://{}{}",
                self.hostname,
                metrics_handler.read().await.metrics_endpoint
            );
        }

        println!("Waiting for connections on port {}.", self.port);

        let server = Arc::new(self);
        let shutdown_signal = Arc::new(Notify::new());
        let shutdown_signal_clone = shutdown_signal.clone();

        spawn(async move {
            match signal::ctrl_c().await {
                Ok(()) => {
                    println!("\nCtrl-C received, initiating graceful shutdown...");
                    shutdown_signal_clone.notify_one();
                }
                Err(err) => {
                    eprintln!("Error setting up Ctrl-C handler: {}", err);
                }
            }
        });

        loop {
            tokio::select! {
                accept_result = listener.accept() => {
                    match accept_result {
                        Ok((mut stream, _)) => {
                            let server_arc_clone = Arc::clone(&server);
                            spawn(async move {
                                println!("Accepted connection from {}", stream.peer_addr().unwrap());
                                let _ = handle_client(server_arc_clone, &mut stream).await.inspect_err(|e| eprintln!("Error handling client: {}", e));
                                stream.shutdown().await.unwrap();
                            });
                        }
                        Err(e) => {
                            eprintln!("Error accepting connection: {}", e);
                        }
                    };
                }
                _ = shutdown_signal.notified() => {
                    println!("Shutdown signal received. Closing listener...");
                    break;
                }
            }
        }
        println!("Server shutdown complete.");
    }

    pub async fn get_response(&self, request: &mut HttpRequest) -> HttpResponse {
        let mut response = self.get_default_response();

        let endpoint = self
            .tree
            .get_endpoint_from_path(&request.uri, request.method);

        match endpoint {
            Ok((endpoint, path)) => {
                let _ = endpoint.handle(&request, &mut response, path).map_err(|e| {
                    println!("Error handling request: {}", e);
                    response.status = HttpStatus::InternalServerError;
                });
            }
            Err(error) => {
                if error == HttpError::NotFound {
                    self.not_found_endpoint
                        .handle(request, &mut response, "")
                        .unwrap()
                } else {
                    response.status = HttpStatus::from(error);
                }
            }
        }

        if request.method == HttpMethod::HEAD {
            request.body = None;
            request.body_type = None;
        }

        if let Some(metrics_handler) = &self.metrics_handler {
            let lock = metrics_handler.read().await;
            if request.uri == lock.metrics_endpoint {
                response = metrics_handler.read().await.metrics();
            }
        }
        response.range = request.range;
        response
    }

    pub async fn write_response(
        &self,
        response: HttpResponse,
        stream: &mut TcpStream,
    ) -> Result<(), String> {
        let mut writer = BufWriter::new(&mut *stream);
        writer
            .write_all(&response.to_bytes())
            .await
            .map_err(|error| format!("Error writing to stream : {}", error))?;
        writer
            .flush()
            .await
            .map_err(|error| format!("Error flushing stream : {}", error))?;
        Ok(())
    }

    pub fn get_default_response(&self) -> HttpResponse {
        let host_header = HeaderEntry::new(
            HttpHeader::Server,
            HttpHeaderValue::Server(format!("{}/{}", self.name, CARGO_PKG_VERSION)),
        );
        let date_header = HeaderEntry::new(HttpHeader::Date, HttpHeaderValue::Date(Utc::now()));

        HttpResponse::new(
            (1, 1),
            HttpStatus::Ok,
            smallvec![host_header, date_header],
            None,
            None,
        )
    }

    pub fn is_supported_version(&self, version: (u8, u8)) -> bool {
        version <= (1, 1)
    }
}