use rustolio_utils::http;
use rustolio_utils::threadsafe::Threadsafe;
pub trait Next<B, C>: Threadsafe {
fn call(
&self,
req: http::Request<B>,
ctx: &'static C,
) -> impl std::future::Future<Output = http::Result<http::Response>> + Threadsafe;
}
#[doc(hidden)]
pub struct NextImpl<F, B, C, Fut> {
f: F,
_marker: std::marker::PhantomData<(B, C, Fut)>,
}
impl<F, B, C, Fut> Next<B, C> for NextImpl<F, B, C, Fut>
where
B: hyper::body::Body + Threadsafe,
C: Threadsafe,
F: Threadsafe + Fn(http::Request<B>, &'static C) -> Fut,
Fut: std::future::Future<Output = http::Result<http::Response>> + Threadsafe,
{
fn call(
&self,
req: http::Request<B>,
ctx: &'static C,
) -> impl std::future::Future<Output = http::Result<http::Response>> + Threadsafe {
(self.f)(req, ctx)
}
}
impl<F, B, C, Fut> From<F> for NextImpl<F, B, C, Fut>
where
B: hyper::body::Body + Threadsafe,
C: Threadsafe,
F: Threadsafe + Fn(http::Request<B>, &'static C) -> Fut,
Fut: std::future::Future<Output = http::Result<http::Response>> + Threadsafe,
{
fn from(f: F) -> Self {
NextImpl {
f,
_marker: std::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use http::{HeaderName, StatusCode};
use http_body_util::BodyExt as _;
use super::*;
async fn middleware(
req: http::Request<http_body_util::Full<hyper::body::Bytes>>,
ctx: &'static String,
next: impl Next<http_body_util::Full<hyper::body::Bytes>, String>,
) -> http::Result<http::Response> {
let auth_header = req.header(HeaderName::AUTHORIZATION).unwrap();
assert_eq!(auth_header, "token");
next.call(req, ctx).await
}
async fn endpoint(
req: http::Request<http_body_util::Full<hyper::body::Bytes>>,
ctx: &String,
) -> http::Result<http::Response> {
let req = req.text().await?;
let text = req.into_body();
Ok(http::Response::builder()
.status(StatusCode::OK)
.text(format!("Req: {}, Ctx: {}", text, ctx))
.build()
.unwrap())
}
#[tokio::test]
async fn test_middleware_endpoint_creation() {
let ctx: &'static String = Box::leak(Box::new("TestContext".to_string()));
let endpoint = NextImpl::from(endpoint);
let handler = |req, ctx| Box::pin(middleware(req, ctx, endpoint));
let res = handler(
http::Request::post("/test")
.header(HeaderName::AUTHORIZATION, "token")
.body(http_body_util::Full::new(hyper::body::Bytes::from(
"This is a string message",
)))
.build()
.unwrap(),
&ctx,
)
.await
.unwrap();
let res_body =
String::from_utf8(res.into_body().collect().await.unwrap().to_bytes().to_vec())
.unwrap();
assert_eq!(res_body, "Req: This is a string message, Ctx: TestContext");
}
}