blueprint_auth/
request_extensions.rs1use axum::extract::FromRequestParts;
5use axum::http::request::Parts;
6use axum::http::{HeaderMap, HeaderValue};
7use std::collections::HashMap;
8use std::convert::TryFrom;
9
10use crate::tls_listener::ClientCertInfo;
11use tracing::warn;
12
13#[derive(Clone, Debug)]
15pub struct ClientCertExtension {
16 pub client_cert: Option<ClientCertInfo>,
17 pub headers: HeaderMap,
18}
19
20impl ClientCertExtension {
21 pub fn new(client_cert: Option<ClientCertInfo>, headers: HeaderMap) -> Self {
23 Self {
24 client_cert,
25 headers,
26 }
27 }
28
29 pub fn subject(&self) -> Option<&str> {
31 self.client_cert.as_ref().map(|cert| cert.subject.as_str())
32 }
33
34 pub fn issuer(&self) -> Option<&str> {
36 self.client_cert.as_ref().map(|cert| cert.issuer.as_str())
37 }
38
39 pub fn serial(&self) -> Option<&str> {
41 self.client_cert.as_ref().map(|cert| cert.serial.as_str())
42 }
43
44 pub fn is_valid(&self) -> bool {
46 if let Some(cert) = &self.client_cert {
47 let now = std::time::SystemTime::now()
48 .duration_since(std::time::UNIX_EPOCH)
49 .unwrap_or_default()
50 .as_secs();
51 let Ok(now) = i64::try_from(now) else {
52 return false;
53 };
54 cert.not_before <= now && now <= cert.not_after
55 } else {
56 false
57 }
58 }
59
60 pub fn additional_headers(&self) -> HeaderMap {
62 let mut headers = HeaderMap::new();
63
64 if let Some(cert) = &self.client_cert {
65 try_insert_header(&mut headers, "x-client-cert-subject", &cert.subject);
67 try_insert_header(&mut headers, "x-client-cert-issuer", &cert.issuer);
68 try_insert_header(&mut headers, "x-client-cert-serial", &cert.serial);
69
70 let not_before = cert.not_before.to_string();
71 try_insert_header(&mut headers, "x-client-cert-not-before", ¬_before);
72
73 let not_after = cert.not_after.to_string();
74 try_insert_header(&mut headers, "x-client-cert-not-after", ¬_after);
75
76 headers.insert("x-auth-method", HeaderValue::from_static("mtls"));
77 }
78
79 headers
80 }
81}
82
83fn try_insert_header(headers: &mut HeaderMap, name: &'static str, value: &str) {
84 match HeaderValue::from_str(value) {
85 Ok(header_value) => {
86 headers.insert(name, header_value);
87 }
88 Err(err) => {
89 warn!("skipping header `{}` due to invalid value: {}", name, err);
90 }
91 }
92}
93
94pub struct ClientCertExtractor {
96 pub client_cert: Option<ClientCertInfo>,
97}
98
99impl<S> FromRequestParts<S> for ClientCertExtractor
100where
101 S: Send + Sync,
102{
103 type Rejection = axum::http::StatusCode;
104
105 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
106 let client_cert = parts.extensions.get::<ClientCertInfo>().cloned();
108
109 Ok(Self { client_cert })
110 }
111}
112
113pub struct ClientCertMiddleware<S> {
115 inner: S,
116}
117
118impl<S> ClientCertMiddleware<S> {
119 pub fn new(inner: S) -> Self {
120 Self { inner }
121 }
122}
123
124impl<S> tower::Service<axum::extract::Request> for ClientCertMiddleware<S>
125where
126 S: tower::Service<axum::extract::Request, Response = axum::response::Response>
127 + Clone
128 + Send
129 + 'static,
130 S::Future: Send + 'static,
131{
132 type Response = S::Response;
133 type Error = S::Error;
134 type Future = std::pin::Pin<
135 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
136 >;
137
138 fn poll_ready(
139 &mut self,
140 cx: &mut std::task::Context<'_>,
141 ) -> std::task::Poll<Result<(), Self::Error>> {
142 self.inner.poll_ready(cx)
143 }
144
145 fn call(&mut self, mut req: axum::extract::Request) -> Self::Future {
146 let client_cert = req.extensions().get::<ClientCertInfo>().cloned();
149
150 if let Some(cert) = client_cert {
152 req.extensions_mut().insert(cert);
153 }
154
155 let inner = self.inner.clone();
156 let mut inner = std::mem::replace(&mut self.inner, inner);
157
158 Box::pin(async move { inner.call(req).await })
159 }
160}
161
162pub fn client_cert_middleware<S>(inner: S) -> ClientCertMiddleware<S> {
164 ClientCertMiddleware::new(inner)
165}
166
167#[derive(Clone, Debug)]
169pub struct AuthContext {
170 pub service_id: u64,
171 pub auth_method: AuthMethod,
172 pub client_cert: Option<ClientCertInfo>,
173 pub additional_headers: HashMap<String, String>,
174}
175
176impl AuthContext {
177 pub fn new(service_id: u64, auth_method: AuthMethod) -> Self {
178 Self {
179 service_id,
180 auth_method,
181 client_cert: None,
182 additional_headers: HashMap::new(),
183 }
184 }
185
186 pub fn with_client_cert(mut self, client_cert: Option<ClientCertInfo>) -> Self {
187 self.client_cert = client_cert;
188 self
189 }
190
191 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
192 self.additional_headers = headers;
193 self
194 }
195
196 pub fn is_mtls(&self) -> bool {
197 matches!(self.auth_method, AuthMethod::Mtls)
198 }
199
200 pub fn client_cert_subject(&self) -> Option<&str> {
201 self.client_cert.as_ref().map(|cert| cert.subject.as_str())
202 }
203}
204
205#[derive(Clone, Debug, PartialEq)]
207pub enum AuthMethod {
208 ApiKey,
209 AccessToken,
210 Mtls,
211 OAuth,
212}
213
214pub struct AuthContextExtractor {
216 pub context: AuthContext,
217}
218
219impl<S> FromRequestParts<S> for AuthContextExtractor
220where
221 S: Send + Sync,
222{
223 type Rejection = axum::http::StatusCode;
224
225 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
226 let context = parts
228 .extensions
229 .get::<AuthContext>()
230 .cloned()
231 .ok_or(axum::http::StatusCode::UNAUTHORIZED)?;
232
233 Ok(Self { context })
234 }
235}
236
237pub fn inject_auth_context(
239 mut req: axum::extract::Request,
240 context: AuthContext,
241) -> axum::extract::Request {
242 req.extensions_mut().insert(context);
243 req
244}
245
246pub fn extract_client_cert_from_request(req: &axum::extract::Request) -> Option<ClientCertInfo> {
248 req.extensions().get::<ClientCertInfo>().cloned()
249}
250
251pub fn extract_auth_context_from_request(req: &axum::extract::Request) -> Option<AuthContext> {
253 req.extensions().get::<AuthContext>().cloned()
254}