blueprint_auth/
request_extensions.rs

1//! Request extension plumbing for client certificate identity
2//! Provides mechanisms to extract and inject mTLS identity information
3
4use axum::extract::FromRequestParts;
5use axum::http::request::Parts;
6use axum::http::{HeaderMap, HeaderValue};
7use std::collections::HashMap;
8use std::convert::TryFrom;
9
10use crate::tls_listener::ClientCertInfo;
11use tracing::warn;
12
13/// Request extension that carries client certificate information
14#[derive(Clone, Debug)]
15pub struct ClientCertExtension {
16    pub client_cert: Option<ClientCertInfo>,
17    pub headers: HeaderMap,
18}
19
20impl ClientCertExtension {
21    /// Create a new client certificate extension
22    pub fn new(client_cert: Option<ClientCertInfo>, headers: HeaderMap) -> Self {
23        Self {
24            client_cert,
25            headers,
26        }
27    }
28
29    /// Extract client certificate subject if available
30    pub fn subject(&self) -> Option<&str> {
31        self.client_cert.as_ref().map(|cert| cert.subject.as_str())
32    }
33
34    /// Extract client certificate issuer if available
35    pub fn issuer(&self) -> Option<&str> {
36        self.client_cert.as_ref().map(|cert| cert.issuer.as_str())
37    }
38
39    /// Extract client certificate serial if available
40    pub fn serial(&self) -> Option<&str> {
41        self.client_cert.as_ref().map(|cert| cert.serial.as_str())
42    }
43
44    /// Check if client certificate is valid (not expired)
45    pub fn is_valid(&self) -> bool {
46        if let Some(cert) = &self.client_cert {
47            let now = std::time::SystemTime::now()
48                .duration_since(std::time::UNIX_EPOCH)
49                .unwrap_or_default()
50                .as_secs();
51            let Ok(now) = i64::try_from(now) else {
52                return false;
53            };
54            cert.not_before <= now && now <= cert.not_after
55        } else {
56            false
57        }
58    }
59
60    /// Get additional headers to inject based on client certificate
61    pub fn additional_headers(&self) -> HeaderMap {
62        let mut headers = HeaderMap::new();
63
64        if let Some(cert) = &self.client_cert {
65            // Inject client certificate information as headers, skipping values that cannot be represented.
66            try_insert_header(&mut headers, "x-client-cert-subject", &cert.subject);
67            try_insert_header(&mut headers, "x-client-cert-issuer", &cert.issuer);
68            try_insert_header(&mut headers, "x-client-cert-serial", &cert.serial);
69
70            let not_before = cert.not_before.to_string();
71            try_insert_header(&mut headers, "x-client-cert-not-before", &not_before);
72
73            let not_after = cert.not_after.to_string();
74            try_insert_header(&mut headers, "x-client-cert-not-after", &not_after);
75
76            headers.insert("x-auth-method", HeaderValue::from_static("mtls"));
77        }
78
79        headers
80    }
81}
82
83fn try_insert_header(headers: &mut HeaderMap, name: &'static str, value: &str) {
84    match HeaderValue::from_str(value) {
85        Ok(header_value) => {
86            headers.insert(name, header_value);
87        }
88        Err(err) => {
89            warn!("skipping header `{}` due to invalid value: {}", name, err);
90        }
91    }
92}
93
94/// Extractor for client certificate information from request
95pub struct ClientCertExtractor {
96    pub client_cert: Option<ClientCertInfo>,
97}
98
99impl<S> FromRequestParts<S> for ClientCertExtractor
100where
101    S: Send + Sync,
102{
103    type Rejection = axum::http::StatusCode;
104
105    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
106        // Extract client certificate information from request extensions
107        let client_cert = parts.extensions.get::<ClientCertInfo>().cloned();
108
109        Ok(Self { client_cert })
110    }
111}
112
113/// Middleware to inject client certificate information into request extensions
114pub struct ClientCertMiddleware<S> {
115    inner: S,
116}
117
118impl<S> ClientCertMiddleware<S> {
119    pub fn new(inner: S) -> Self {
120        Self { inner }
121    }
122}
123
124impl<S> tower::Service<axum::extract::Request> for ClientCertMiddleware<S>
125where
126    S: tower::Service<axum::extract::Request, Response = axum::response::Response>
127        + Clone
128        + Send
129        + 'static,
130    S::Future: Send + 'static,
131{
132    type Response = S::Response;
133    type Error = S::Error;
134    type Future = std::pin::Pin<
135        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
136    >;
137
138    fn poll_ready(
139        &mut self,
140        cx: &mut std::task::Context<'_>,
141    ) -> std::task::Poll<Result<(), Self::Error>> {
142        self.inner.poll_ready(cx)
143    }
144
145    fn call(&mut self, mut req: axum::extract::Request) -> Self::Future {
146        // Extract client certificate information from the request
147        // This would typically come from the TLS connection
148        let client_cert = req.extensions().get::<ClientCertInfo>().cloned();
149
150        // Add client certificate extension to the request
151        if let Some(cert) = client_cert {
152            req.extensions_mut().insert(cert);
153        }
154
155        let inner = self.inner.clone();
156        let mut inner = std::mem::replace(&mut self.inner, inner);
157
158        Box::pin(async move { inner.call(req).await })
159    }
160}
161
162/// Helper function to create client certificate middleware
163pub fn client_cert_middleware<S>(inner: S) -> ClientCertMiddleware<S> {
164    ClientCertMiddleware::new(inner)
165}
166
167/// Request extension for authentication context
168#[derive(Clone, Debug)]
169pub struct AuthContext {
170    pub service_id: u64,
171    pub auth_method: AuthMethod,
172    pub client_cert: Option<ClientCertInfo>,
173    pub additional_headers: HashMap<String, String>,
174}
175
176impl AuthContext {
177    pub fn new(service_id: u64, auth_method: AuthMethod) -> Self {
178        Self {
179            service_id,
180            auth_method,
181            client_cert: None,
182            additional_headers: HashMap::new(),
183        }
184    }
185
186    pub fn with_client_cert(mut self, client_cert: Option<ClientCertInfo>) -> Self {
187        self.client_cert = client_cert;
188        self
189    }
190
191    pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
192        self.additional_headers = headers;
193        self
194    }
195
196    pub fn is_mtls(&self) -> bool {
197        matches!(self.auth_method, AuthMethod::Mtls)
198    }
199
200    pub fn client_cert_subject(&self) -> Option<&str> {
201        self.client_cert.as_ref().map(|cert| cert.subject.as_str())
202    }
203}
204
205/// Authentication method enum
206#[derive(Clone, Debug, PartialEq)]
207pub enum AuthMethod {
208    ApiKey,
209    AccessToken,
210    Mtls,
211    OAuth,
212}
213
214/// Extractor for authentication context
215pub struct AuthContextExtractor {
216    pub context: AuthContext,
217}
218
219impl<S> FromRequestParts<S> for AuthContextExtractor
220where
221    S: Send + Sync,
222{
223    type Rejection = axum::http::StatusCode;
224
225    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
226        // Extract authentication context from request extensions
227        let context = parts
228            .extensions
229            .get::<AuthContext>()
230            .cloned()
231            .ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
232
233        Ok(Self { context })
234    }
235}
236
237/// Helper function to inject authentication context into request
238pub fn inject_auth_context(
239    mut req: axum::extract::Request,
240    context: AuthContext,
241) -> axum::extract::Request {
242    req.extensions_mut().insert(context);
243    req
244}
245
246/// Helper function to extract client certificate from request
247pub fn extract_client_cert_from_request(req: &axum::extract::Request) -> Option<ClientCertInfo> {
248    req.extensions().get::<ClientCertInfo>().cloned()
249}
250
251/// Helper function to extract authentication context from request
252pub fn extract_auth_context_from_request(req: &axum::extract::Request) -> Option<AuthContext> {
253    req.extensions().get::<AuthContext>().cloned()
254}