use crate::{
Endpoint, EndpointExt, IntoEndpoint, Request, Response, Result,
endpoint::BoxEndpoint,
error::{NotFoundError, RouteError},
http::header,
route::{check_result, internal::trie::Trie},
};
#[derive(Default)]
pub struct RouteDomain {
tree: Trie<BoxEndpoint<'static>>,
}
impl RouteDomain {
pub fn new() -> Self {
Default::default()
}
#[must_use]
pub fn at<E>(self, pattern: impl AsRef<str>, ep: E) -> Self
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
check_result(self.try_at(pattern, ep))
}
pub fn try_at<E>(mut self, pattern: impl AsRef<str>, ep: E) -> Result<Self, RouteError>
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
self.tree.add(
pattern.as_ref(),
ep.into_endpoint().map_to_response().boxed(),
)?;
Ok(self)
}
}
impl Endpoint for RouteDomain {
type Output = Response;
async fn call(&self, req: Request) -> Result<Self::Output> {
let host = req
.headers()
.get(header::HOST)
.and_then(|host| host.to_str().ok())
.unwrap_or_default();
match self.tree.matches(host) {
Some(ep) => ep.call(req).await,
None => Err(NotFoundError.into()),
}
}
}
#[cfg(test)]
mod tests {
use http::StatusCode;
use super::*;
use crate::{endpoint::make_sync, handler, http::HeaderMap, test::TestClient};
async fn check(r: &RouteDomain, host: &str, value: &str) {
let cli = TestClient::new(r);
let mut req = cli.get("/");
if !host.is_empty() {
req = req.header(header::HOST, host);
}
let resp = req.send().await;
resp.assert_status_is_ok();
resp.assert_text(value).await;
}
#[tokio::test]
async fn route_domain() {
#[handler(internal)]
fn h(headers: &HeaderMap) -> String {
headers
.get(header::HOST)
.and_then(|value| value.to_str().ok())
.unwrap_or_default()
.to_string()
}
let r = RouteDomain::new()
.at("example.com", make_sync(|_| "1"))
.at("www.example.com", make_sync(|_| "2"))
.at("www.+.com", make_sync(|_| "3"))
.at("*.com", make_sync(|_| "4"))
.at("*", make_sync(|_| "5"));
check(&r, "example.com", "1").await;
check(&r, "www.example.com", "2").await;
check(&r, "www.abc.com", "3").await;
check(&r, "abc.com", "4").await;
check(&r, "rust-lang.org", "5").await;
check(&r, "", "5").await;
}
#[tokio::test]
async fn not_found() {
let r = RouteDomain::new()
.at("example.com", make_sync(|_| "1"))
.at("www.example.com", make_sync(|_| "2"))
.at("www.+.com", make_sync(|_| "3"))
.at("*.com", make_sync(|_| "4"));
let cli = TestClient::new(r);
cli.get("/")
.header(header::HOST, "rust-lang.org")
.send()
.await
.assert_status(StatusCode::NOT_FOUND);
cli.get("/")
.send()
.await
.assert_status(StatusCode::NOT_FOUND);
}
#[handler(internal)]
fn h() {}
#[test]
#[should_panic]
fn duplicate_1() {
let _ = RouteDomain::new().at("example.com", h).at("example.com", h);
}
#[test]
#[should_panic]
fn duplicate_2() {
let _ = RouteDomain::new()
.at("+.example.com", h)
.at("+.example.com", h);
}
#[test]
#[should_panic]
fn duplicate_3() {
let _ = RouteDomain::new()
.at("*.example.com", h)
.at("*.example.com", h);
}
#[test]
#[should_panic]
fn duplicate_4() {
let _ = RouteDomain::new().at("*", h).at("*", h);
}
}