#[cfg(test)]
mod tests;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::app::App;
use crate::application::Application;
use crate::core::New;
use crate::request::Request;
use crate::response::Response;
use crate::router::PathParams;
use crate::server::ConnectionInfo;
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
type AsyncHandlerFn<S> = Arc<
dyn Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> BoxFuture<Response> + Send + Sync,
>;
#[derive(Clone)]
enum Segment {
Literal(String),
Param(String),
Wildcard(String),
}
fn parse_pattern(pattern: &str) -> Vec<Segment> {
if pattern == "/" {
return vec![];
}
pattern
.split('/')
.filter(|s| !s.is_empty())
.map(|seg| {
if let Some(name) = seg.strip_prefix(':') {
Segment::Param(name.to_string())
} else if let Some(name) = seg.strip_prefix('*') {
Segment::Wildcard(name.to_string())
} else {
Segment::Literal(seg.to_string())
}
})
.collect()
}
fn try_match(pattern: &[Segment], path: &[&str]) -> Option<HashMap<String, String>> {
let mut params = HashMap::new();
let mut pi = 0;
for (si, seg) in pattern.iter().enumerate() {
match seg {
Segment::Literal(lit) => {
if pi >= path.len() || path[pi] != lit.as_str() {
return None;
}
pi += 1;
}
Segment::Param(name) => {
if pi >= path.len() {
return None;
}
params.insert(name.clone(), path[pi].to_string());
pi += 1;
}
Segment::Wildcard(name) => {
if si != pattern.len() - 1 {
return None;
}
params.insert(name.clone(), path[pi..].join("/"));
pi = path.len();
}
}
}
if pi == path.len() { Some(params) } else { None }
}
#[derive(Clone)]
struct AsyncRoute<S> {
method: String,
segments: Vec<Segment>,
handler: AsyncHandlerFn<S>,
}
#[derive(Clone)]
pub struct AsyncAppWithState<S> {
state: Arc<S>,
routes: Vec<AsyncRoute<S>>,
}
impl<S: Send + Sync + 'static> AsyncAppWithState<S> {
pub fn new(state: S) -> Self {
AsyncAppWithState { state: Arc::new(state), routes: Vec::new() }
}
pub fn state(&self) -> &S {
&self.state
}
fn add<F, Fut>(mut self, method: &str, pattern: &str, handler: F) -> Self
where
F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.routes.push(AsyncRoute {
method: method.to_string(),
segments: parse_pattern(pattern),
handler: Arc::new(move |req, params, conn, state| Box::pin(handler(req, params, conn, state))),
});
self
}
pub fn get<F, Fut>(self, pattern: &str, handler: F) -> Self
where
F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.add("GET", pattern, handler)
}
pub fn post<F, Fut>(self, pattern: &str, handler: F) -> Self
where
F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.add("POST", pattern, handler)
}
pub fn put<F, Fut>(self, pattern: &str, handler: F) -> Self
where
F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.add("PUT", pattern, handler)
}
pub fn patch<F, Fut>(self, pattern: &str, handler: F) -> Self
where
F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.add("PATCH", pattern, handler)
}
pub fn delete<F, Fut>(self, pattern: &str, handler: F) -> Self
where
F: Fn(Request, PathParams, ConnectionInfo, Arc<S>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response> + Send + 'static,
{
self.add("DELETE", pattern, handler)
}
async fn execute_async(
&self,
request: &Request,
connection: &ConnectionInfo,
) -> Result<Response, String> {
let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
for route in &self.routes {
if route.method != request.method {
continue;
}
if let Some(params_map) = try_match(&route.segments, &path_segs) {
let params = PathParams::from_map(params_map);
let fut = (route.handler)(
request.clone(),
params,
connection.clone(),
Arc::clone(&self.state),
);
return Ok(fut.await);
}
}
App::new().execute(request, connection)
}
}
impl<S: Send + Sync + 'static> Application for AsyncAppWithState<S> {
fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
let request = request.clone();
let connection = connection.clone();
match tokio::runtime::Handle::try_current() {
Ok(_) => {
std::thread::scope(|s| {
s.spawn(|| {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(self.execute_async(&request, &connection))
})
.join()
.unwrap()
})
}
Err(_) => {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(self.execute_async(&request, &connection))
}
}
}
}