use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::{HeaderMap, HeaderValue};
use std::collections::HashMap;
use std::convert::TryFrom;
use crate::tls_listener::ClientCertInfo;
use tracing::warn;
#[derive(Clone, Debug)]
pub struct ClientCertExtension {
pub client_cert: Option<ClientCertInfo>,
pub headers: HeaderMap,
}
impl ClientCertExtension {
pub fn new(client_cert: Option<ClientCertInfo>, headers: HeaderMap) -> Self {
Self {
client_cert,
headers,
}
}
pub fn subject(&self) -> Option<&str> {
self.client_cert.as_ref().map(|cert| cert.subject.as_str())
}
pub fn issuer(&self) -> Option<&str> {
self.client_cert.as_ref().map(|cert| cert.issuer.as_str())
}
pub fn serial(&self) -> Option<&str> {
self.client_cert.as_ref().map(|cert| cert.serial.as_str())
}
pub fn is_valid(&self) -> bool {
if let Some(cert) = &self.client_cert {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let Ok(now) = i64::try_from(now) else {
return false;
};
cert.not_before <= now && now <= cert.not_after
} else {
false
}
}
pub fn additional_headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
if let Some(cert) = &self.client_cert {
try_insert_header(&mut headers, "x-client-cert-subject", &cert.subject);
try_insert_header(&mut headers, "x-client-cert-issuer", &cert.issuer);
try_insert_header(&mut headers, "x-client-cert-serial", &cert.serial);
let not_before = cert.not_before.to_string();
try_insert_header(&mut headers, "x-client-cert-not-before", ¬_before);
let not_after = cert.not_after.to_string();
try_insert_header(&mut headers, "x-client-cert-not-after", ¬_after);
headers.insert("x-auth-method", HeaderValue::from_static("mtls"));
}
headers
}
}
fn try_insert_header(headers: &mut HeaderMap, name: &'static str, value: &str) {
match HeaderValue::from_str(value) {
Ok(header_value) => {
headers.insert(name, header_value);
}
Err(err) => {
warn!("skipping header `{}` due to invalid value: {}", name, err);
}
}
}
pub struct ClientCertExtractor {
pub client_cert: Option<ClientCertInfo>,
}
impl<S> FromRequestParts<S> for ClientCertExtractor
where
S: Send + Sync,
{
type Rejection = axum::http::StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let client_cert = parts.extensions.get::<ClientCertInfo>().cloned();
Ok(Self { client_cert })
}
}
pub struct ClientCertMiddleware<S> {
inner: S,
}
impl<S> ClientCertMiddleware<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S> tower::Service<axum::extract::Request> for ClientCertMiddleware<S>
where
S: tower::Service<axum::extract::Request, Response = axum::response::Response>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: axum::extract::Request) -> Self::Future {
let client_cert = req.extensions().get::<ClientCertInfo>().cloned();
if let Some(cert) = client_cert {
req.extensions_mut().insert(cert);
}
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
Box::pin(async move { inner.call(req).await })
}
}
pub fn client_cert_middleware<S>(inner: S) -> ClientCertMiddleware<S> {
ClientCertMiddleware::new(inner)
}
#[derive(Clone, Debug)]
pub struct AuthContext {
pub service_id: u64,
pub auth_method: AuthMethod,
pub client_cert: Option<ClientCertInfo>,
pub additional_headers: HashMap<String, String>,
}
impl AuthContext {
pub fn new(service_id: u64, auth_method: AuthMethod) -> Self {
Self {
service_id,
auth_method,
client_cert: None,
additional_headers: HashMap::new(),
}
}
pub fn with_client_cert(mut self, client_cert: Option<ClientCertInfo>) -> Self {
self.client_cert = client_cert;
self
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.additional_headers = headers;
self
}
pub fn is_mtls(&self) -> bool {
matches!(self.auth_method, AuthMethod::Mtls)
}
pub fn client_cert_subject(&self) -> Option<&str> {
self.client_cert.as_ref().map(|cert| cert.subject.as_str())
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum AuthMethod {
ApiKey,
AccessToken,
Mtls,
OAuth,
}
pub struct AuthContextExtractor {
pub context: AuthContext,
}
impl<S> FromRequestParts<S> for AuthContextExtractor
where
S: Send + Sync,
{
type Rejection = axum::http::StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let context = parts
.extensions
.get::<AuthContext>()
.cloned()
.ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
Ok(Self { context })
}
}
pub fn inject_auth_context(
mut req: axum::extract::Request,
context: AuthContext,
) -> axum::extract::Request {
req.extensions_mut().insert(context);
req
}
pub fn extract_client_cert_from_request(req: &axum::extract::Request) -> Option<ClientCertInfo> {
req.extensions().get::<ClientCertInfo>().cloned()
}
pub fn extract_auth_context_from_request(req: &axum::extract::Request) -> Option<AuthContext> {
req.extensions().get::<AuthContext>().cloned()
}