use http::StatusCode;
use http::header::LOCATION;
use crate::body::TakoBody;
use crate::responder::Responder;
use crate::types::Response;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Redirect {
status: StatusCode,
location: String,
}
impl Redirect {
#[inline]
#[must_use]
pub fn with_status(location: impl Into<String>, status: StatusCode) -> Self {
Self {
status,
location: location.into(),
}
}
#[inline]
#[must_use]
pub fn found(location: impl Into<String>) -> Self {
Self::with_status(location, StatusCode::FOUND)
}
#[inline]
#[must_use]
pub fn see_other(location: impl Into<String>) -> Self {
Self::with_status(location, StatusCode::SEE_OTHER)
}
#[inline]
#[must_use]
pub fn temporary(location: impl Into<String>) -> Self {
Self::with_status(location, StatusCode::TEMPORARY_REDIRECT)
}
#[inline]
#[must_use]
pub fn permanent_moved(location: impl Into<String>) -> Self {
Self::with_status(location, StatusCode::MOVED_PERMANENTLY)
}
#[inline]
#[must_use]
pub fn permanent(location: impl Into<String>) -> Self {
Self::with_status(location, StatusCode::PERMANENT_REDIRECT)
}
}
impl Responder for Redirect {
fn into_response(self) -> Response {
let Ok(value) = http::HeaderValue::try_from(self.location.as_str()) else {
let mut resp = http::Response::new(TakoBody::from(
"redirect location contains invalid header characters",
));
*resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return resp;
};
let mut resp = http::Response::new(TakoBody::empty());
*resp.status_mut() = self.status;
resp.headers_mut().insert(LOCATION, value);
resp
}
}
pub fn found(location: impl Into<String>) -> Redirect {
Redirect::found(location)
}
pub fn see_other(location: impl Into<String>) -> Redirect {
Redirect::see_other(location)
}
pub fn temporary(location: impl Into<String>) -> Redirect {
Redirect::temporary(location)
}
pub fn permanent_moved(location: impl Into<String>) -> Redirect {
Redirect::permanent_moved(location)
}
pub fn permanent(location: impl Into<String>) -> Redirect {
Redirect::permanent(location)
}
fn validate_host(header_value: &str) -> Option<String> {
let trimmed = header_value.trim();
if trimmed.is_empty() {
return None;
}
if trimmed
.bytes()
.any(|b| b == b'\r' || b == b'\n' || b == 0 || b == b' ' || b == b'\t')
{
return None;
}
let _authority: http::uri::Authority = trimmed.parse().ok()?;
let host = if let Some(after_bracket) = trimmed.strip_prefix('[') {
let end = after_bracket.find(']')?;
let bracketed = &trimmed[..=end + 1];
bracketed.to_string()
} else {
trimmed
.split(':')
.next()
.filter(|s| !s.is_empty())?
.to_string()
};
Some(host)
}
pub fn http_to_https_router(https_port: u16) -> crate::router::Router {
http_to_https_router_inner(https_port, Vec::new())
}
pub fn http_to_https_router_with_allowed_hosts(
https_port: u16,
allowed_hosts: impl IntoIterator<Item = impl Into<String>>,
) -> crate::router::Router {
let allowed: Vec<String> = allowed_hosts
.into_iter()
.map(|s| s.into().to_ascii_lowercase())
.collect();
http_to_https_router_inner(https_port, allowed)
}
fn http_to_https_router_inner(
https_port: u16,
allowed_hosts: Vec<String>,
) -> crate::router::Router {
let allowed = std::sync::Arc::new(allowed_hosts);
let mut router = crate::router::Router::new();
router.fallback(move |req: crate::types::Request| {
let port = https_port;
let allowed = allowed.clone();
async move {
let host_header = req
.headers()
.get(http::header::HOST)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let Some(host) = validate_host(host_header) else {
return http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(TakoBody::from("invalid Host header"))
.expect("static 400 response is well-formed");
};
if !allowed.is_empty() && !allowed.contains(&host.to_ascii_lowercase()) {
return http::Response::builder()
.status(StatusCode::MISDIRECTED_REQUEST)
.body(TakoBody::from("host not allowed"))
.expect("static 421 response is well-formed");
}
let path_and_query = req
.uri()
.path_and_query()
.map_or("/", http::uri::PathAndQuery::as_str);
let location = if port == 443 {
format!("https://{host}{path_and_query}")
} else {
format!("https://{host}:{port}{path_and_query}")
};
Redirect::permanent(location).into_response()
}
});
router
}