1use crate::error::{AuthError, AuthResult};
2use crate::traits::Authn;
3use crate::types::{AuthProvider as AuthProviderType, AuthToken, UserContext};
4use axum::{
5 extract::{FromRequestParts, Request},
6 http::{request::Parts, StatusCode},
7 response::{IntoResponse, Response},
8};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
14pub struct AuthExtension {
15 pub user: UserContext,
16 provider_index: usize,
19}
20
21impl AuthExtension {
22 pub fn user(&self) -> &UserContext {
24 &self.user
25 }
26
27 pub fn user_mut(&mut self) -> &mut UserContext {
29 &mut self.user
30 }
31}
32
33#[derive(Clone)]
35pub struct AuthState {
36 providers: Arc<Vec<Box<dyn Authn>>>,
37 required: bool,
38 issuer_to_provider: Arc<Option<HashMap<String, AuthProviderType>>>,
40}
41
42impl AuthState {
43 pub fn new(providers: Vec<Box<dyn Authn>>) -> Self {
45 Self {
46 providers: Arc::new(providers),
47 required: false,
48 issuer_to_provider: Arc::new(None),
49 }
50 }
51
52 pub fn required(mut self, required: bool) -> Self {
54 self.required = required;
55 self
56 }
57
58 pub fn with_issuer_mapping(mut self, mapping: HashMap<String, AuthProviderType>) -> Self {
60 self.issuer_to_provider = Arc::new(Some(mapping));
61 self
62 }
63
64 fn extract_issuer_from_token(&self, token: &str) -> Option<String> {
66 use base64::Engine;
67
68 let parts: Vec<&str> = token.split('.').collect();
69 if parts.len() != 3 {
70 return None;
71 }
72
73 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
75 .decode(parts[1])
76 .ok()?;
77
78 let payload: serde_json::Value = serde_json::from_slice(&payload).ok()?;
79
80 payload
82 .get("iss")
83 .and_then(|v| v.as_str())
84 .map(|s| s.to_string())
85 }
86
87 async fn try_issuer_based_auth(
94 &self,
95 token: &AuthToken,
96 ) -> Option<AuthResult<(UserContext, usize)>> {
97 let issuer_map = self.issuer_to_provider.as_ref().as_ref()?;
98
99 let issuer = match self.extract_issuer_from_token(&token.token) {
101 Some(iss) => iss,
102 None => return None, };
104
105 let target_provider_type = match issuer_map.get(&issuer) {
107 Some(provider_type) => provider_type,
108 None => {
109 tracing::warn!(
112 issuer = %issuer,
113 "Token rejected: issuer not in allowlist"
114 );
115 return Some(Err(AuthError::InvalidToken {
116 message: format!("Token issuer '{}' is not trusted", issuer),
117 provider: None,
118 }));
119 }
120 };
121
122 for (index, provider) in self.providers.iter().enumerate() {
126 tracing::debug!(
127 issuer = %issuer,
128 provider_index = index,
129 "Attempting authentication via issuer-based routing"
130 );
131 match provider.authenticate(&token.token).await {
132 Ok(user) => return Some(Ok((user, index))),
133 Err(e) => {
134 tracing::debug!(
135 error = %e,
136 "Provider authentication failed, trying next"
137 );
138 continue;
139 }
140 }
141 }
142
143 tracing::error!(
144 provider = ?target_provider_type,
145 issuer = %issuer,
146 "Provider not found for issuer - configuration error"
147 );
148 Some(Err(AuthError::ConfigurationError {
149 message: format!("Provider for issuer '{}' not found", issuer),
150 }))
151 }
152
153 async fn try_fallback_auth(&self, token: &AuthToken) -> AuthResult<(UserContext, usize)> {
155 tracing::debug!("Trying fallback authentication");
156
157 for (index, provider) in self.providers.iter().enumerate() {
158 tracing::debug!(
159 provider_index = index,
160 "Attempting authentication with provider"
161 );
162 match provider.authenticate(&token.token).await {
163 Ok(user) => return Ok((user, index)),
164 Err(e) => {
165 tracing::debug!(
166 provider_index = index,
167 error = %e,
168 "Provider authentication failed, trying next"
169 );
170 continue;
171 }
172 }
173 }
174
175 Err(AuthError::ProviderError(
176 "No provider could handle the token".to_string(),
177 ))
178 }
179
180 async fn authenticate_token(&self, token: &AuthToken) -> AuthResult<(UserContext, usize)> {
184 if let Some(result) = self.try_issuer_based_auth(token).await {
186 return result;
187 }
188
189 self.try_fallback_auth(token).await
191 }
192
193 pub fn get_provider(&self, index: usize) -> Option<&dyn Authn> {
195 self.providers.get(index).map(|p| p.as_ref())
196 }
197}
198
199#[derive(Clone)]
201pub struct AuthLayer {
202 state: AuthState,
203}
204
205impl AuthLayer {
206 pub fn new(providers: Vec<Box<dyn Authn>>) -> Self {
208 Self {
209 state: AuthState::new(providers),
210 }
211 }
212
213 pub fn required(mut self, required: bool) -> Self {
215 self.state = self.state.required(required);
216 self
217 }
218
219 pub fn with_issuer_mapping(mut self, mapping: HashMap<String, AuthProviderType>) -> Self {
234 self.state = self.state.with_issuer_mapping(mapping);
235 self
236 }
237
238 #[deprecated(
251 note = "Auto issuer mapping not supported with new trait system. Use with_issuer_mapping() instead."
252 )]
253 pub fn with_auto_issuer_mapping(self) -> Self {
254 tracing::warn!("with_auto_issuer_mapping() is deprecated and has no effect. Use with_issuer_mapping() instead.");
255 self
256 }
257
258 pub async fn from_configs(configs: Vec<crate::provider::Config>) -> AuthResult<Self> {
279 let mut providers: Vec<Box<dyn Authn>> = Vec::new();
280
281 for config in configs {
282 Self::create_providers_from_config(config, &mut providers).await?;
283 }
284
285 Ok(Self::new(providers))
286 }
287
288 async fn create_providers_from_config(
290 config: crate::provider::Config,
291 providers: &mut Vec<Box<dyn Authn>>,
292 ) -> AuthResult<()> {
293 match config {
294 #[cfg(feature = "stytch")]
295 crate::provider::Config::Stytch {
296 project_id,
297 project_secret,
298 m2m_client_id,
299 m2m_client_secret,
300 } => {
301 use crate::provider::{
302 Provider, StytchB2BM2MProvider, StytchConsumerM2MProvider,
303 StytchConsumerSessionProvider, StytchProvider,
304 };
305
306 let auth_config = Provider {
307 config: crate::provider::Config::Stytch {
308 project_id,
309 project_secret,
310 m2m_client_id,
311 m2m_client_secret,
312 },
313 ..Default::default()
314 };
315
316 providers.push(Box::new(StytchProvider::new(&auth_config).await?));
318 providers.push(Box::new(StytchB2BM2MProvider::new(&auth_config).await?));
319 providers.push(Box::new(
320 StytchConsumerSessionProvider::new(&auth_config).await?,
321 ));
322 providers.push(Box::new(
323 StytchConsumerM2MProvider::new(&auth_config).await?,
324 ));
325 }
326
327 #[cfg(feature = "clerk")]
328 crate::provider::Config::Clerk {
329 publishable_key,
330 secret_key,
331 } => {
332 use crate::provider::{ClerkProvider, Provider};
333
334 let auth_config = Provider {
335 config: crate::provider::Config::Clerk {
336 publishable_key,
337 secret_key,
338 },
339 ..Default::default()
340 };
341
342 providers.push(Box::new(ClerkProvider::new(&auth_config).await?));
343 }
344
345 #[cfg(feature = "msal")]
346 crate::provider::Config::Msal {
347 tenant_id,
348 client_id,
349 client_secret,
350 } => {
351 use crate::provider::{MsalProvider, Provider};
352
353 let auth_config = Provider {
354 config: crate::provider::Config::Msal {
355 tenant_id,
356 client_id,
357 client_secret,
358 },
359 ..Default::default()
360 };
361
362 providers.push(Box::new(MsalProvider::new(&auth_config).await?));
363 }
364
365 crate::provider::Config::None => {
366 }
368 }
369
370 Ok(())
371 }
372}
373
374impl<S> tower_layer::Layer<S> for AuthLayer {
375 type Service = AuthMiddleware<S>;
376
377 fn layer(&self, inner: S) -> Self::Service {
378 AuthMiddleware {
379 inner,
380 state: self.state.clone(),
381 }
382 }
383}
384
385#[derive(Clone)]
387pub struct AuthMiddleware<S> {
388 inner: S,
389 state: AuthState,
390}
391
392impl<S> AuthMiddleware<S> {
393 fn extract_auth_token(req: &Request) -> Result<Option<AuthToken>, AuthRejection> {
395 let auth_header = req
396 .headers()
397 .get(axum::http::header::AUTHORIZATION)
398 .and_then(|h| h.to_str().ok());
399
400 match auth_header {
401 Some(header_value) => {
402 if let Some(token) = AuthToken::from_auth_header(header_value) {
403 Ok(Some(token))
404 } else {
405 tracing::warn!("Malformed Authorization header");
406 Err(AuthRejection::MalformedHeader)
407 }
408 }
409 None => Ok(None),
410 }
411 }
412
413 fn handle_auth_success(
415 user: UserContext,
416 provider_index: usize,
417 state: AuthState,
418 req: &mut Request,
419 ) {
420 tracing::info!(
421 user_id = %user.user_id,
422 provider = %user.provider,
423 "Authentication successful"
424 );
425
426 req.extensions_mut().insert(AuthExtension {
427 user,
428 provider_index,
429 });
430 req.extensions_mut().insert(state);
431 }
432
433 fn handle_auth_failure(error: &AuthError, required: bool) -> Option<Response> {
435 tracing::warn!(
436 error = %error,
437 provider = ?error.provider(),
438 "Authentication failed"
439 );
440
441 if required {
442 Some(AuthRejection::InvalidToken.into_response())
443 } else {
444 None
445 }
446 }
447}
448
449impl<S> tower::Service<Request> for AuthMiddleware<S>
450where
451 S: tower::Service<Request, Response = Response> + Clone + Send + 'static,
452 S::Future: Send + 'static,
453{
454 type Response = S::Response;
455 type Error = S::Error;
456 type Future = std::pin::Pin<
457 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
458 >;
459
460 fn poll_ready(
461 &mut self,
462 cx: &mut std::task::Context<'_>,
463 ) -> std::task::Poll<Result<(), Self::Error>> {
464 self.inner.poll_ready(cx)
465 }
466
467 fn call(&mut self, mut req: Request) -> Self::Future {
468 let clone = self.inner.clone();
469 let mut inner = std::mem::replace(&mut self.inner, clone);
470 let state = self.state.clone();
471
472 Box::pin(async move {
473 let token_result = Self::extract_auth_token(&req);
475
476 match token_result {
477 Ok(Some(token)) => {
478 match state.authenticate_token(&token).await {
480 Ok((user, provider_index)) => {
481 Self::handle_auth_success(user, provider_index, state, &mut req);
482 }
483 Err(e) => {
484 if let Some(response) = Self::handle_auth_failure(&e, state.required) {
485 return Ok(response);
486 }
487 }
488 }
489 }
490 Ok(None) => {
491 if state.required {
493 tracing::warn!("Missing Authorization header for protected endpoint");
494 return Ok(AuthRejection::MissingAuth.into_response());
495 }
496 }
497 Err(rejection) => {
498 if state.required {
500 return Ok(rejection.into_response());
501 }
502 }
503 }
504
505 inner.call(req).await
506 })
507 }
508}
509
510impl<S> FromRequestParts<S> for AuthExtension
512where
513 S: Send + Sync,
514{
515 type Rejection = AuthRejection;
516
517 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
518 parts
519 .extensions
520 .get::<AuthExtension>()
521 .cloned()
522 .ok_or(AuthRejection::MissingAuth)
523 }
524}
525
526pub struct OptionalAuth(pub Option<UserContext>);
528
529impl<S> FromRequestParts<S> for OptionalAuth
530where
531 S: Send + Sync,
532{
533 type Rejection = std::convert::Infallible;
534
535 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
536 Ok(OptionalAuth(
537 parts
538 .extensions
539 .get::<AuthExtension>()
540 .map(|ext| ext.user.clone()),
541 ))
542 }
543}
544
545pub struct AuthContext {
548 pub user: UserContext,
549 state: AuthState,
550 provider_index: usize,
551}
552
553impl AuthContext {
554 pub fn user(&self) -> &UserContext {
556 &self.user
557 }
558
559 pub async fn check_permission(&self, permission: &str) -> AuthResult<bool> {
562 if let Some(provider) = self.state.get_provider(self.provider_index) {
563 match provider.authorize(&self.user, permission).await {
564 Ok(()) => Ok(true),
565 Err(AuthError::InsufficientPermissions(_)) => Ok(false),
566 Err(e) => Err(e),
567 }
568 } else {
569 Ok(false)
570 }
571 }
572
573 pub async fn check_role(&self, role: &str) -> AuthResult<bool> {
576 Ok(self.user.has_role(role))
577 }
578
579 pub async fn get_roles(&self) -> AuthResult<Vec<String>> {
582 Ok(self.user.get_roles())
583 }
584
585 pub async fn check_organization(&self, org_id: &str) -> AuthResult<bool> {
588 Ok(self.user.organization_id() == Some(org_id))
589 }
590
591 pub async fn require_permission(&self, permission: &str) -> Result<(), StatusCode> {
593 if self.check_permission(permission).await.unwrap_or(false) {
594 Ok(())
595 } else {
596 Err(StatusCode::FORBIDDEN)
597 }
598 }
599
600 pub async fn require_role(&self, role: &str) -> Result<(), StatusCode> {
602 if self.check_role(role).await.unwrap_or(false) {
603 Ok(())
604 } else {
605 Err(StatusCode::FORBIDDEN)
606 }
607 }
608}
609
610impl<S> FromRequestParts<S> for AuthContext
611where
612 S: Send + Sync,
613{
614 type Rejection = AuthRejection;
615
616 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
617 let auth_ext = parts
618 .extensions
619 .get::<AuthExtension>()
620 .ok_or(AuthRejection::MissingAuth)?;
621
622 let state = parts
623 .extensions
624 .get::<AuthState>()
625 .ok_or(AuthRejection::MissingAuth)?;
626
627 Ok(AuthContext {
628 user: auth_ext.user.clone(),
629 state: state.clone(),
630 provider_index: auth_ext.provider_index,
631 })
632 }
633}
634
635#[derive(Debug)]
637pub enum AuthRejection {
638 MissingAuth,
639 InvalidToken,
640 MalformedHeader,
641}
642
643impl IntoResponse for AuthRejection {
644 fn into_response(self) -> Response {
645 let (status, message) = match self {
646 AuthRejection::MissingAuth => (StatusCode::UNAUTHORIZED, "Authentication required"),
647 AuthRejection::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid or expired token"),
648 AuthRejection::MalformedHeader => {
649 (StatusCode::BAD_REQUEST, "Malformed Authorization header")
650 }
651 };
652
653 (status, message).into_response()
654 }
655}
656
657#[doc = r##"
659# Example Usage
660
661Using the authentication middleware in an Axum router:
662
663```rust,ignore
664use axum::{Router, routing::get, http::StatusCode};
665use libauth_rs::middleware::{AuthLayer, AuthExtension, AuthContext};
666use libauth_rs::provider::{Provider, ClerkProvider, MsalProvider};
667
668#[tokio::main]
669async fn main() {
670 let config = Provider::default();
671
672 // Create multiple providers
673 let clerk = ClerkProvider::new(&config).await.unwrap();
674 let msal = MsalProvider::new(&config).await.unwrap();
675
676 // Create auth layer with automatic issuer-based routing
677 let auth_layer = AuthLayer::new(vec![Box::new(clerk), Box::new(msal)])
678 .with_auto_issuer_mapping() // Automatically route based on JWT issuer
679 .required(false);
680
681 let app = Router::new()
682 .route("/public", get(public_handler))
683 .route("/protected", get(protected_handler))
684 .route("/admin", get(admin_handler))
685 .layer(auth_layer);
686
687 // Run your server...
688}
689
690// Public handler - no auth required
691async fn public_handler(OptionalAuth(user): OptionalAuth) -> String {
692 match user {
693 Some(u) => format!("Hello, {}!", u.user_id),
694 None => "Hello, anonymous!".to_string(),
695 }
696}
697
698// Protected handler - auth required, basic user info
699async fn protected_handler(AuthExtension { user, .. }: AuthExtension) -> String {
700 format!("Welcome, {}!", user.user_id)
701}
702
703// Admin handler - auth required + role check using AuthContext
704async fn admin_handler(auth: AuthContext) -> Result<String, StatusCode> {
705 // Use provider-specific authorization
706 auth.require_role("admin").await?;
707
708 let roles = auth.get_roles().await.unwrap_or_default();
709 Ok(format!("Admin access granted! Your roles: {:?}", roles))
710}
711
712// Per-provider routers example
713async fn setup_per_provider_routers() -> Router {
714 let config = Provider::default();
715 let clerk = ClerkProvider::new(&config).await.unwrap();
716 let msal = MsalProvider::new(&config).await.unwrap();
717
718 // Clerk-specific routes
719 let clerk_router = Router::new()
720 .route("/org/:org_id/members", get(clerk_org_members))
721 .layer(AuthLayer::new(vec![Box::new(clerk)]).required(true));
722
723 // MSAL-specific routes
724 let msal_router = Router::new()
725 .route("/azure/groups", get(msal_groups))
726 .layer(AuthLayer::new(vec![Box::new(msal)]).required(true));
727
728 // Combine routers
729 Router::new()
730 .nest("/clerk", clerk_router)
731 .nest("/msal", msal_router)
732}
733```
734"##]
735pub(crate) mod _example {}