use std::fmt::{self, Debug, Formatter};
use std::pin::Pin;
use std::sync::{Arc, LazyLock};
use headers::HeaderValue;
use http::header::{ALT_SVC, CONTENT_TYPE};
use http::uri::Scheme;
use hyper::service::Service as HyperService;
use hyper::{Method, Request as HyperRequest, Response as HyperResponse};
use crate::catcher::{Catcher, write_error_default};
use crate::conn::SocketAddr;
use crate::fuse::ArcFusewire;
use crate::handler::{Handler, WhenHoop};
use crate::http::body::{ReqBody, ResBody};
use crate::http::{Mime, Request, Response, StatusCode};
use crate::routing::{FlowCtrl, PathState, Router};
use crate::{Depot, async_trait};
#[non_exhaustive]
pub struct Service {
pub router: Arc<Router>,
pub catcher: Option<Arc<Catcher>>,
pub hoops: Vec<Arc<dyn Handler>>,
pub allowed_media_types: Arc<Vec<Mime>>,
}
impl Debug for Service {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Service")
.field("router", &self.router)
.field("catcher", &self.catcher)
.field("hoops", &self.hoops.len())
.field("allowed_media_types", &self.allowed_media_types.len())
.finish()
}
}
impl Service {
#[inline]
pub fn new<T>(router: T) -> Self
where
T: Into<Arc<Router>>,
{
Self {
router: router.into(),
catcher: None,
hoops: vec![],
allowed_media_types: Arc::new(vec![]),
}
}
#[inline]
#[must_use]
pub fn router(&self) -> Arc<Router> {
self.router.clone()
}
#[inline]
#[must_use]
pub fn catcher(mut self, catcher: impl Into<Arc<Catcher>>) -> Self {
self.catcher = Some(catcher.into());
self
}
#[inline]
#[must_use]
pub fn hoop<H: Handler>(mut self, hoop: H) -> Self {
self.hoops.push(Arc::new(hoop));
self
}
#[inline]
#[must_use]
pub fn hoop_when<H, F>(mut self, hoop: H, filter: F) -> Self
where
H: Handler,
F: Fn(&Request, &Depot) -> bool + Send + Sync + 'static,
{
self.hoops.push(Arc::new(WhenHoop {
inner: hoop,
filter,
}));
self
}
#[inline]
#[must_use]
pub fn allowed_media_types<T>(mut self, allowed_media_types: T) -> Self
where
T: Into<Arc<Vec<Mime>>>,
{
self.allowed_media_types = allowed_media_types.into();
self
}
#[doc(hidden)]
#[inline]
#[must_use]
pub fn hyper_handler(
&self,
local_addr: SocketAddr,
remote_addr: SocketAddr,
http_scheme: Scheme,
fusewire: Option<ArcFusewire>,
alt_svc_h3: Option<HeaderValue>,
) -> HyperHandler {
HyperHandler {
local_addr,
remote_addr,
http_scheme,
state: Arc::new(HyperHandlerState {
router: self.router.clone(),
catcher: self.catcher.clone(),
hoops: self.hoops.clone(),
allowed_media_types: self.allowed_media_types.clone(),
}),
fusewire,
alt_svc_h3,
}
}
#[cfg(feature = "test")]
#[inline]
pub async fn handle(&self, request: impl Into<Request> + Send) -> Response {
let request = request.into();
self.hyper_handler(
request.local_addr.clone(),
request.remote_addr.clone(),
request.scheme.clone(),
None,
None,
)
.handle(request)
.await
}
}
impl<T> From<T> for Service
where
T: Into<Arc<Router>>,
{
#[inline]
fn from(router: T) -> Self {
Self::new(router)
}
}
struct DefaultStatusOK;
#[async_trait]
impl Handler for DefaultStatusOK {
async fn handle(
&self,
req: &mut Request,
depot: &mut Depot,
res: &mut Response,
ctrl: &mut FlowCtrl,
) {
ctrl.call_next(req, depot, res).await;
if res.status_code.is_none() {
res.status_code = Some(StatusCode::OK);
}
}
}
static DEFAULT_STATUS_OK_HANDLER: LazyLock<Arc<dyn Handler>> =
LazyLock::new(|| Arc::new(DefaultStatusOK));
#[doc(hidden)]
pub(crate) struct HyperHandlerState {
pub(crate) router: Arc<Router>,
pub(crate) catcher: Option<Arc<Catcher>>,
pub(crate) hoops: Vec<Arc<dyn Handler>>,
pub(crate) allowed_media_types: Arc<Vec<Mime>>,
}
#[doc(hidden)]
#[derive(Clone)]
pub struct HyperHandler {
pub(crate) local_addr: SocketAddr,
pub(crate) remote_addr: SocketAddr,
pub(crate) http_scheme: Scheme,
pub(crate) state: Arc<HyperHandlerState>,
pub(crate) fusewire: Option<ArcFusewire>,
pub(crate) alt_svc_h3: Option<HeaderValue>,
}
impl Debug for HyperHandler {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("HyperHandler")
.field("local_addr", &self.local_addr)
.field("remote_addr", &self.remote_addr)
.field("http_scheme", &self.http_scheme)
.field("router", &self.state.router)
.field("catcher", &self.state.catcher)
.field("allowed_media_types", &self.state.allowed_media_types)
.field("alt_svc_h3", &self.alt_svc_h3)
.finish()
}
}
impl HyperHandler {
pub fn handle(&self, mut req: Request) -> impl Future<Output = Response> + 'static {
let state = self.state.clone();
req.local_addr = self.local_addr.clone();
req.remote_addr = self.remote_addr.clone();
#[cfg(not(feature = "cookie"))]
let mut res = Response::new();
#[cfg(feature = "cookie")]
let mut res = Response::with_cookies(req.cookies.clone());
if let Some(alt_svc_h3) = &self.alt_svc_h3
&& !res.headers().contains_key(ALT_SVC)
{
res.headers_mut().insert(ALT_SVC, alt_svc_h3.clone());
}
let mut depot = Depot::new();
let mut path_state = PathState::new(req.uri().path());
async move {
if let Some(dm) = state.router.detect(&mut req, &mut path_state).await {
req.params = path_state.params;
#[cfg(feature = "matched-path")]
{
req.matched_path = path_state.matched_parts.join("/");
}
let mut handlers = Vec::with_capacity(state.hoops.len() + dm.hoops.len() + 2);
handlers.extend(state.hoops.iter().cloned());
handlers.extend(dm.hoops);
handlers.push(DEFAULT_STATUS_OK_HANDLER.clone());
handlers.push(dm.goal);
let mut ctrl = FlowCtrl::new(handlers);
ctrl.call_next(&mut req, &mut depot, &mut res).await;
if res.status_code.is_none() {
res.status_code = Some(StatusCode::OK);
}
} else if !state.hoops.is_empty() {
req.params = path_state.params;
if path_state.once_ended {
res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED);
} else {
res.status_code = Some(StatusCode::NOT_FOUND);
}
let mut ctrl = FlowCtrl::new(state.hoops.clone());
ctrl.call_next(&mut req, &mut depot, &mut res).await;
if res.status_code.is_none() && path_state.once_ended {
res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED);
}
} else if path_state.once_ended {
res.status_code = Some(StatusCode::METHOD_NOT_ALLOWED);
}
let status_code = if let Some(status_code) = res.status_code {
status_code
} else {
res.status_code = Some(StatusCode::NOT_FOUND);
StatusCode::NOT_FOUND
};
if !state.allowed_media_types.is_empty()
&& let Some(ctype) = res
.headers()
.get(CONTENT_TYPE)
.and_then(|c| c.to_str().ok())
.and_then(|c| c.parse::<Mime>().ok())
{
let mut is_allowed = false;
for mime in &*state.allowed_media_types {
if mime.type_() == ctype.type_() && mime.subtype() == ctype.subtype() {
is_allowed = true;
break;
}
}
if !is_allowed {
res.status_code(StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
}
let has_error = status_code.is_client_error() || status_code.is_server_error();
if res.body.is_none()
&& !has_error
&& !status_code.is_redirection()
&& status_code != StatusCode::NO_CONTENT
&& status_code != StatusCode::SWITCHING_PROTOCOLS
&& [Method::GET, Method::POST, Method::PATCH, Method::PUT].contains(req.method())
{
tracing::warn!(
uri = ?req.uri(),
method = req.method().as_str(),
"http response content type header not set"
);
}
if Method::HEAD != *req.method()
&& (res.body.is_none() || res.body.is_error())
&& has_error
{
if let Some(catcher) = &state.catcher {
catcher.catch(&mut req, &mut depot, &mut res).await;
} else {
write_error_default(&req, &mut res, None);
}
}
if Method::HEAD == *req.method() && !res.body.is_none() {
tracing::debug!("stripping response body for HEAD request per RFC 9110 §9.3.2");
res.take_body();
}
#[cfg(feature = "quinn")]
{
use std::sync::Mutex;
use bytes::Bytes;
if let Some(session) =
req.extensions.remove::<Arc<
crate::proto::WebTransportSession<salvo_http3::quinn::Connection, Bytes>,
>>()
{
res.extensions.insert(session);
}
if let Some(conn) = req.extensions.remove::<Arc<
Mutex<salvo_http3::server::Connection<salvo_http3::quinn::Connection, Bytes>>,
>>() {
res.extensions.insert(conn);
}
if let Some(stream) = req.extensions.remove::<Arc<
salvo_http3::server::RequestStream<
salvo_http3::quinn::BidiStream<Bytes>,
Bytes,
>,
>>() {
res.extensions.insert(stream);
}
}
res
}
}
}
impl<B> HyperService<HyperRequest<B>> for HyperHandler
where
B: Into<ReqBody>,
{
type Response = HyperResponse<ResBody>;
type Error = hyper::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
#[inline]
fn call(
&self,
#[cfg(not(feature = "fix-http1-request-uri"))] req: HyperRequest<B>,
#[cfg(feature = "fix-http1-request-uri")] mut req: HyperRequest<B>,
) -> Self::Future {
let scheme = req
.uri()
.scheme()
.cloned()
.unwrap_or_else(|| self.http_scheme.clone());
#[cfg(feature = "fix-http1-request-uri")]
if req.uri().scheme().is_none()
&& let Some(host) = req
.headers()
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
.and_then(|host| host.parse::<http::uri::Authority>().ok())
{
let mut uri_parts = std::mem::take(req.uri_mut()).into_parts();
uri_parts.scheme = Some(scheme.clone());
uri_parts.authority = Some(host);
if let Ok(uri) = http::uri::Uri::from_parts(uri_parts) {
*req.uri_mut() = uri;
}
}
let mut request = Request::from_hyper(req, scheme);
request.body.set_fusewire(self.fusewire.clone());
let response = self.handle(request);
Box::pin(async move { Ok(response.await.into_hyper()) })
}
}
#[cfg(test)]
mod tests {
use crate::prelude::*;
use crate::test::{ResponseExt, TestClient};
#[tokio::test]
async fn test_service() {
#[handler]
async fn before1(
req: &mut Request,
depot: &mut Depot,
res: &mut Response,
ctrl: &mut FlowCtrl,
) {
res.render(Text::Plain("before1"));
if req.query::<String>("b").unwrap_or_default() == "1" {
ctrl.skip_rest();
} else {
ctrl.call_next(req, depot, res).await;
}
}
#[handler]
async fn before2(
req: &mut Request,
depot: &mut Depot,
res: &mut Response,
ctrl: &mut FlowCtrl,
) {
res.render(Text::Plain("before2"));
if req.query::<String>("b").unwrap_or_default() == "2" {
ctrl.skip_rest();
} else {
ctrl.call_next(req, depot, res).await;
}
}
#[handler]
async fn before3(
req: &mut Request,
depot: &mut Depot,
res: &mut Response,
ctrl: &mut FlowCtrl,
) {
res.render(Text::Plain("before3"));
if req.query::<String>("b").unwrap_or_default() == "3" {
ctrl.skip_rest();
} else {
ctrl.call_next(req, depot, res).await;
}
}
#[handler]
async fn hello() -> Result<&'static str, ()> {
Ok("hello")
}
let router = Router::with_path("level1").hoop(before1).push(
Router::with_hoop(before2)
.path("level2")
.push(Router::with_hoop(before3).path("hello").goal(hello)),
);
let service = Service::new(router);
async fn access(service: &Service, b: &str) -> String {
TestClient::get(format!("http://127.0.0.1:5801/level1/level2/hello?b={b}"))
.send(service)
.await
.take_string()
.await
.unwrap()
}
let content = access(&service, "").await;
assert_eq!(content, "before1before2before3hello");
let content = access(&service, "1").await;
assert_eq!(content, "before1");
let content = access(&service, "2").await;
assert_eq!(content, "before1before2");
let content = access(&service, "3").await;
assert_eq!(content, "before1before2before3");
}
#[tokio::test]
async fn test_service_405_or_404_error() {
#[handler]
async fn login() -> &'static str {
"login"
}
#[handler]
async fn hello() -> &'static str {
"hello"
}
let router = Router::new()
.push(Router::with_path("hello").goal(hello))
.push(
Router::with_path("login")
.post(login)
.push(Router::with_path("user").get(login)),
);
let service = Service::new(router);
let res = TestClient::get("http://127.0.0.1:5801/hello")
.send(&service)
.await;
assert_eq!(res.status_code.unwrap(), StatusCode::OK);
let res = TestClient::put("http://127.0.0.1:5801/hello")
.send(&service)
.await;
assert_eq!(res.status_code.unwrap(), StatusCode::OK);
let res = TestClient::post("http://127.0.0.1:5801/login")
.send(&service)
.await;
assert_eq!(res.status_code.unwrap(), StatusCode::OK);
let res = TestClient::get("http://127.0.0.1:5801/login")
.send(&service)
.await;
assert_eq!(res.status_code.unwrap(), StatusCode::METHOD_NOT_ALLOWED);
let res = TestClient::get("http://127.0.0.1:5801/login2")
.send(&service)
.await;
assert_eq!(res.status_code.unwrap(), StatusCode::NOT_FOUND);
let res = TestClient::get("http://127.0.0.1:5801/login/user")
.send(&service)
.await;
assert_eq!(res.status_code.unwrap(), StatusCode::OK);
let res = TestClient::post("http://127.0.0.1:5801/login/user")
.send(&service)
.await;
assert_eq!(res.status_code.unwrap(), StatusCode::METHOD_NOT_ALLOWED);
let res = TestClient::post("http://127.0.0.1:5801/login/user1")
.send(&service)
.await;
assert_eq!(res.status_code.unwrap(), StatusCode::NOT_FOUND);
}
}