libauth_rs/middleware/
axum.rs

1use crate::error::{AuthError, AuthResult};
2use crate::traits::Authn;
3use crate::types::{AuthProvider as AuthProviderType, AuthToken, UserContext};
4use axum::{
5    extract::{FromRequestParts, Request},
6    http::{request::Parts, StatusCode},
7    response::{IntoResponse, Response},
8};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Extension for accessing UserContext in handlers
13#[derive(Debug, Clone)]
14pub struct AuthExtension {
15    pub user: UserContext,
16    /// Reference to the provider that authenticated this user
17    /// This allows for provider-specific authorization checks
18    provider_index: usize,
19}
20
21impl AuthExtension {
22    /// Get the user context
23    pub fn user(&self) -> &UserContext {
24        &self.user
25    }
26
27    /// Get mutable user context
28    pub fn user_mut(&mut self) -> &mut UserContext {
29        &mut self.user
30    }
31}
32
33/// Authentication state that holds configured providers
34#[derive(Clone)]
35pub struct AuthState {
36    providers: Arc<Vec<Box<dyn Authn>>>,
37    required: bool,
38    /// Optional mapping from issuer to provider type for issuer-based routing
39    issuer_to_provider: Arc<Option<HashMap<String, AuthProviderType>>>,
40}
41
42impl AuthState {
43    /// Create a new auth state with providers
44    pub fn new(providers: Vec<Box<dyn Authn>>) -> Self {
45        Self {
46            providers: Arc::new(providers),
47            required: false,
48            issuer_to_provider: Arc::new(None),
49        }
50    }
51
52    /// Set whether authentication is required (returns 401 if not authenticated)
53    pub fn required(mut self, required: bool) -> Self {
54        self.required = required;
55        self
56    }
57
58    /// Set issuer-to-provider mapping for routing based on JWT issuer (iss) claim
59    pub fn with_issuer_mapping(mut self, mapping: HashMap<String, AuthProviderType>) -> Self {
60        self.issuer_to_provider = Arc::new(Some(mapping));
61        self
62    }
63
64    /// Extract issuer (iss) from JWT token without full verification
65    fn extract_issuer_from_token(&self, token: &str) -> Option<String> {
66        use base64::Engine;
67
68        let parts: Vec<&str> = token.split('.').collect();
69        if parts.len() != 3 {
70            return None;
71        }
72
73        // Decode payload (JWT part 1 is payload)
74        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
75            .decode(parts[1])
76            .ok()?;
77
78        let payload: serde_json::Value = serde_json::from_slice(&payload).ok()?;
79
80        // Extract issuer (iss) claim
81        payload
82            .get("iss")
83            .and_then(|v| v.as_str())
84            .map(|s| s.to_string())
85    }
86
87    /// Try to authenticate using issuer-based routing (fast path)
88    ///
89    /// Returns:
90    /// - Some(Ok(...)) if issuer found and authentication succeeded
91    /// - Some(Err(...)) if issuer found but authentication failed or issuer not in allowlist
92    /// - None if issuer-based routing is not configured or token has no issuer
93    async fn try_issuer_based_auth(
94        &self,
95        token: &AuthToken,
96    ) -> Option<AuthResult<(UserContext, usize)>> {
97        let issuer_map = self.issuer_to_provider.as_ref().as_ref()?;
98
99        // Extract issuer from token - if no issuer, return None to fallback
100        let issuer = match self.extract_issuer_from_token(&token.token) {
101            Some(iss) => iss,
102            None => return None, // Token has no issuer, use fallback
103        };
104
105        // Check if issuer is in our allowlist
106        let target_provider_type = match issuer_map.get(&issuer) {
107            Some(provider_type) => provider_type,
108            None => {
109                // Issuer-based routing IS configured and token HAS an issuer,
110                // but issuer is NOT in allowlist - reject immediately for security
111                tracing::warn!(
112                    issuer = %issuer,
113                    "Token rejected: issuer not in allowlist"
114                );
115                return Some(Err(AuthError::InvalidToken {
116                    message: format!("Token issuer '{}' is not trusted", issuer),
117                    provider: None,
118                }));
119            }
120        };
121
122        // Try authenticating with the first provider (issuer routing simplified)
123        // Note: With the new trait system, we can't query provider_type anymore
124        // so we just try to authenticate and see if it works
125        for (index, provider) in self.providers.iter().enumerate() {
126            tracing::debug!(
127                issuer = %issuer,
128                provider_index = index,
129                "Attempting authentication via issuer-based routing"
130            );
131            match provider.authenticate(&token.token).await {
132                Ok(user) => return Some(Ok((user, index))),
133                Err(e) => {
134                    tracing::debug!(
135                        error = %e,
136                        "Provider authentication failed, trying next"
137                    );
138                    continue;
139                }
140            }
141        }
142
143        tracing::error!(
144            provider = ?target_provider_type,
145            issuer = %issuer,
146            "Provider not found for issuer - configuration error"
147        );
148        Some(Err(AuthError::ConfigurationError {
149            message: format!("Provider for issuer '{}' not found", issuer),
150        }))
151    }
152
153    /// Try to authenticate using all providers (fallback path)
154    async fn try_fallback_auth(&self, token: &AuthToken) -> AuthResult<(UserContext, usize)> {
155        tracing::debug!("Trying fallback authentication");
156
157        for (index, provider) in self.providers.iter().enumerate() {
158            tracing::debug!(
159                provider_index = index,
160                "Attempting authentication with provider"
161            );
162            match provider.authenticate(&token.token).await {
163                Ok(user) => return Ok((user, index)),
164                Err(e) => {
165                    tracing::debug!(
166                        provider_index = index,
167                        error = %e,
168                        "Provider authentication failed, trying next"
169                    );
170                    continue;
171                }
172            }
173        }
174
175        Err(AuthError::ProviderError(
176            "No provider could handle the token".to_string(),
177        ))
178    }
179
180    /// Try to authenticate a token using configured providers
181    /// If issuer-based routing is configured, routes to the appropriate provider first
182    /// Returns (UserContext, provider_index)
183    async fn authenticate_token(&self, token: &AuthToken) -> AuthResult<(UserContext, usize)> {
184        // Try issuer-based routing first if configured (fast path)
185        if let Some(result) = self.try_issuer_based_auth(token).await {
186            return result;
187        }
188
189        // Fallback: try all providers using can_handle
190        self.try_fallback_auth(token).await
191    }
192
193    /// Get a provider by index (used for authorization checks)
194    pub fn get_provider(&self, index: usize) -> Option<&dyn Authn> {
195        self.providers.get(index).map(|p| p.as_ref())
196    }
197}
198
199/// Tower Layer for authentication middleware
200#[derive(Clone)]
201pub struct AuthLayer {
202    state: AuthState,
203}
204
205impl AuthLayer {
206    /// Create a new auth layer with providers
207    pub fn new(providers: Vec<Box<dyn Authn>>) -> Self {
208        Self {
209            state: AuthState::new(providers),
210        }
211    }
212
213    /// Set whether authentication is required
214    pub fn required(mut self, required: bool) -> Self {
215        self.state = self.state.required(required);
216        self
217    }
218
219    /// Set issuer-to-provider mapping for routing based on JWT issuer (iss) claim
220    ///
221    /// # Example
222    /// ```rust,ignore
223    /// let config = AuthConfig::default();
224    /// let tenant_id = config.azure_tenant_id.as_ref().unwrap();
225    ///
226    /// let auth_layer = AuthLayer::new(providers)
227    ///     .with_issuer_mapping(HashMap::from([
228    ///         ("https://clerk.example.com".to_string(), AuthProvider::Clerk),
229    ///         (format!("https://login.microsoftonline.com/{}/v2.0", tenant_id), AuthProvider::Msal),
230    ///         (format!("https://sts.windows.net/{}/", tenant_id), AuthProvider::Msal),
231    ///     ]));
232    /// ```
233    pub fn with_issuer_mapping(mut self, mapping: HashMap<String, AuthProviderType>) -> Self {
234        self.state = self.state.with_issuer_mapping(mapping);
235        self
236    }
237
238    /// Automatically build issuer mapping from the configured providers
239    ///
240    /// NOTE: This method is currently disabled with the new trait system.
241    /// Use `with_issuer_mapping()` to manually configure issuer routing.
242    ///
243    /// # Example
244    /// ```rust,ignore
245    /// let auth_layer = AuthLayer::new(providers)
246    ///     .with_issuer_mapping([
247    ///         ("stytch.com/project-id".to_string(), AuthProviderType::Stytch),
248    ///     ]);
249    /// ```
250    #[deprecated(
251        note = "Auto issuer mapping not supported with new trait system. Use with_issuer_mapping() instead."
252    )]
253    pub fn with_auto_issuer_mapping(self) -> Self {
254        tracing::warn!("with_auto_issuer_mapping() is deprecated and has no effect. Use with_issuer_mapping() instead.");
255        self
256    }
257
258    /// Create AuthLayer from modern ProviderConfig enums
259    ///
260    /// This is the recommended way to configure authentication providers.
261    ///
262    /// # Example
263    /// ```rust,ignore
264    /// use libauth_rs::provider::Config;
265    /// use libauth_rs::middleware::AuthLayer;
266    ///
267    /// let configs = vec![
268    ///     Config::Stytch {
269    ///         project_id: env::var("STYTCH_PROJECT_ID").unwrap(),
270    ///         project_secret: env::var("STYTCH_SECRET").unwrap(),
271    ///         m2m_client_id: None,
272    ///         m2m_client_secret: None,
273    ///     },
274    /// ];
275    ///
276    /// let auth_layer = AuthLayer::from_configs(configs).await?;
277    /// ```
278    pub async fn from_configs(configs: Vec<crate::provider::Config>) -> AuthResult<Self> {
279        let mut providers: Vec<Box<dyn Authn>> = Vec::new();
280
281        for config in configs {
282            Self::create_providers_from_config(config, &mut providers).await?;
283        }
284
285        Ok(Self::new(providers))
286    }
287
288    /// Helper to create providers from a single config
289    async fn create_providers_from_config(
290        config: crate::provider::Config,
291        providers: &mut Vec<Box<dyn Authn>>,
292    ) -> AuthResult<()> {
293        match config {
294            #[cfg(feature = "stytch")]
295            crate::provider::Config::Stytch {
296                project_id,
297                project_secret,
298                m2m_client_id,
299                m2m_client_secret,
300            } => {
301                use crate::provider::{
302                    Provider, StytchB2BM2MProvider, StytchConsumerM2MProvider,
303                    StytchConsumerSessionProvider, StytchProvider,
304                };
305
306                let auth_config = Provider {
307                    config: crate::provider::Config::Stytch {
308                        project_id,
309                        project_secret,
310                        m2m_client_id,
311                        m2m_client_secret,
312                    },
313                    ..Default::default()
314                };
315
316                // Instantiate all 4 Stytch providers
317                providers.push(Box::new(StytchProvider::new(&auth_config).await?));
318                providers.push(Box::new(StytchB2BM2MProvider::new(&auth_config).await?));
319                providers.push(Box::new(
320                    StytchConsumerSessionProvider::new(&auth_config).await?,
321                ));
322                providers.push(Box::new(
323                    StytchConsumerM2MProvider::new(&auth_config).await?,
324                ));
325            }
326
327            #[cfg(feature = "clerk")]
328            crate::provider::Config::Clerk {
329                publishable_key,
330                secret_key,
331            } => {
332                use crate::provider::{ClerkProvider, Provider};
333
334                let auth_config = Provider {
335                    config: crate::provider::Config::Clerk {
336                        publishable_key,
337                        secret_key,
338                    },
339                    ..Default::default()
340                };
341
342                providers.push(Box::new(ClerkProvider::new(&auth_config).await?));
343            }
344
345            #[cfg(feature = "msal")]
346            crate::provider::Config::Msal {
347                tenant_id,
348                client_id,
349                client_secret,
350            } => {
351                use crate::provider::{MsalProvider, Provider};
352
353                let auth_config = Provider {
354                    config: crate::provider::Config::Msal {
355                        tenant_id,
356                        client_id,
357                        client_secret,
358                    },
359                    ..Default::default()
360                };
361
362                providers.push(Box::new(MsalProvider::new(&auth_config).await?));
363            }
364
365            crate::provider::Config::None => {
366                // Skip None variant
367            }
368        }
369
370        Ok(())
371    }
372}
373
374impl<S> tower_layer::Layer<S> for AuthLayer {
375    type Service = AuthMiddleware<S>;
376
377    fn layer(&self, inner: S) -> Self::Service {
378        AuthMiddleware {
379            inner,
380            state: self.state.clone(),
381        }
382    }
383}
384
385/// Middleware service
386#[derive(Clone)]
387pub struct AuthMiddleware<S> {
388    inner: S,
389    state: AuthState,
390}
391
392impl<S> AuthMiddleware<S> {
393    /// Extract and parse the Authorization header
394    fn extract_auth_token(req: &Request) -> Result<Option<AuthToken>, AuthRejection> {
395        let auth_header = req
396            .headers()
397            .get(axum::http::header::AUTHORIZATION)
398            .and_then(|h| h.to_str().ok());
399
400        match auth_header {
401            Some(header_value) => {
402                if let Some(token) = AuthToken::from_auth_header(header_value) {
403                    Ok(Some(token))
404                } else {
405                    tracing::warn!("Malformed Authorization header");
406                    Err(AuthRejection::MalformedHeader)
407                }
408            }
409            None => Ok(None),
410        }
411    }
412
413    /// Handle successful authentication
414    fn handle_auth_success(
415        user: UserContext,
416        provider_index: usize,
417        state: AuthState,
418        req: &mut Request,
419    ) {
420        tracing::info!(
421            user_id = %user.user_id,
422            provider = %user.provider,
423            "Authentication successful"
424        );
425
426        req.extensions_mut().insert(AuthExtension {
427            user,
428            provider_index,
429        });
430        req.extensions_mut().insert(state);
431    }
432
433    /// Handle authentication failure
434    fn handle_auth_failure(error: &AuthError, required: bool) -> Option<Response> {
435        tracing::warn!(
436            error = %error,
437            provider = ?error.provider(),
438            "Authentication failed"
439        );
440
441        if required {
442            Some(AuthRejection::InvalidToken.into_response())
443        } else {
444            None
445        }
446    }
447}
448
449impl<S> tower::Service<Request> for AuthMiddleware<S>
450where
451    S: tower::Service<Request, Response = Response> + Clone + Send + 'static,
452    S::Future: Send + 'static,
453{
454    type Response = S::Response;
455    type Error = S::Error;
456    type Future = std::pin::Pin<
457        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
458    >;
459
460    fn poll_ready(
461        &mut self,
462        cx: &mut std::task::Context<'_>,
463    ) -> std::task::Poll<Result<(), Self::Error>> {
464        self.inner.poll_ready(cx)
465    }
466
467    fn call(&mut self, mut req: Request) -> Self::Future {
468        let clone = self.inner.clone();
469        let mut inner = std::mem::replace(&mut self.inner, clone);
470        let state = self.state.clone();
471
472        Box::pin(async move {
473            // Extract token from Authorization header
474            let token_result = Self::extract_auth_token(&req);
475
476            match token_result {
477                Ok(Some(token)) => {
478                    // Token present, attempt authentication
479                    match state.authenticate_token(&token).await {
480                        Ok((user, provider_index)) => {
481                            Self::handle_auth_success(user, provider_index, state, &mut req);
482                        }
483                        Err(e) => {
484                            if let Some(response) = Self::handle_auth_failure(&e, state.required) {
485                                return Ok(response);
486                            }
487                        }
488                    }
489                }
490                Ok(None) => {
491                    // No token present
492                    if state.required {
493                        tracing::warn!("Missing Authorization header for protected endpoint");
494                        return Ok(AuthRejection::MissingAuth.into_response());
495                    }
496                }
497                Err(rejection) => {
498                    // Malformed token
499                    if state.required {
500                        return Ok(rejection.into_response());
501                    }
502                }
503            }
504
505            inner.call(req).await
506        })
507    }
508}
509
510/// Extractor for UserContext from request
511impl<S> FromRequestParts<S> for AuthExtension
512where
513    S: Send + Sync,
514{
515    type Rejection = AuthRejection;
516
517    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
518        parts
519            .extensions
520            .get::<AuthExtension>()
521            .cloned()
522            .ok_or(AuthRejection::MissingAuth)
523    }
524}
525
526/// Extractor for optional UserContext
527pub struct OptionalAuth(pub Option<UserContext>);
528
529impl<S> FromRequestParts<S> for OptionalAuth
530where
531    S: Send + Sync,
532{
533    type Rejection = std::convert::Infallible;
534
535    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
536        Ok(OptionalAuth(
537            parts
538                .extensions
539                .get::<AuthExtension>()
540                .map(|ext| ext.user.clone()),
541        ))
542    }
543}
544
545/// Extractor for AuthContext with provider access
546/// This allows handlers to perform provider-specific authorization checks
547pub struct AuthContext {
548    pub user: UserContext,
549    state: AuthState,
550    provider_index: usize,
551}
552
553impl AuthContext {
554    /// Get the user context
555    pub fn user(&self) -> &UserContext {
556        &self.user
557    }
558
559    /// Check if user has a specific permission
560    /// Uses the Authorizer trait to check permissions
561    pub async fn check_permission(&self, permission: &str) -> AuthResult<bool> {
562        if let Some(provider) = self.state.get_provider(self.provider_index) {
563            match provider.authorize(&self.user, permission).await {
564                Ok(()) => Ok(true),
565                Err(AuthError::InsufficientPermissions(_)) => Ok(false),
566                Err(e) => Err(e),
567            }
568        } else {
569            Ok(false)
570        }
571    }
572
573    /// Check if user has a specific role
574    /// Checks the user's metadata for role information
575    pub async fn check_role(&self, role: &str) -> AuthResult<bool> {
576        Ok(self.user.has_role(role))
577    }
578
579    /// Get all user roles
580    /// Extracts roles from user's metadata
581    pub async fn get_roles(&self) -> AuthResult<Vec<String>> {
582        Ok(self.user.get_roles())
583    }
584
585    /// Check if user belongs to a specific organization
586    /// Checks the user's organization_id in metadata
587    pub async fn check_organization(&self, org_id: &str) -> AuthResult<bool> {
588        Ok(self.user.organization_id() == Some(org_id))
589    }
590
591    /// Require a specific permission or return 403 Forbidden
592    pub async fn require_permission(&self, permission: &str) -> Result<(), StatusCode> {
593        if self.check_permission(permission).await.unwrap_or(false) {
594            Ok(())
595        } else {
596            Err(StatusCode::FORBIDDEN)
597        }
598    }
599
600    /// Require a specific role or return 403 Forbidden
601    pub async fn require_role(&self, role: &str) -> Result<(), StatusCode> {
602        if self.check_role(role).await.unwrap_or(false) {
603            Ok(())
604        } else {
605            Err(StatusCode::FORBIDDEN)
606        }
607    }
608}
609
610impl<S> FromRequestParts<S> for AuthContext
611where
612    S: Send + Sync,
613{
614    type Rejection = AuthRejection;
615
616    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
617        let auth_ext = parts
618            .extensions
619            .get::<AuthExtension>()
620            .ok_or(AuthRejection::MissingAuth)?;
621
622        let state = parts
623            .extensions
624            .get::<AuthState>()
625            .ok_or(AuthRejection::MissingAuth)?;
626
627        Ok(AuthContext {
628            user: auth_ext.user.clone(),
629            state: state.clone(),
630            provider_index: auth_ext.provider_index,
631        })
632    }
633}
634
635/// Rejection type for auth extraction
636#[derive(Debug)]
637pub enum AuthRejection {
638    MissingAuth,
639    InvalidToken,
640    MalformedHeader,
641}
642
643impl IntoResponse for AuthRejection {
644    fn into_response(self) -> Response {
645        let (status, message) = match self {
646            AuthRejection::MissingAuth => (StatusCode::UNAUTHORIZED, "Authentication required"),
647            AuthRejection::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid or expired token"),
648            AuthRejection::MalformedHeader => {
649                (StatusCode::BAD_REQUEST, "Malformed Authorization header")
650            }
651        };
652
653        (status, message).into_response()
654    }
655}
656
657// Example usage documentation module
658#[doc = r##"
659# Example Usage
660
661Using the authentication middleware in an Axum router:
662
663```rust,ignore
664use axum::{Router, routing::get, http::StatusCode};
665use libauth_rs::middleware::{AuthLayer, AuthExtension, AuthContext};
666use libauth_rs::provider::{Provider, ClerkProvider, MsalProvider};
667
668#[tokio::main]
669async fn main() {
670    let config = Provider::default();
671
672    // Create multiple providers
673    let clerk = ClerkProvider::new(&config).await.unwrap();
674    let msal = MsalProvider::new(&config).await.unwrap();
675
676    // Create auth layer with automatic issuer-based routing
677    let auth_layer = AuthLayer::new(vec![Box::new(clerk), Box::new(msal)])
678        .with_auto_issuer_mapping() // Automatically route based on JWT issuer
679        .required(false);
680
681    let app = Router::new()
682        .route("/public", get(public_handler))
683        .route("/protected", get(protected_handler))
684        .route("/admin", get(admin_handler))
685        .layer(auth_layer);
686
687    // Run your server...
688}
689
690// Public handler - no auth required
691async fn public_handler(OptionalAuth(user): OptionalAuth) -> String {
692    match user {
693        Some(u) => format!("Hello, {}!", u.user_id),
694        None => "Hello, anonymous!".to_string(),
695    }
696}
697
698// Protected handler - auth required, basic user info
699async fn protected_handler(AuthExtension { user, .. }: AuthExtension) -> String {
700    format!("Welcome, {}!", user.user_id)
701}
702
703// Admin handler - auth required + role check using AuthContext
704async fn admin_handler(auth: AuthContext) -> Result<String, StatusCode> {
705    // Use provider-specific authorization
706    auth.require_role("admin").await?;
707
708    let roles = auth.get_roles().await.unwrap_or_default();
709    Ok(format!("Admin access granted! Your roles: {:?}", roles))
710}
711
712// Per-provider routers example
713async fn setup_per_provider_routers() -> Router {
714    let config = Provider::default();
715    let clerk = ClerkProvider::new(&config).await.unwrap();
716    let msal = MsalProvider::new(&config).await.unwrap();
717
718    // Clerk-specific routes
719    let clerk_router = Router::new()
720        .route("/org/:org_id/members", get(clerk_org_members))
721        .layer(AuthLayer::new(vec![Box::new(clerk)]).required(true));
722
723    // MSAL-specific routes
724    let msal_router = Router::new()
725        .route("/azure/groups", get(msal_groups))
726        .layer(AuthLayer::new(vec![Box::new(msal)]).required(true));
727
728    // Combine routers
729    Router::new()
730        .nest("/clerk", clerk_router)
731        .nest("/msal", msal_router)
732}
733```
734"##]
735pub(crate) mod _example {}