use std::{fs, path::PathBuf};
use async_trait::async_trait;
use axum::{
extract::Host,
handler::HandlerWithoutStateExt,
http::{StatusCode, Uri},
response::Redirect,
};
use axum_server::tls_rustls::RustlsConfig;
use rcgen::{date_time_ymd, Certificate, CertificateParams, DistinguishedName, DnType, SanType};
#[async_trait]
pub trait GenerateCertKey {
fn get_cert_key_path() -> anyhow::Result<(String, String)> {
fs::create_dir_all("certs/")?;
Ok(("certs/cert.pem".to_owned(), "certs/key.pem".to_owned()))
}
async fn get_rustls_config(create_if_not_exists: bool) -> anyhow::Result<RustlsConfig> {
let (cert, key) = Self::get_cert_key_path()?;
let cert_pathbuf = PathBuf::from(cert);
let key_pathbuf = PathBuf::from(key);
if create_if_not_exists && (!cert_pathbuf.exists() | !key_pathbuf.exists()) {
tracing::info!(
"generate cert at {} and key at {}",
cert_pathbuf.to_str().unwrap(),
key_pathbuf.to_str().unwrap()
);
Self::generate_cert_key()?;
}
Ok(RustlsConfig::from_pem_file(cert_pathbuf, key_pathbuf)
.await
.unwrap())
}
fn generate_cert_key() -> anyhow::Result<()> {
let cert = Certificate::from_params(Self::get_cert_params())?;
let pem_serialized = cert.serialize_pem()?;
println!("{}", pem_serialized);
println!("{}", cert.serialize_private_key_pem());
let (cert_path, key_path) = Self::get_cert_key_path()?;
fs::write(cert_path, pem_serialized.as_bytes())?;
fs::write(key_path, cert.serialize_private_key_pem().as_bytes())?;
Ok(())
}
fn get_cert_params() -> CertificateParams {
let mut params: CertificateParams = Default::default();
params.not_before = date_time_ymd(1975, 1, 1);
params.not_after = date_time_ymd(4096, 1, 1);
params.distinguished_name = DistinguishedName::new();
params
.distinguished_name
.push(DnType::OrganizationName, "Axum-restful");
params
.distinguished_name
.push(DnType::CommonName, "Axum-restful common name");
params.subject_alt_names = vec![SanType::DnsName("localhost".to_string())];
params
}
}
pub async fn redirect_http_to_https(http_port: u16, https_port: u16, http_ip: &str) {
fn make_https(host: String, uri: Uri, http_port: u16, https_port: u16) -> anyhow::Result<Uri> {
let mut parts = uri.into_parts();
parts.scheme = Some(axum::http::uri::Scheme::HTTPS);
if parts.path_and_query.is_none() {
parts.path_and_query = Some("/".parse().unwrap());
}
let https_host = host.replace(&http_port.to_string(), &https_port.to_string());
parts.authority = Some(https_host.parse()?);
Ok(Uri::from_parts(parts)?)
}
let redirect = move |Host(host): Host, uri: Uri| async move {
match make_https(host, uri, http_port, https_port) {
Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
Err(error) => {
tracing::warn!(%error, "failed to convert URI to HTTPS");
Err(StatusCode::BAD_REQUEST)
}
}
};
let addr = format!("{}:{}", http_ip, http_port);
tracing::debug!("http redirect listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, redirect.into_make_service())
.await
.unwrap();
}