Skip to main content

axum_oidc_client/
auth.rs

1//! Core authentication module for OAuth2/OIDC with PKCE support.
2//!
3//! This module provides the main authentication layer and configuration types
4//! for integrating OAuth2 authentication into Axum applications.
5//!
6//! # Main Types
7//!
8//! - [`AuthLayer`] - Tower layer for adding authentication to your Axum app
9//! - [`OAuthConfiguration`] - Configuration for OAuth2 endpoints and credentials
10//! - [`CodeChallengeMethod`] - PKCE code challenge method (S256 or Plain)
11//! - [`LogoutHandler`] - Trait for implementing custom logout behavior
12//!
13//! # Examples
14//!
15//! ```rust,no_run
16//! use axum::{Router, routing::get};
17//! use axum_oidc_client::{
18//!     auth::{AuthLayer, CodeChallengeMethod},
19//!     auth_builder::OAuthConfigurationBuilder,
20//!     auth_cache::AuthCache,
21//!     logout::handle_default_logout::DefaultLogoutHandler,
22//! };
23//! use std::sync::Arc;
24//!
25//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
26//! let config = OAuthConfigurationBuilder::default()
27//!     .with_authorization_endpoint("https://provider.com/oauth/authorize")
28//!     .with_token_endpoint("https://provider.com/oauth/token")
29//!     .with_client_id("client-id")
30//!     .with_client_secret("client-secret")
31//!     .with_redirect_uri("http://localhost:8080/auth/callback")
32//!     .with_private_cookie_key("secret-key")
33//!     .with_scopes(vec!["openid", "email"])
34//!     .build()?;
35//!
36//! # #[cfg(feature = "redis")]
37//! let cache: Arc<dyn AuthCache + Send + Sync> = Arc::new(
38//!     axum_oidc_client::redis::AuthCache::new("redis://127.0.0.1/", 3600)
39//! );
40//!
41//! let logout_handler = Arc::new(DefaultLogoutHandler);
42//!
43//! let app = Router::new()
44//!     .route("/", get(|| async { "Hello!" }))
45//!     .layer(AuthLayer::new(Arc::new(config), cache, logout_handler));
46//! # Ok(())
47//! # }
48//! ```
49
50use axum::{
51    extract::Request,
52    response::{IntoResponse, Redirect, Response},
53};
54use axum_extra::extract::{cookie::Key, PrivateCookieJar};
55use chrono::{Duration, Local};
56use futures_util::future::BoxFuture;
57use http::request::Parts;
58use pkce_std::Method;
59use reqwest::Client;
60
61use std::{
62    fmt::Display,
63    sync::Arc,
64    task::{Context, Poll},
65};
66use tower::{Layer, Service};
67
68use crate::{
69    auth_cache::AuthCache,
70    auth_router::{
71        handle_auth::handle_auth,
72        handle_callback::{handle_callback, AccessTokenResponse},
73        handle_default::handle_default,
74    },
75    auth_session::AuthSession,
76    errors::Error,
77};
78
79/// PKCE code challenge method.
80///
81/// Defines how the code verifier is transformed into a code challenge
82/// during the OAuth2 PKCE flow.
83///
84/// # Variants
85///
86/// - `S256` - SHA-256 hash of the code verifier (recommended)
87/// - `Plain` - Plain text code verifier (not recommended for production)
88///
89/// # Examples
90///
91/// ```
92/// use axum_oidc_client::auth::CodeChallengeMethod;
93///
94/// let method = CodeChallengeMethod::S256;
95/// assert_eq!(method.to_string(), "S256");
96///
97/// let plain = CodeChallengeMethod::Plain;
98/// assert_eq!(plain.to_string(), "plain");
99/// ```
100#[derive(Debug, Clone, PartialEq, Default)]
101pub enum CodeChallengeMethod {
102    /// SHA-256 hashing method (recommended, default)
103    #[default]
104    S256,
105    /// Plain text method (not recommended for production)
106    Plain,
107}
108
109impl Display for CodeChallengeMethod {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        match self {
112            CodeChallengeMethod::S256 => write!(f, "S256"),
113            CodeChallengeMethod::Plain => write!(f, "plain"),
114        }
115    }
116}
117
118/// Calculate token expiration time based on expires_in and token_max_age.
119///
120/// This function determines when a token should be considered expired,
121/// taking into account both the provider's expiration time and the
122/// application's configured maximum token age.
123///
124/// # Arguments
125///
126/// * `expires_in` - Seconds until token expiration from the OAuth provider
127/// * `token_max_age` - Maximum allowed token age in seconds from configuration
128///
129/// # Returns
130///
131/// The current time plus the calculated expiration duration.
132/// Returns the maximum of:
133/// - 1 second (minimum)
134/// - The maximum of (expires_in - 1) and token_max_age
135///
136/// # Examples
137///
138/// ```ignore
139/// // Token expires in 3600 seconds, max age is 1800
140/// let expiration = calculate_token_expiration(3600, 1800);
141/// // Uses 3599 seconds (expires_in - 1)
142///
143/// // Token expires in 300 seconds, max age is 1800
144/// let expiration = calculate_token_expiration(300, 1800);
145/// // Uses 1800 seconds (token_max_age)
146/// ```
147pub fn calculate_token_expiration(
148    expires_in: i64,
149    token_max_age: Option<i64>,
150) -> chrono::DateTime<Local> {
151    Local::now()
152        + Duration::seconds(std::cmp::max(
153            1,
154            std::cmp::min(expires_in - 1, token_max_age.unwrap_or(0)),
155        ))
156}
157
158impl AuthSession {
159    pub fn new(response: &AccessTokenResponse, conf: &OAuthConfiguration) -> Self {
160        AuthSession {
161            id_token: response.id_token.to_owned(),
162            access_token: response.access_token.to_owned(),
163            token_type: response.token_type.to_owned(),
164            refresh_token: response.refresh_token.to_owned(),
165            scope: response.scope.to_owned(),
166            expires: calculate_token_expiration(response.expires_in, conf.token_max_age),
167        }
168    }
169}
170
171impl From<CodeChallengeMethod> for Method {
172    fn from(method: CodeChallengeMethod) -> Self {
173        match method {
174            CodeChallengeMethod::S256 => Method::Sha256,
175            CodeChallengeMethod::Plain => Method::Plain,
176        }
177    }
178}
179
180/// OAuth2/OIDC configuration.
181///
182/// Contains all necessary configuration for OAuth2 authentication including
183/// endpoints, credentials, and session management settings.
184///
185/// # Fields
186///
187/// * `private_cookie_key` - Secret key for encrypting session cookies
188/// * `client_id` - OAuth2 client identifier
189/// * `client_secret` - OAuth2 client secret
190/// * `redirect_uri` - URI where the provider redirects after authentication
191/// * `authorization_endpoint` - OAuth2 authorization endpoint URL
192/// * `token_endpoint` - OAuth2 token endpoint URL
193/// * `end_session_endpoint` - Optional OIDC end session endpoint URL
194/// * `post_logout_redirect_uri` - URI to redirect to after logout
195/// * `scopes` - Space-separated list of OAuth2 scopes
196/// * `code_challenge_method` - PKCE code challenge method
197/// * `custom_ca_cert` - Optional path to custom CA certificate
198/// * `session_max_age` - Maximum session age in seconds
199/// * `token_max_age` - Optional maximum token age in seconds
200///
201/// # Examples
202///
203/// Use [`crate::auth_builder::OAuthConfigurationBuilder`] to construct:
204///
205/// ```rust,no_run
206/// use axum_oidc_client::auth_builder::OAuthConfigurationBuilder;
207///
208/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
209/// let config = OAuthConfigurationBuilder::default()
210///     .with_authorization_endpoint("https://provider.com/oauth/authorize")
211///     .with_token_endpoint("https://provider.com/oauth/token")
212///     .with_client_id("my-client-id")
213///     .with_client_secret("my-client-secret")
214///     .with_redirect_uri("http://localhost:8080/auth/callback")
215///     .with_private_cookie_key("secret-key-at-least-32-bytes")
216///     .with_scopes(vec!["openid", "email", "profile"])
217///     .build()?;
218/// # Ok(())
219/// # }
220/// ```
221#[derive(Clone)]
222pub struct OAuthConfiguration {
223    /// Secret key for encrypting session cookies
224    pub private_cookie_key: Key,
225    /// OAuth2 client identifier
226    pub client_id: String,
227    /// OAuth2 client secret
228    pub client_secret: String,
229    /// Redirect URI for OAuth2 callback
230    pub redirect_uri: String,
231    /// OAuth2 authorization endpoint URL
232    pub authorization_endpoint: String,
233    /// OAuth2 token endpoint URL
234    pub token_endpoint: String,
235    /// Optional OIDC end session endpoint URL
236    pub end_session_endpoint: Option<String>,
237    /// URI to redirect to after logout
238    pub post_logout_redirect_uri: String,
239    /// Space-separated list of OAuth2 scopes
240    pub scopes: String,
241    /// PKCE code challenge method
242    pub code_challenge_method: CodeChallengeMethod,
243    /// Optional path to custom CA certificate file
244    pub custom_ca_cert: Option<String>,
245    /// Maximum session age in seconds
246    pub session_max_age: i64,
247    /// Optional maximum token age in seconds
248    pub token_max_age: Option<i64>,
249    /// Base path for authentication routes (default: "/auth")
250    pub base_path: String,
251}
252
253/// Session cookie key name.
254///
255/// This constant defines the name of the cookie used to store the session identifier.
256pub const SESSION_KEY: &str = "AUTH_SESSION";
257
258/// Trait for handling logout behavior.
259///
260/// Implement this trait to customize the logout process for your application.
261/// The library provides two built-in implementations:
262/// - [`crate::logout::handle_default_logout::DefaultLogoutHandler`] - Simple logout with session cleanup
263/// - [`crate::logout::handle_oidc_logout::OidcLogoutHandler`] - OIDC logout with provider notification
264///
265/// # Examples
266///
267/// ## Using the Default Handler
268///
269/// ```rust,no_run
270/// use axum_oidc_client::logout::handle_default_logout::DefaultLogoutHandler;
271/// use std::sync::Arc;
272///
273/// let logout_handler = Arc::new(DefaultLogoutHandler);
274/// ```
275///
276/// ## Using the OIDC Handler
277///
278/// ```rust,no_run
279/// use axum_oidc_client::logout::handle_oidc_logout::OidcLogoutHandler;
280/// use std::sync::Arc;
281///
282/// let logout_handler = Arc::new(
283///     OidcLogoutHandler::new("https://provider.com/oauth/logout")
284/// );
285/// ```
286///
287/// ## Custom Implementation
288///
289/// ```rust,no_run
290/// use axum_oidc_client::auth::{LogoutHandler, OAuthConfiguration};
291/// use axum_oidc_client::auth_cache::AuthCache;
292/// use axum_oidc_client::errors::Error;
293/// use axum::response::Response;
294/// use http::request::Parts;
295/// use std::sync::Arc;
296/// use futures_util::future::BoxFuture;
297///
298/// struct CustomLogoutHandler;
299///
300/// impl LogoutHandler for CustomLogoutHandler {
301///     fn handle_logout<'a>(
302///         &'a self,
303///         parts: &'a mut Parts,
304///         configuration: Arc<OAuthConfiguration>,
305///         cache: Arc<dyn AuthCache + Send + Sync>,
306///     ) -> BoxFuture<'a, Result<Response, Error>> {
307///         Box::pin(async move {
308///             // Custom logout logic here
309///             # unimplemented!()
310///         })
311///     }
312/// }
313/// ```
314pub trait LogoutHandler: Send + Sync {
315    /// Handle the logout request.
316    ///
317    /// This method is called when a user requests to log out. Implementations should:
318    /// 1. Remove the session cookie
319    /// 2. Invalidate the session in the cache
320    /// 3. Optionally notify the OAuth provider
321    /// 4. Redirect the user appropriately
322    ///
323    /// # Arguments
324    ///
325    /// * `parts` - The request parts containing headers, extensions, and query parameters
326    /// * `configuration` - The OAuth configuration
327    /// * `cache` - The authentication cache for session storage
328    ///
329    /// # Returns
330    ///
331    /// A future that resolves to either:
332    /// * `Ok(Response)` - A successful logout response (typically a redirect)
333    /// * `Err(Error)` - An error if logout fails
334    ///
335    /// # Returns
336    /// A response that handles the logout (typically a redirect or HTML page)
337    fn handle_logout<'a>(
338        &'a self,
339        parts: &'a mut Parts,
340        configuration: Arc<OAuthConfiguration>,
341        cache: Arc<dyn AuthCache + Send + Sync>,
342    ) -> BoxFuture<'a, Result<Response, Error>>;
343}
344
345#[derive(Clone)]
346pub struct AuthLayer {
347    oauth_client: Arc<Client>,
348    configuration: Arc<OAuthConfiguration>,
349    cache: Arc<dyn AuthCache + Send + Sync>,
350    logout_handler: Arc<dyn LogoutHandler>,
351}
352
353impl AuthLayer {
354    pub fn new(
355        configuration: Arc<OAuthConfiguration>,
356        cache: Arc<dyn AuthCache + Send + Sync>,
357        logout_handler: Arc<dyn LogoutHandler>,
358    ) -> Self {
359        let oauth_client = Arc::new(
360            match configuration.custom_ca_cert.clone() {
361                Some(custom_ca_cert) => {
362                    let cert = std::fs::read(custom_ca_cert).unwrap();
363                    let cert = reqwest::Certificate::from_pem(&cert).unwrap();
364                    reqwest::ClientBuilder::new()
365                        .add_root_certificate(cert)
366                        .use_rustls_tls()
367                }
368                None => reqwest::ClientBuilder::new(),
369            }
370            .build()
371            .unwrap(),
372        );
373        Self {
374            configuration,
375            cache,
376            oauth_client,
377            logout_handler,
378        }
379    }
380
381    /// Create a new AuthLayer with a custom logout handler
382    ///
383    /// This is an alias for `new()` and is provided for backwards compatibility.
384    pub fn with_logout_handler(
385        configuration: Arc<OAuthConfiguration>,
386        cache: Arc<dyn AuthCache + Send + Sync>,
387        logout_handler: Arc<dyn LogoutHandler>,
388    ) -> Self {
389        Self::new(configuration, cache, logout_handler)
390    }
391}
392
393impl<S> Layer<S> for AuthLayer {
394    type Service = AuthMiddleware<S>;
395
396    fn layer(&self, inner: S) -> Self::Service {
397        AuthMiddleware {
398            inner,
399            configuration: self.configuration.clone(),
400            cache: self.cache.clone(),
401            oauth_client: self.oauth_client.clone(),
402            logout_handler: self.logout_handler.clone(),
403        }
404    }
405}
406
407#[derive(Clone)]
408pub struct AuthMiddleware<S> {
409    inner: S,
410    configuration: Arc<OAuthConfiguration>,
411    cache: Arc<dyn AuthCache + Send + Sync>,
412    oauth_client: Arc<Client>,
413    logout_handler: Arc<dyn LogoutHandler>,
414}
415
416impl<S> Service<Request> for AuthMiddleware<S>
417where
418    S: Service<Request, Response = Response> + Send + 'static,
419    S::Future: Send + 'static,
420{
421    type Response = S::Response;
422    type Error = S::Error;
423
424    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
425
426    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
427        self.inner.poll_ready(cx)
428    }
429
430    fn call(&mut self, mut request: Request) -> Self::Future {
431        let OAuthConfiguration {
432            private_cookie_key, ..
433        } = self.configuration.as_ref();
434        let headers = request.headers().clone();
435        let uri = request.uri().clone();
436        let path = uri.path().to_string();
437        let jar = PrivateCookieJar::from_headers(&headers, private_cookie_key.to_owned());
438
439        let cache = self.cache.clone();
440        let configuration = self.configuration.clone();
441        let client = self.oauth_client.clone();
442
443        // Add extensions to request for extractors
444        request.extensions_mut().insert(cache.clone());
445        request.extensions_mut().insert(configuration.clone());
446        request.extensions_mut().insert(client.clone());
447
448        let session_id = jar
449            .get(SESSION_KEY)
450            .map(|cookie| cookie.value().to_string());
451
452        // Build the auth routes dynamically based on base_path from configuration
453        let base_path = &configuration.base_path;
454        let auth_route = base_path.clone();
455        let callback_route = format!("{}/callback", base_path);
456        let logout_route = format!("{}/logout", base_path);
457
458        match path.as_str() {
459            p if p == auth_route => Box::pin(async move {
460                match handle_auth(configuration, cache).await {
461                    Ok(response) => Ok(response),
462                    Err(err) => Ok(err.into_response()),
463                }
464            }),
465            p if p == callback_route => {
466                let (mut parts, _) = request.into_parts();
467                Box::pin(async move {
468                    match handle_callback(&mut parts, uri).await {
469                        Ok(response) => Ok(response),
470                        Err(err) => match err {
471                            Error::MissingCodeVerifier => {
472                                Ok((jar, Redirect::temporary("/MissingCodeVerifier"))
473                                    .into_response())
474                            }
475                            _ => Ok(err.into_response()),
476                        },
477                    }
478                })
479            }
480            p if p == logout_route => {
481                let (mut parts, _) = request.into_parts();
482                let logout_handler = self.logout_handler.clone();
483                Box::pin(async move {
484                    match logout_handler
485                        .handle_logout(&mut parts, configuration, cache)
486                        .await
487                    {
488                        Ok(response) => Ok(response),
489                        Err(err) => Ok(err.into_response()),
490                    }
491                })
492            }
493            _ => {
494                let future = self.inner.call(request);
495                Box::pin(async move {
496                    handle_default(configuration, cache, jar, session_id, future).await
497                })
498            }
499        }
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use crate::auth_session::AuthSession;
507
508    // Mock cache for testing
509    #[allow(dead_code)]
510    struct MockCache;
511
512    impl AuthCache for MockCache {
513        fn get_code_verifier(
514            &self,
515            _challenge_state: &str,
516        ) -> BoxFuture<'_, Result<Option<String>, Error>> {
517            Box::pin(async { Ok(None) })
518        }
519
520        fn set_code_verifier(
521            &self,
522            _challenge_state: &str,
523            _code_verifier: &str,
524        ) -> BoxFuture<'_, Result<(), Error>> {
525            Box::pin(async { Ok(()) })
526        }
527
528        fn invalidate_code_verifier(
529            &self,
530            _challenge_state: &str,
531        ) -> BoxFuture<'_, Result<(), Error>> {
532            Box::pin(async { Ok(()) })
533        }
534
535        fn get_auth_session(&self, _id: &str) -> BoxFuture<'_, Result<Option<AuthSession>, Error>> {
536            Box::pin(async { Ok(None) })
537        }
538
539        fn set_auth_session(
540            &self,
541            _id: &str,
542            _session: AuthSession,
543        ) -> BoxFuture<'_, Result<(), Error>> {
544            Box::pin(async { Ok(()) })
545        }
546
547        fn invalidate_auth_session(&self, _id: &str) -> BoxFuture<'_, Result<(), Error>> {
548            Box::pin(async { Ok(()) })
549        }
550
551        fn extend_auth_session(&self, _id: &str, _ttl: i64) -> BoxFuture<'_, Result<(), Error>> {
552            Box::pin(async { Ok(()) })
553        }
554    }
555
556    fn create_test_config() -> OAuthConfiguration {
557        use axum_extra::extract::cookie::Key;
558        OAuthConfiguration {
559            private_cookie_key: Key::from(&[0u8; 64]),
560            client_id: "test-client".to_string(),
561            client_secret: "test-secret".to_string(),
562            redirect_uri: "http://localhost:8080/auth/callback".to_string(),
563            authorization_endpoint: "http://localhost/auth".to_string(),
564            token_endpoint: "http://localhost/token".to_string(),
565            end_session_endpoint: None,
566            post_logout_redirect_uri: "/".to_string(),
567            scopes: "openid email".to_string(),
568            code_challenge_method: CodeChallengeMethod::S256,
569            session_max_age: 30,
570            token_max_age: Some(60),
571            custom_ca_cert: None,
572            base_path: "/auth".to_string(),
573        }
574    }
575
576    #[test]
577    fn test_default_base_path() {
578        let config = create_test_config();
579
580        assert_eq!(config.base_path, "/auth");
581    }
582
583    #[test]
584    fn test_custom_base_path() {
585        let mut config = create_test_config();
586        config.base_path = "/api/auth".to_string();
587
588        assert_eq!(config.base_path, "/api/auth");
589    }
590
591    #[test]
592    fn test_base_path_can_be_customized() {
593        let mut config = create_test_config();
594        config.base_path = "/oauth".to_string();
595
596        assert_eq!(config.base_path, "/oauth");
597    }
598
599    #[test]
600    fn test_base_path_with_different_values() {
601        let mut config1 = create_test_config();
602        config1.base_path = "/oauth".to_string();
603        assert_eq!(config1.base_path, "/oauth");
604
605        let mut config2 = create_test_config();
606        config2.base_path = "/api/v1/auth".to_string();
607        assert_eq!(config2.base_path, "/api/v1/auth");
608
609        let mut config3 = create_test_config();
610        config3.base_path = "/auth/custom".to_string();
611        assert_eq!(config3.base_path, "/auth/custom");
612    }
613}