1use crate::error::{AuthError, AuthResult};
2use crate::providers::AuthProvider;
3use crate::types::{AuthProvider as AuthProviderType, AuthToken, UserContext};
4use axum::{
5 extract::{FromRequestParts, Request},
6 http::{request::Parts, StatusCode},
7 response::{IntoResponse, Response},
8};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
14pub struct AuthExtension {
15 pub user: UserContext,
16 provider_index: usize,
19}
20
21impl AuthExtension {
22 pub fn user(&self) -> &UserContext {
24 &self.user
25 }
26
27 pub fn user_mut(&mut self) -> &mut UserContext {
29 &mut self.user
30 }
31}
32
33#[derive(Clone)]
35pub struct AuthState {
36 providers: Arc<Vec<Box<dyn AuthProvider>>>,
37 required: bool,
38 issuer_to_provider: Arc<Option<HashMap<String, AuthProviderType>>>,
40}
41
42impl AuthState {
43 pub fn new(providers: Vec<Box<dyn AuthProvider>>) -> Self {
45 Self {
46 providers: Arc::new(providers),
47 required: false,
48 issuer_to_provider: Arc::new(None),
49 }
50 }
51
52 pub fn required(mut self, required: bool) -> Self {
54 self.required = required;
55 self
56 }
57
58 pub fn with_issuer_mapping(mut self, mapping: HashMap<String, AuthProviderType>) -> Self {
60 self.issuer_to_provider = Arc::new(Some(mapping));
61 self
62 }
63
64 fn extract_issuer_from_token(&self, token: &str) -> Option<String> {
66 use base64::Engine;
67
68 let parts: Vec<&str> = token.split('.').collect();
69 if parts.len() != 3 {
70 return None;
71 }
72
73 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
75 .decode(parts[1])
76 .ok()?;
77
78 let payload: serde_json::Value = serde_json::from_slice(&payload).ok()?;
79
80 payload
82 .get("iss")
83 .and_then(|v| v.as_str())
84 .map(|s| s.to_string())
85 }
86
87 async fn authenticate_token(&self, token: &AuthToken) -> AuthResult<(UserContext, usize)> {
91 if let Some(ref issuer_map) = *self.issuer_to_provider {
93 if let Some(issuer) = self.extract_issuer_from_token(&token.token) {
94 if let Some(target_provider_type) = issuer_map.get(&issuer) {
95 for (index, provider) in self.providers.iter().enumerate() {
97 if provider.provider_type() == *target_provider_type {
98 tracing::debug!(
99 "Routing to provider {:?} based on issuer: {}",
100 target_provider_type,
101 issuer
102 );
103 let user = provider.authenticate(token).await?;
104 return Ok((user, index));
105 }
106 }
107 tracing::warn!(
108 "Provider {:?} not found for issuer: {}",
109 target_provider_type,
110 issuer
111 );
112 }
113 }
114 }
115
116 for (index, provider) in self.providers.iter().enumerate() {
118 if provider.can_handle(token).await {
119 let user = provider.authenticate(token).await?;
120 return Ok((user, index));
121 }
122 }
123
124 Err(AuthError::ProviderError(
125 "No provider could handle the token".to_string(),
126 ))
127 }
128
129 pub fn get_provider(&self, index: usize) -> Option<&dyn AuthProvider> {
131 self.providers.get(index).map(|p| p.as_ref())
132 }
133}
134
135#[derive(Clone)]
137pub struct AuthLayer {
138 state: AuthState,
139}
140
141impl AuthLayer {
142 pub fn new(providers: Vec<Box<dyn AuthProvider>>) -> Self {
144 Self {
145 state: AuthState::new(providers),
146 }
147 }
148
149 pub fn required(mut self, required: bool) -> Self {
151 self.state = self.state.required(required);
152 self
153 }
154
155 pub fn with_issuer_mapping(mut self, mapping: HashMap<String, AuthProviderType>) -> Self {
170 self.state = self.state.with_issuer_mapping(mapping);
171 self
172 }
173
174 pub fn with_auto_issuer_mapping(mut self) -> Self {
183 let mut mapping = HashMap::new();
184
185 for provider in self.state.providers.iter() {
186 let provider_type = provider.provider_type();
187 for issuer in provider.expected_issuers() {
188 mapping.insert(issuer, provider_type);
189 }
190 }
191
192 if !mapping.is_empty() {
193 self.state = self.state.with_issuer_mapping(mapping);
194 }
195
196 self
197 }
198}
199
200impl<S> tower_layer::Layer<S> for AuthLayer {
201 type Service = AuthMiddleware<S>;
202
203 fn layer(&self, inner: S) -> Self::Service {
204 AuthMiddleware {
205 inner,
206 state: self.state.clone(),
207 }
208 }
209}
210
211#[derive(Clone)]
213pub struct AuthMiddleware<S> {
214 inner: S,
215 state: AuthState,
216}
217
218impl<S> tower::Service<Request> for AuthMiddleware<S>
219where
220 S: tower::Service<Request, Response = Response> + Clone + Send + 'static,
221 S::Future: Send + 'static,
222{
223 type Response = S::Response;
224 type Error = S::Error;
225 type Future = std::pin::Pin<
226 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
227 >;
228
229 fn poll_ready(
230 &mut self,
231 cx: &mut std::task::Context<'_>,
232 ) -> std::task::Poll<Result<(), Self::Error>> {
233 self.inner.poll_ready(cx)
234 }
235
236 fn call(&mut self, mut req: Request) -> Self::Future {
237 let clone = self.inner.clone();
238 let mut inner = std::mem::replace(&mut self.inner, clone);
239 let state = self.state.clone();
240
241 Box::pin(async move {
242 let auth_header = req
244 .headers()
245 .get(axum::http::header::AUTHORIZATION)
246 .and_then(|h| h.to_str().ok());
247
248 let auth_result = if let Some(header_value) = auth_header {
249 if let Some(token) = AuthToken::from_auth_header(header_value) {
250 match state.authenticate_token(&token).await {
251 Ok((user, provider_index)) => {
252 tracing::info!(
254 user_id = %user.user_id,
255 provider = %user.provider,
256 "Authentication successful"
257 );
258 Some((user, provider_index))
259 }
260 Err(e) => {
261 tracing::warn!(
263 error = %e,
264 provider = ?e.provider(),
265 "Authentication failed"
266 );
267 if state.required {
268 return Ok(StatusCode::UNAUTHORIZED.into_response());
269 }
270 None
271 }
272 }
273 } else {
274 tracing::warn!("Malformed Authorization header");
276 if state.required {
277 return Ok(StatusCode::UNAUTHORIZED.into_response());
278 }
279 None
280 }
281 } else {
282 if state.required {
284 tracing::warn!("Missing Authorization header for protected endpoint");
285 return Ok(StatusCode::UNAUTHORIZED.into_response());
286 }
287 None
288 };
289
290 if let Some((user, provider_index)) = auth_result {
292 req.extensions_mut().insert(AuthExtension {
293 user,
294 provider_index,
295 });
296 req.extensions_mut().insert(state.clone());
297 }
298
299 inner.call(req).await
300 })
301 }
302}
303
304impl<S> FromRequestParts<S> for AuthExtension
306where
307 S: Send + Sync,
308{
309 type Rejection = AuthRejection;
310
311 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
312 parts
313 .extensions
314 .get::<AuthExtension>()
315 .cloned()
316 .ok_or(AuthRejection::MissingAuth)
317 }
318}
319
320pub struct OptionalAuth(pub Option<UserContext>);
322
323impl<S> FromRequestParts<S> for OptionalAuth
324where
325 S: Send + Sync,
326{
327 type Rejection = std::convert::Infallible;
328
329 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
330 Ok(OptionalAuth(
331 parts
332 .extensions
333 .get::<AuthExtension>()
334 .map(|ext| ext.user.clone()),
335 ))
336 }
337}
338
339pub struct AuthContext {
342 pub user: UserContext,
343 state: AuthState,
344 provider_index: usize,
345}
346
347impl AuthContext {
348 pub fn user(&self) -> &UserContext {
350 &self.user
351 }
352
353 pub async fn check_permission(&self, permission: &str) -> AuthResult<bool> {
355 if let Some(provider) = self.state.get_provider(self.provider_index) {
356 provider.check_permission(&self.user, permission).await
357 } else {
358 Ok(false)
359 }
360 }
361
362 pub async fn check_role(&self, role: &str) -> AuthResult<bool> {
364 if let Some(provider) = self.state.get_provider(self.provider_index) {
365 provider.check_role(&self.user, role).await
366 } else {
367 Ok(false)
368 }
369 }
370
371 pub async fn get_roles(&self) -> AuthResult<Vec<String>> {
373 if let Some(provider) = self.state.get_provider(self.provider_index) {
374 provider.get_user_roles(&self.user).await
375 } else {
376 Ok(vec![])
377 }
378 }
379
380 pub async fn check_organization(&self, org_id: &str) -> AuthResult<bool> {
382 if let Some(provider) = self.state.get_provider(self.provider_index) {
383 provider.check_organization(&self.user, org_id).await
384 } else {
385 Ok(false)
386 }
387 }
388
389 pub async fn require_permission(&self, permission: &str) -> Result<(), StatusCode> {
391 if self.check_permission(permission).await.unwrap_or(false) {
392 Ok(())
393 } else {
394 Err(StatusCode::FORBIDDEN)
395 }
396 }
397
398 pub async fn require_role(&self, role: &str) -> Result<(), StatusCode> {
400 if self.check_role(role).await.unwrap_or(false) {
401 Ok(())
402 } else {
403 Err(StatusCode::FORBIDDEN)
404 }
405 }
406}
407
408impl<S> FromRequestParts<S> for AuthContext
409where
410 S: Send + Sync,
411{
412 type Rejection = AuthRejection;
413
414 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
415 let auth_ext = parts
416 .extensions
417 .get::<AuthExtension>()
418 .ok_or(AuthRejection::MissingAuth)?;
419
420 let state = parts
421 .extensions
422 .get::<AuthState>()
423 .ok_or(AuthRejection::MissingAuth)?;
424
425 Ok(AuthContext {
426 user: auth_ext.user.clone(),
427 state: state.clone(),
428 provider_index: auth_ext.provider_index,
429 })
430 }
431}
432
433#[derive(Debug)]
435pub enum AuthRejection {
436 MissingAuth,
437}
438
439impl IntoResponse for AuthRejection {
440 fn into_response(self) -> Response {
441 let (status, message) = match self {
442 AuthRejection::MissingAuth => (StatusCode::UNAUTHORIZED, "Authentication required"),
443 };
444
445 (status, message).into_response()
446 }
447}
448
449#[doc = r##"
451# Example Usage
452
453Using the authentication middleware in an Axum router:
454
455```rust,ignore
456use axum::{Router, routing::get, http::StatusCode};
457use libauth_rs::middleware::{AuthLayer, AuthExtension, AuthContext};
458use libauth_rs::providers::{AuthConfig, ClerkProvider, MsalProvider};
459
460#[tokio::main]
461async fn main() {
462 let config = AuthConfig::default();
463
464 // Create multiple providers
465 let clerk = ClerkProvider::new(&config).await.unwrap();
466 let msal = MsalProvider::new(&config).await.unwrap();
467
468 // Create auth layer with automatic issuer-based routing
469 let auth_layer = AuthLayer::new(vec![Box::new(clerk), Box::new(msal)])
470 .with_auto_issuer_mapping() // Automatically route based on JWT issuer
471 .required(false);
472
473 let app = Router::new()
474 .route("/public", get(public_handler))
475 .route("/protected", get(protected_handler))
476 .route("/admin", get(admin_handler))
477 .layer(auth_layer);
478
479 // Run your server...
480}
481
482// Public handler - no auth required
483async fn public_handler(OptionalAuth(user): OptionalAuth) -> String {
484 match user {
485 Some(u) => format!("Hello, {}!", u.user_id),
486 None => "Hello, anonymous!".to_string(),
487 }
488}
489
490// Protected handler - auth required, basic user info
491async fn protected_handler(AuthExtension { user, .. }: AuthExtension) -> String {
492 format!("Welcome, {}!", user.user_id)
493}
494
495// Admin handler - auth required + role check using AuthContext
496async fn admin_handler(auth: AuthContext) -> Result<String, StatusCode> {
497 // Use provider-specific authorization
498 auth.require_role("admin").await?;
499
500 let roles = auth.get_roles().await.unwrap_or_default();
501 Ok(format!("Admin access granted! Your roles: {:?}", roles))
502}
503
504// Per-provider routers example
505async fn setup_per_provider_routers() -> Router {
506 let config = AuthConfig::default();
507 let clerk = ClerkProvider::new(&config).await.unwrap();
508 let msal = MsalProvider::new(&config).await.unwrap();
509
510 // Clerk-specific routes
511 let clerk_router = Router::new()
512 .route("/org/:org_id/members", get(clerk_org_members))
513 .layer(AuthLayer::new(vec![Box::new(clerk)]).required(true));
514
515 // MSAL-specific routes
516 let msal_router = Router::new()
517 .route("/azure/groups", get(msal_groups))
518 .layer(AuthLayer::new(vec![Box::new(msal)]).required(true));
519
520 // Combine routers
521 Router::new()
522 .nest("/clerk", clerk_router)
523 .nest("/msal", msal_router)
524}
525```
526"##]
527pub(crate) mod _example {}