use actix_web::http::header::HeaderMap;
use actix_web::http::{Method, StatusCode};
use actix_web::web::Bytes;
use actix_web::{web, App, HttpRequest, HttpResponse, HttpServer};
pub use askama;
pub use askama::Template;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
pub mod authentication;
pub type ArcRenderModel = Arc<dyn RenderModel>;
#[derive(Clone)]
pub struct RequestContext {
pub params: HashMap<String, String>,
pub path_params: HashMap<String, String>,
pub headers: HeaderMap,
pub path: String,
pub body: Vec<u8>,
pub method: HttpMethod,
pub rules: Vec<RouteRules>,
pub user: Option<User>,
}
#[derive(Clone)]
pub struct User {
pub name: String,
pub roles: Vec<String>,
}
#[derive(Clone)]
pub enum ActionResult {
Html(String),
View(ArcRenderModel),
Redirect(String),
File(String),
NotFound,
PayloadTooLarge(String),
UnAuthorized(String),
Forbidden(String),
Ok(String),
BadRequest(String),
StatusCode(u16, String),
}
pub trait RenderModel: Send + Sync {
fn render_html(&self) -> Result<String, askama::Error>;
}
impl<T: askama::Template + Send + Sync> RenderModel for T {
fn render_html(&self) -> Result<String, askama::Error> {
self.render()
}
}
pub type ActionFn = Arc<dyn Fn(RequestContext) -> ActionResult + Send + Sync + 'static>;
pub type MiddlewareFn =
Arc<dyn Fn(RequestContext, ActionFn) -> ActionResult + Send + Sync + 'static>;
#[derive(Clone, PartialEq, Eq)]
pub enum RouteRules {
Authorize,
AllowAnonymous,
Roles(Vec<String>),
RequestSizeLimit(usize),
}
#[derive(Clone, PartialEq)]
pub enum HttpMethod {
GET,
POST,
PUT,
DELETE,
PATCH,
OPTIONS,
HEAD,
TRACE,
CONNECT,
NotSupported,
}
#[derive(Clone)]
pub struct Route {
pub path: String,
pub action: ActionFn,
pub rules: Vec<RouteRules>,
pub method: HttpMethod,
}
pub struct Server {
routes: Vec<Route>,
middlewares: Vec<MiddlewareFn>,
}
impl Server {
pub fn new() -> Self {
let mut server = Self {
routes: Vec::new(),
middlewares: Vec::new(),
};
server.add_middleware(|ctx, next| {
println!("--- Incoming Request ---");
println!("Path: {}", ctx.path);
println!("Query Params: {:?}", ctx.params);
println!("Headers:");
for (key, value) in ctx.headers.iter() {
println!(" {}: {:?}", key, value);
}
println!("------------------------");
let result = next(ctx.clone());
match &result {
ActionResult::Html(_) => println!("Response: Html"),
ActionResult::View(_) => println!("Response: View"),
ActionResult::Redirect(url) => println!("Response: Redirect to {:?}", url),
ActionResult::File(path) => println!("Response: File {:?}", path),
ActionResult::NotFound => println!("Response: NotFound"),
ActionResult::PayloadTooLarge(content) => println!("Response: {:?}", content),
ActionResult::Forbidden(content) => println!("Response: {:?}", content),
ActionResult::UnAuthorized(content) => println!("Response: {:?}", content),
ActionResult::Ok(content) => println!("Response: {:?}", content),
ActionResult::BadRequest(content) => println!("Response: {:?}", content),
ActionResult::StatusCode(code, body) => println!("Response: {:?} {:?}", code, body),
}
println!("--- End of Request ---\n");
result
});
server
}
fn match_and_extract_params(pattern: &str, path: &str) -> Option<HashMap<String, String>> {
let pattern_segments: Vec<&str> = pattern.split('/').collect();
let path_segments: Vec<&str> = path.split('/').collect();
if pattern_segments.len() != path_segments.len() {
return None;
}
let mut params = HashMap::new();
for (p_segment, r_segment) in pattern_segments.iter().zip(path_segments.iter()) {
if p_segment.starts_with('{') && p_segment.ends_with('}') {
let key = p_segment.trim_matches(|c| c == '{' || c == '}').to_string();
params.insert(key, r_segment.to_string());
} else if p_segment != r_segment {
return None;
}
}
Some(params)
}
pub fn add_middleware<F>(&mut self, mw: F)
where
F: Fn(RequestContext, ActionFn) -> ActionResult + Send + Sync + 'static,
{
self.middlewares.push(Arc::new(mw));
}
pub fn use_static_files(&mut self) {
let middleware = move |ctx: RequestContext, next: ActionFn| {
if ctx.method == HttpMethod::GET && ctx.path.contains('.') {
return ActionResult::File(ctx.path);
}
next(ctx)
};
self.add_middleware(middleware);
}
pub fn get<F>(&mut self, path: &str, action: F, rules: Vec<RouteRules>)
where
F: Fn(RequestContext) -> ActionResult + Send + Sync + 'static,
{
self.add_route(path, action, HttpMethod::GET, rules);
}
pub fn post<F>(&mut self, path: &str, action: F, rules: Vec<RouteRules>)
where
F: Fn(RequestContext) -> ActionResult + Send + Sync + 'static,
{
self.add_route(path, action, HttpMethod::POST, rules);
}
pub fn put<F>(&mut self, path: &str, action: F, rules: Vec<RouteRules>)
where
F: Fn(RequestContext) -> ActionResult + Send + Sync + 'static,
{
self.add_route(path, action, HttpMethod::PUT, rules);
}
pub fn delete<F>(&mut self, path: &str, action: F, rules: Vec<RouteRules>)
where
F: Fn(RequestContext) -> ActionResult + Send + Sync + 'static,
{
self.add_route(path, action, HttpMethod::DELETE, rules);
}
pub fn add_route<F>(
&mut self,
path: &str,
action: F,
method: HttpMethod,
rules: Vec<RouteRules>,
) where
F: Fn(RequestContext) -> ActionResult + Send + Sync + 'static,
{
self.routes.push(Route {
path: path.to_string(),
action: Arc::new(action),
method,
rules,
});
}
fn handle_request(&self, ctx: RequestContext) -> ActionResult {
let routes = self.routes.clone();
let route_handler: ActionFn = Arc::new(move |mut ctx: RequestContext| {
for route in routes.iter() {
if route.method != ctx.method {
continue;
}
if let Some(path_params) = Server::match_and_extract_params(&route.path, &ctx.path)
{
ctx.path_params = path_params;
for rule in route.rules.clone() {
if let RouteRules::RequestSizeLimit(limit) = rule {
if ctx.body.len() > limit {
return ActionResult::PayloadTooLarge(format!(
"Request to route '{}' exceeded the allowed size: {} bytes",
route.path, limit
));
}
} else if let RouteRules::Roles(roles) = rule {
match &ctx.user {
Some(user) => {
let has_role = roles.iter().any(|r| user.roles.contains(r));
if !has_role {
return ActionResult::UnAuthorized(
"You do not have the required role(s)".into(),
);
}
}
None => (),
}
}
}
return (route.action)(ctx);
}
}
ActionResult::NotFound
});
let mut next = route_handler;
for mw in self.middlewares.iter().rev() {
let current_next = next.clone();
let mw_clone = mw.clone();
next = Arc::new(move |ctx: RequestContext| mw_clone(ctx, current_next.clone()));
}
next(ctx)
}
pub async fn start(self, addr: &str) -> std::io::Result<()> {
println!("Server listening at http://{}", addr);
let shared_routes = web::Data::new(self);
HttpServer::new(move || {
App::new()
.app_data(shared_routes.clone())
.default_service(web::to(
|req: HttpRequest, body: Bytes, srv: web::Data<Server>| {
let mut params = HashMap::new();
for (key, value) in req
.query_string()
.split('&')
.filter(|s| !s.is_empty())
.map(|pair| {
let mut kv = pair.splitn(2, '=');
(kv.next().unwrap_or(""), kv.next().unwrap_or(""))
})
{
params.insert(key.to_string(), value.to_string());
}
let mapped_methods = match req.method() {
&Method::GET => HttpMethod::GET,
&Method::POST => HttpMethod::POST,
&Method::PUT => HttpMethod::PUT,
&Method::DELETE => HttpMethod::DELETE,
&Method::PATCH => HttpMethod::PATCH,
&Method::CONNECT => HttpMethod::CONNECT,
&Method::OPTIONS => HttpMethod::OPTIONS,
&Method::HEAD => HttpMethod::HEAD,
&Method::TRACE => HttpMethod::TRACE,
_ => HttpMethod::NotSupported,
};
let route_rules = match srv.routes.iter().find(|r| {
r.path == req.path().to_string() && r.method == mapped_methods
}) {
Some(r) => r.rules.clone(),
None => Vec::new(),
};
let ctx = RequestContext {
path: req.path().to_string(),
headers: req.headers().clone(),
params,
path_params: HashMap::new(),
body: body.to_vec(),
method: mapped_methods,
rules: route_rules,
user: None,
};
let result = srv.handle_request(ctx);
let body = match result {
ActionResult::Html(s) => {
HttpResponse::Ok().content_type("text/html").body(s)
}
ActionResult::StatusCode(code, body) => {
let valid_code = StatusCode::from_u16(code)
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
HttpResponse::build(valid_code)
.content_type("application/json")
.body(body)
}
ActionResult::View(renderer_arc) => match renderer_arc.render_html() {
Ok(html) => HttpResponse::Ok().content_type("text/html").body(html),
Err(e) => {
eprintln!("Askama Rendering Error: {}", e);
HttpResponse::InternalServerError()
.content_type("application/json")
.body(format!("Template Rendering Error: {}", e))
}
},
ActionResult::Ok(content) => HttpResponse::Ok()
.content_type("application/json")
.body(content),
ActionResult::BadRequest(content) => HttpResponse::BadRequest()
.content_type("application/json")
.body(content),
ActionResult::Redirect(url) => HttpResponse::Found()
.append_header(("Location", url))
.finish(),
ActionResult::File(path) => {
let wwwroot = std::env::current_dir()
.unwrap()
.join("wwwroot")
.canonicalize()
.unwrap();
let requested = Path::new(path.trim_start_matches(['/', '\\']));
let file_path = wwwroot.join(requested).canonicalize();
println!("wwwroot: {}", wwwroot.display());
println!("requested path: {:?}", requested);
println!("file_path: {:?}", file_path);
match file_path {
Ok(path) if path.starts_with(&wwwroot) => {
match std::fs::read(&path) {
Ok(bytes) => {
let content_type = mime_guess::from_path(&path)
.first_or_octet_stream();
HttpResponse::Ok()
.content_type(content_type.as_ref())
.body(bytes)
}
Err(_) => HttpResponse::NotFound().body("Not found"),
}
}
_ => HttpResponse::Forbidden().body("Access denied"),
}
}
ActionResult::PayloadTooLarge(body) => HttpResponse::PayloadTooLarge()
.content_type("application/json")
.body(body),
ActionResult::Forbidden(body) => HttpResponse::Forbidden()
.content_type("application/json")
.body(body),
ActionResult::UnAuthorized(body) => HttpResponse::Unauthorized()
.content_type("application/json")
.body(body),
ActionResult::NotFound => HttpResponse::NotFound()
.content_type("application/json")
.body("Not found"),
};
async move { body }
},
))
})
.bind(addr)?
.run()
.await
}
}