#[cfg(test)]
mod tests;
use std::collections::HashMap;
use std::sync::Arc;
use crate::request::Request;
use crate::response::Response;
use crate::server::ConnectionInfo;
pub struct PathParams {
params: HashMap<String, String>,
}
impl PathParams {
fn new() -> Self {
PathParams { params: HashMap::new() }
}
pub(crate) fn from_map(params: HashMap<String, String>) -> Self {
PathParams { params }
}
pub fn get(&self, name: &str) -> Option<&str> {
self.params.get(name).map(String::as_str)
}
fn insert(&mut self, key: String, value: String) {
self.params.insert(key, value);
}
}
#[derive(Clone)]
enum Segment {
Literal(String),
Param(String),
Wildcard(String),
}
type HandlerFn =
Arc<dyn Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static>;
#[derive(Clone)]
struct Route {
method: String,
segments: Vec<Segment>,
handler: HandlerFn,
}
#[derive(Clone)]
pub struct RouteInfo {
pub method: String,
pub pattern: String,
}
#[derive(Clone)]
pub struct Router {
routes: Vec<Route>,
host: Option<String>,
}
impl Router {
pub fn new() -> Self {
Router { routes: Vec::new(), host: None }
}
pub fn with_host(mut self, host: &str) -> Self {
self.host = Some(host.to_string());
self
}
pub fn get<F>(self, pattern: &str, handler: F) -> Self
where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
self.add("GET", pattern, handler)
}
pub fn post<F>(self, pattern: &str, handler: F) -> Self
where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
self.add("POST", pattern, handler)
}
pub fn put<F>(self, pattern: &str, handler: F) -> Self
where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
self.add("PUT", pattern, handler)
}
pub fn patch<F>(self, pattern: &str, handler: F) -> Self
where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
self.add("PATCH", pattern, handler)
}
pub fn delete<F>(self, pattern: &str, handler: F) -> Self
where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
self.add("DELETE", pattern, handler)
}
fn add<F>(mut self, method: &str, pattern: &str, handler: F) -> Self
where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
self.routes.push(Route {
method: method.to_string(),
segments: Self::parse_pattern(pattern),
handler: Arc::new(handler),
});
self
}
fn parse_pattern(pattern: &str) -> Vec<Segment> {
if pattern == "/" {
return vec![];
}
pattern
.split('/')
.filter(|s| !s.is_empty())
.map(|seg| {
if let Some(name) = seg.strip_prefix(':') {
Segment::Param(name.to_string())
} else if let Some(name) = seg.strip_prefix('*') {
Segment::Wildcard(name.to_string())
} else {
Segment::Literal(seg.to_string())
}
})
.collect()
}
pub fn route_entries(&self) -> Vec<RouteInfo> {
self.routes.iter().map(|r| RouteInfo {
method: r.method.clone(),
pattern: Self::segments_to_pattern(&r.segments),
}).collect()
}
fn segments_to_pattern(segs: &[Segment]) -> String {
if segs.is_empty() {
return "/".to_string();
}
let parts: Vec<String> = segs.iter().map(|s| match s {
Segment::Literal(l) => l.clone(),
Segment::Param(n) => format!(":{}", n),
Segment::Wildcard(n) => format!("*{}", n),
}).collect();
format!("/{}", parts.join("/"))
}
pub fn handle(&self, request: &Request, connection: &ConnectionInfo) -> Option<Response> {
if let Some(required_host) = &self.host {
let actual = connection.sni_hostname.as_deref().or_else(|| {
request.headers.iter()
.find(|h| h.name.eq_ignore_ascii_case("host"))
.map(|h| h.value.as_str())
});
if actual != Some(required_host.as_str()) {
return None;
}
}
let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
for route in &self.routes {
if route.method != request.method {
continue;
}
if let Some(params) = Self::try_match(&route.segments, &path_segs) {
return Some((route.handler)(request, ¶ms, connection));
}
}
None
}
fn try_match(pattern: &[Segment], path: &[&str]) -> Option<PathParams> {
let mut params = PathParams::new();
let mut pi = 0;
for (si, seg) in pattern.iter().enumerate() {
match seg {
Segment::Literal(lit) => {
if pi >= path.len() || path[pi] != lit.as_str() {
return None;
}
pi += 1;
}
Segment::Param(name) => {
if pi >= path.len() {
return None;
}
params.insert(name.clone(), path[pi].to_string());
pi += 1;
}
Segment::Wildcard(name) => {
if si != pattern.len() - 1 {
return None; }
params.insert(name.clone(), path[pi..].join("/"));
pi = path.len();
}
}
}
if pi == path.len() { Some(params) } else { None }
}
}