francoisgib_webserver 1.0.3

HTTP Webserver
Documentation
mod test;

use std::{collections::VecDeque, str::FromStr};

use super::http::methods::HttpMethod;
use crate::{
    http::{
        errors::HttpError,
        headers::{ContentType, HeaderEntry, HttpHeader, HttpHeaderValue},
        requests::HttpRequest,
        responses::HttpResponse,
        status::HttpStatus,
    },
    utils::{
        buffer::Buffer,
        files::{get_file_extension, read_file},
    },
};
use smallvec::{SmallVec, smallvec};

#[derive(Debug, Default, Clone)]
pub struct HttpTree {
    path_part: String,
    nodes: Vec<HttpTree>,
    endpoint: Option<HttpEndpoint>,
}

impl HttpTree {
    /// Creates a new empty `HttpTree`.
    pub fn new() -> Self {
        HttpTree::default()
    }

    /// Adds an endpoint to the tree.
    ///
    /// # Arguments
    /// * `path` - The URL path (e.g., "/api/data").
    /// * `method` - The HTTP method (`GET`, `POST`, etc.).
    /// * `endpoint_type` - The type of endpoint (file, directory, or handler).
    pub fn add_endpoint(&mut self, path: &str, method: HttpMethod, endpoint_type: EndpointType) {
        let path = if path.starts_with("/") {
            path.strip_prefix('/').unwrap()
        } else {
            path
        };

        let path = if path.ends_with("/") {
            path.strip_suffix('/').unwrap()
        } else {
            path
        };

        if path.is_empty() {
            self.endpoint = Some(HttpEndpoint::new(method, endpoint_type));
        } else {
            let parts = path.split('/').map(|part| part.to_owned()).collect();
            self.add_endpoint_rec(parts, method, endpoint_type);
        }
    }

    /// Adds an endpoint to the tree recursively.
    /// It will find the corresponding leaf or node in the tree to insert the endpoint.
    ///
    /// # Arguments
    /// * `parts` - The parts of the URL path.
    /// * `method` - The HTTP method (`GET`, `POST`, etc.).
    /// * `endpoint_type` - The type of endpoint (file, directory, or handler).
    fn add_endpoint_rec(
        &mut self,
        mut parts: VecDeque<String>,
        method: HttpMethod,
        endpoint_type: EndpointType,
    ) {
        if let Some(part) = parts.pop_front() {
            if let Some(node) = self.nodes.iter_mut().find(|x| x.path_part == part) {
                node.add_endpoint_rec(parts, method, endpoint_type);
                return;
            };

            let mut new_tree = HttpTree::new();
            new_tree.path_part = part.to_owned();
            new_tree.nodes = vec![];
            if parts.is_empty() {
                new_tree.endpoint = Some(HttpEndpoint::new(method, endpoint_type));
            } else {
                new_tree.add_endpoint_rec(parts, method, endpoint_type);
            }
            self.nodes.push(new_tree);
        }
    }

    /// Tries to find the endpoint matching the given path and method.
    ///
    /// # Returns
    /// * `Ok((endpoint, remaining_path))` if found.
    /// * `Err(HttpError::NotFound)` if the path does not exist.
    /// * `Err(HttpError::MethodNotAllowed)` if method mismatch.
    pub fn get_endpoint_from_path<'a>(
        &'a self,
        path: &'a str,
        method: HttpMethod,
    ) -> Result<(&'a HttpEndpoint, &'a str), HttpError> {
        let path = if path.starts_with('/') {
            path.strip_prefix('/').unwrap()
        } else {
            path
        };

        if let Some(endpoint) = &self.endpoint {
            match endpoint.endpoint_type {
                EndpointType::Directory(_) => {
                    if endpoint.method == method || method == HttpMethod::HEAD {
                        return Ok((endpoint, path));
                    }
                }
                _ => {}
            }
        }

        if path.is_empty() {
            if let Some(endpoint) = &self.endpoint {
                if endpoint.method == method || method == HttpMethod::HEAD {
                    return Ok((endpoint, path));
                } else {
                    return Err(HttpError::MethodNotAllowed);
                }
            }
            return Err(HttpError::NotFound);
        }

        let mut parts = path.splitn(2, '/');
        let current = parts.next().unwrap();
        let rest = parts.next().unwrap_or("");

        for node in &self.nodes {
            if node.path_part == current {
                return node.get_endpoint_from_path(rest, method);
            }
        }
        Err(HttpError::NotFound)
    }
}

/// An HTTP endpoint that handles incoming requests.
#[derive(Debug, Clone)]
pub struct HttpEndpoint {
    method: HttpMethod,
    endpoint_type: EndpointType,
}

impl HttpEndpoint {
    /// Creates a new `HttpEndpoint` from a method and endpoint type.
    pub fn new(method: HttpMethod, endpoint_type: EndpointType) -> Self {
        Self {
            method,
            endpoint_type,
        }
    }

    pub fn handle(
        &self,
        request: &HttpRequest,
        response: &mut HttpResponse,
        remaining_path: &str,
    ) -> Result<(), String> {
        match &self.endpoint_type {
            EndpointType::File(path) => self.handle_file(request, response, path),
            EndpointType::Directory(path) => {
                self.handle_directory(request, response, path, remaining_path)
            }
            EndpointType::Handler(handler) => handler(request, response),
        }?;
        self.apply_range_if_needed(response);
        Ok(())
    }

    fn handle_file(
        &self,
        _: &HttpRequest,
        response: &mut HttpResponse,
        path: &str,
    ) -> Result<(), String> {
        let _ = read_file(path)
            .map(|content| {
                let content_type = self.determine_content_type(path);
                let mut headers = self.create_headers(content_type, content.len() as u64);
                response.headers.append(&mut headers);
                response.body = Some(Buffer::from(content));
                response.status = HttpStatus::Ok;
            })
            .map_err(|_| response.status = HttpStatus::NotFound);
        Ok(())
    }

    fn handle_directory(
        &self,
        _: &HttpRequest,
        response: &mut HttpResponse,
        path: &str,
        remaining_path: &str,
    ) -> Result<(), String> {
        let content_type = self.determine_content_type(remaining_path);

        let parsed_path = remaining_path.replace("../", "");
        let parsed_path = parsed_path.trim_start_matches("/");

        let file_path = if !path.ends_with("/") && !path.is_empty() {
            format!("{}/{}", path, parsed_path)
        } else {
            format!("{}{}", path, parsed_path)
        };

        let _ = read_file(&file_path)
            .map(|content| {
                let len = content.len() as u64;
                response.body = Some(Buffer::from(content));
                let mut default_headers = self.create_headers(content_type, len);
                response.headers.append(&mut default_headers);
                response.status = HttpStatus::Ok;
            })
            .map_err(|_| response.status = HttpStatus::NotFound);
        Ok(())
    }

    fn determine_content_type(&self, path: &str) -> ContentType {
        get_file_extension(path)
            .map(|ext| ContentType::from_str(&ext).unwrap_or(ContentType::ApplicationOctetStream))
            .unwrap_or(ContentType::ApplicationOctetStream)
    }

    fn create_headers(
        &self,
        content_type: ContentType,
        content_length: u64,
    ) -> SmallVec<[HeaderEntry; 4]> {
        smallvec![
            HeaderEntry::new(
                HttpHeader::ContentType,
                HttpHeaderValue::ContentType(content_type),
            ),
            HeaderEntry::new(
                HttpHeader::ContentLength,
                HttpHeaderValue::ContentLength(content_length),
            ),
        ]
    }

    fn apply_range_if_needed(&self, mut response: &mut HttpResponse) {
        if let Some((begin, end)) = response.range {
            if let Some(ref body) = response.body {
                let body_str = body.to_string();
                let begin = begin as usize;
                let end = end as usize;

                if begin < body_str.len() {
                    let end = std::cmp::min(end + 1, body_str.len());
                    let slice = &body_str[begin..end];

                    if let Ok(sliced_body) = Buffer::from_str(slice) {
                        response.body = Some(sliced_body);

                        if response.status == HttpStatus::Ok {
                            response.status = HttpStatus::PartialContent;
                        }

                        self.update_content_length(&mut response, (end - begin) as u64);
                    }
                }
            }
        }
    }

    fn update_content_length(&self, response: &mut HttpResponse, length: u64) {
        if let Some(content_length_header) = response.find_header(HttpHeader::ContentLength) {
            content_length_header.value = HttpHeaderValue::ContentLength(length);
        } else {
            response.headers.push(HeaderEntry {
                name: HttpHeader::ContentLength,
                value: HttpHeaderValue::ContentLength(length),
            });
        }
    }
}

/// Represents the different types of HTTP endpoints that can be registered in the router.
#[derive(Debug, Clone)]
pub enum EndpointType {
    /// A file endpoint serves a file from the filesystem
    File(String),

    /// A directory endpoint serves a directory listing
    Directory(String),

    /// A handler endpoint calls a function to dynamically generate a response
    Handler(fn(&HttpRequest, &mut HttpResponse) -> Result<(), String>),
}