use super::server::LuaHttpHandler;
use super::{Request, Response};
use crate::transport::Context;
use std::collections::HashMap;
use std::ops::{Add, Deref};
use std::rc::Rc;
#[cfg(feature = "open-api")]
use crate::transport::http::openapi::RouteOperation;
type HandlerFn<E> = dyn Fn(&mut Context, Request) -> Result<Response, E> + 'static;
pub struct Handler<E>(pub Box<HandlerFn<E>>);
impl<R, E, F> From<F> for Handler<E>
where
R: Into<Response>,
F: Fn(&mut Context, Request) -> Result<R, E> + 'static,
{
fn from(f: F) -> Self {
Handler(Box::new(move |ctx, req| f(ctx, req).map(Into::into)))
}
}
impl<E> Deref for Handler<E> {
type Target = Box<HandlerFn<E>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
type MiddlewareFn<E> = dyn Fn(Handler<E>) -> Handler<E> + 'static;
pub struct Middleware<E>(pub Box<MiddlewareFn<E>>);
impl<E, F> From<F> for Middleware<E>
where
F: Fn(Handler<E>) -> Handler<E> + 'static,
{
fn from(f: F) -> Self {
Middleware(Box::new(f))
}
}
impl<E> Deref for Middleware<E> {
type Target = Box<MiddlewareFn<E>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub struct Builder<E> {
path: String,
method: &'static str,
middlewares: Vec<Rc<Middleware<E>>>,
#[cfg(feature = "open-api")]
open_api_op: Option<RouteOperation>,
}
impl<E> Default for Builder<E> {
fn default() -> Self {
Builder::new()
}
}
impl<E> Builder<E> {
pub fn new() -> Self {
Self {
path: "".to_string(),
method: "GET",
middlewares: vec![],
#[cfg(feature = "open-api")]
open_api_op: None,
}
}
pub fn with_middleware(self, md: impl Into<Middleware<E>>) -> Self {
let mut mw = self.middlewares;
mw.push(Rc::from(md.into()));
Self {
middlewares: mw,
..self
}
}
pub fn with_path(self, path: &'static str) -> Self {
Self {
path: self.path.add(path),
..self
}
}
pub fn with_method(self, method: &'static str) -> Self {
Self { method, ..self }
}
#[cfg(feature = "open-api")]
pub fn define_open_api(self, spec: RouteOperation) -> Self {
Self {
open_api_op: Some(spec),
..self
}
}
pub fn group(self) -> Group<E> {
Group { inner: self }
}
pub fn build(self, handler: impl Into<Handler<E>>) -> Route<E> {
let handler = self
.middlewares
.into_iter()
.fold(handler.into(), |handler, mw| (mw)(handler));
#[cfg(feature = "open-api")]
{
let open_api_op = self.open_api_op.unwrap_or_default();
open_api_op.update_global_doc(&self.path, self.method);
}
Route {
path: self.path,
method: self.method,
handler,
}
}
}
pub struct Group<E> {
inner: Builder<E>,
}
impl<E> Group<E> {
pub fn builder(&self) -> Builder<E> {
Builder {
path: self.inner.path.clone(),
method: self.inner.method,
middlewares: self.inner.middlewares.clone(),
#[cfg(feature = "open-api")]
open_api_op: self.inner.open_api_op.as_ref().cloned(),
}
}
#[cfg(feature = "open-api")]
pub fn open_api(&self) -> RouteOperation {
match &self.inner.open_api_op {
Some(open_api) => open_api.clone(),
None => RouteOperation::new(),
}
}
}
pub struct Route<E> {
method: &'static str,
path: String,
handler: Handler<E>,
}
impl<E> LuaHttpHandler for Route<E>
where
E: From<crate::tarantool::error::Error> + From<serde_json::Error>,
{
fn handle(&self, req: Request) -> Response {
let mut ctx = Context::new();
ctx.put("path", self.path.clone());
let maybe_response: Result<Response, _> = (self.handler)(&mut ctx, req);
maybe_response.unwrap_or_else(|_| Response {
headers: HashMap::from([(
"content-type".to_string(),
"text/html; charset=utf-8".to_string(),
)]),
status: 500,
body: "internal server error".as_bytes().to_vec(),
})
}
fn method(&self) -> &str {
self.method
}
fn path(&self) -> &str {
self.path.as_str()
}
}
#[cfg(test)]
mod test {
use super::*;
use std::cell::Cell;
use std::error::Error;
#[test]
fn test_middleware() {
let call_counter = Rc::new(Cell::new(0));
let middleware_call_counter = call_counter.clone();
let route = Builder::new()
.with_middleware(move |h: Handler<Box<dyn Error>>| {
let middleware_call_counter = middleware_call_counter.clone();
Handler(Box::new(move |ctx, request| {
let result = h(ctx, request);
middleware_call_counter.set(middleware_call_counter.get() + 1);
result
}))
})
.build(|_: &mut Context, _: Request| Ok(()));
route.handle(Request::default());
route.handle(Request::default());
route.handle(Request::default());
assert_eq!(3, call_counter.get())
}
#[test]
fn test_error_handling() {
let route = Builder::new().build(
|_: &mut Context, _: Request| -> Result<(), Box<dyn Error>> { Err("error".into()) },
);
let response = route.handle(Request::default());
assert_eq!(500, response.status)
}
#[test]
fn test_path() {
let route = Builder::<Box<dyn Error>>::new()
.with_path("/1")
.with_path("/2")
.with_path("/3")
.build(|_: &mut Context, _: Request| Ok(()));
assert_eq!("/1/2/3".to_string(), route.path);
}
#[test]
fn test_group() {
let group = Builder::new()
.with_middleware(move |h: Handler<Box<dyn Error>>| {
Handler(Box::new(move |ctx, request| {
let result: Result<Response, _> = h(ctx, request);
result.map(|mut resp| {
resp.status = 201;
resp
})
}))
})
.with_path("/group")
.group();
let route1 = group
.builder()
.with_path("/route1")
.build(|_: &mut Context, _: Request| Ok(()));
let route2 = group
.builder()
.with_path("/route2")
.build(|_: &mut Context, _: Request| Ok(()));
assert_eq!("/group/route1".to_string(), route1.path);
assert_eq!("/group/route2".to_string(), route2.path);
assert_eq!(201, route1.handle(Request::default()).status);
assert_eq!(201, route2.handle(Request::default()).status);
}
}