mod test;
use std::{collections::VecDeque, str::FromStr};
use super::http::methods::HttpMethod;
use crate::{
http::{
errors::HttpError,
headers::{ContentType, HeaderEntry, HttpHeader, HttpHeaderValue},
requests::HttpRequest,
responses::HttpResponse,
status::HttpStatus,
},
utils::{
buffer::Buffer,
files::{get_file_extension, read_file},
},
};
use smallvec::{SmallVec, smallvec};
#[derive(Debug, Default, Clone)]
pub struct HttpTree {
path_part: String,
nodes: Vec<HttpTree>,
endpoint: Option<HttpEndpoint>,
}
impl HttpTree {
pub fn new() -> Self {
HttpTree::default()
}
pub fn add_endpoint(&mut self, path: &str, method: HttpMethod, endpoint_type: EndpointType) {
let path = if path.starts_with("/") {
path.strip_prefix('/').unwrap()
} else {
path
};
let path = if path.ends_with("/") {
path.strip_suffix('/').unwrap()
} else {
path
};
if path.is_empty() {
self.endpoint = Some(HttpEndpoint::new(method, endpoint_type));
} else {
let parts = path.split('/').map(|part| part.to_owned()).collect();
self.add_endpoint_rec(parts, method, endpoint_type);
}
}
fn add_endpoint_rec(
&mut self,
mut parts: VecDeque<String>,
method: HttpMethod,
endpoint_type: EndpointType,
) {
if let Some(part) = parts.pop_front() {
if let Some(node) = self.nodes.iter_mut().find(|x| x.path_part == part) {
node.add_endpoint_rec(parts, method, endpoint_type);
return;
};
let mut new_tree = HttpTree::new();
new_tree.path_part = part.to_owned();
new_tree.nodes = vec![];
if parts.is_empty() {
new_tree.endpoint = Some(HttpEndpoint::new(method, endpoint_type));
} else {
new_tree.add_endpoint_rec(parts, method, endpoint_type);
}
self.nodes.push(new_tree);
}
}
pub fn get_endpoint_from_path<'a>(
&'a self,
path: &'a str,
method: HttpMethod,
) -> Result<(&'a HttpEndpoint, &'a str), HttpError> {
let path = if path.starts_with('/') {
path.strip_prefix('/').unwrap()
} else {
path
};
if let Some(endpoint) = &self.endpoint {
match endpoint.endpoint_type {
EndpointType::Directory(_) => {
if endpoint.method == method || method == HttpMethod::HEAD {
return Ok((endpoint, path));
}
}
_ => {}
}
}
if path.is_empty() {
if let Some(endpoint) = &self.endpoint {
if endpoint.method == method || method == HttpMethod::HEAD {
return Ok((endpoint, path));
} else {
return Err(HttpError::MethodNotAllowed);
}
}
return Err(HttpError::NotFound);
}
let mut parts = path.splitn(2, '/');
let current = parts.next().unwrap();
let rest = parts.next().unwrap_or("");
for node in &self.nodes {
if node.path_part == current {
return node.get_endpoint_from_path(rest, method);
}
}
Err(HttpError::NotFound)
}
}
#[derive(Debug, Clone)]
pub struct HttpEndpoint {
method: HttpMethod,
endpoint_type: EndpointType,
}
impl HttpEndpoint {
pub fn new(method: HttpMethod, endpoint_type: EndpointType) -> Self {
Self {
method,
endpoint_type,
}
}
pub fn handle(
&self,
request: &HttpRequest,
response: &mut HttpResponse,
remaining_path: &str,
) -> Result<(), String> {
match &self.endpoint_type {
EndpointType::File(path) => self.handle_file(request, response, path),
EndpointType::Directory(path) => {
self.handle_directory(request, response, path, remaining_path)
}
EndpointType::Handler(handler) => handler(request, response),
}?;
self.apply_range_if_needed(response);
Ok(())
}
fn handle_file(
&self,
_: &HttpRequest,
response: &mut HttpResponse,
path: &str,
) -> Result<(), String> {
let _ = read_file(path)
.map(|content| {
let content_type = self.determine_content_type(path);
let mut headers = self.create_headers(content_type, content.len() as u64);
response.headers.append(&mut headers);
response.body = Some(Buffer::from(content));
response.status = HttpStatus::Ok;
})
.map_err(|_| response.status = HttpStatus::NotFound);
Ok(())
}
fn handle_directory(
&self,
_: &HttpRequest,
response: &mut HttpResponse,
path: &str,
remaining_path: &str,
) -> Result<(), String> {
let content_type = self.determine_content_type(remaining_path);
let parsed_path = remaining_path.replace("../", "");
let parsed_path = parsed_path.trim_start_matches("/");
let file_path = if !path.ends_with("/") && !path.is_empty() {
format!("{}/{}", path, parsed_path)
} else {
format!("{}{}", path, parsed_path)
};
let _ = read_file(&file_path)
.map(|content| {
let len = content.len() as u64;
response.body = Some(Buffer::from(content));
let mut default_headers = self.create_headers(content_type, len);
response.headers.append(&mut default_headers);
response.status = HttpStatus::Ok;
})
.map_err(|_| response.status = HttpStatus::NotFound);
Ok(())
}
fn determine_content_type(&self, path: &str) -> ContentType {
get_file_extension(path)
.map(|ext| ContentType::from_str(&ext).unwrap_or(ContentType::ApplicationOctetStream))
.unwrap_or(ContentType::ApplicationOctetStream)
}
fn create_headers(
&self,
content_type: ContentType,
content_length: u64,
) -> SmallVec<[HeaderEntry; 4]> {
smallvec![
HeaderEntry::new(
HttpHeader::ContentType,
HttpHeaderValue::ContentType(content_type),
),
HeaderEntry::new(
HttpHeader::ContentLength,
HttpHeaderValue::ContentLength(content_length),
),
]
}
fn apply_range_if_needed(&self, mut response: &mut HttpResponse) {
if let Some((begin, end)) = response.range {
if let Some(ref body) = response.body {
let body_str = body.to_string();
let begin = begin as usize;
let end = end as usize;
if begin < body_str.len() {
let end = std::cmp::min(end + 1, body_str.len());
let slice = &body_str[begin..end];
if let Ok(sliced_body) = Buffer::from_str(slice) {
response.body = Some(sliced_body);
if response.status == HttpStatus::Ok {
response.status = HttpStatus::PartialContent;
}
self.update_content_length(&mut response, (end - begin) as u64);
}
}
}
}
}
fn update_content_length(&self, response: &mut HttpResponse, length: u64) {
if let Some(content_length_header) = response.find_header(HttpHeader::ContentLength) {
content_length_header.value = HttpHeaderValue::ContentLength(length);
} else {
response.headers.push(HeaderEntry {
name: HttpHeader::ContentLength,
value: HttpHeaderValue::ContentLength(length),
});
}
}
}
#[derive(Debug, Clone)]
pub enum EndpointType {
File(String),
Directory(String),
Handler(fn(&HttpRequest, &mut HttpResponse) -> Result<(), String>),
}