axum_oidc_client/auth.rs
1//! Core authentication module for OAuth2/OIDC with PKCE support.
2//!
3//! This module provides the main authentication layer and configuration types
4//! for integrating OAuth2 authentication into Axum applications.
5//!
6//! # Main Types
7//!
8//! - [`AuthLayer`] - Tower layer for adding authentication to your Axum app
9//! - [`OAuthConfiguration`] - Configuration for OAuth2 endpoints and credentials
10//! - [`CodeChallengeMethod`] - PKCE code challenge method (S256 or Plain)
11//! - [`LogoutHandler`] - Trait for implementing custom logout behavior
12//!
13//! # Examples
14//!
15//! ```rust,no_run
16//! use axum::{Router, routing::get};
17//! use axum_oidc_client::{
18//! auth::{AuthLayer, CodeChallengeMethod},
19//! auth_builder::OAuthConfigurationBuilder,
20//! auth_cache::AuthCache,
21//! logout::handle_default_logout::DefaultLogoutHandler,
22//! };
23//! use std::sync::Arc;
24//!
25//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
26//! let config = OAuthConfigurationBuilder::default()
27//! .with_authorization_endpoint("https://provider.com/oauth/authorize")
28//! .with_token_endpoint("https://provider.com/oauth/token")
29//! .with_client_id("client-id")
30//! .with_client_secret("client-secret")
31//! .with_redirect_uri("http://localhost:8080/auth/callback")
32//! .with_private_cookie_key("secret-key")
33//! .with_scopes(vec!["openid", "email"])
34//! .build()?;
35//!
36//! # #[cfg(feature = "redis")]
37//! let cache: Arc<dyn AuthCache + Send + Sync> = Arc::new(
38//! axum_oidc_client::redis::AuthCache::new("redis://127.0.0.1/", 3600)
39//! );
40//!
41//! let logout_handler = Arc::new(DefaultLogoutHandler);
42//!
43//! let app = Router::new()
44//! .route("/", get(|| async { "Hello!" }))
45//! .layer(AuthLayer::new(Arc::new(config), cache, logout_handler));
46//! # Ok(())
47//! # }
48//! ```
49
50use axum::{
51 extract::Request,
52 response::{IntoResponse, Redirect, Response},
53};
54use axum_extra::extract::{cookie::Key, PrivateCookieJar};
55use chrono::{Duration, Local};
56use futures_util::future::BoxFuture;
57use http::request::Parts;
58use pkce_std::Method;
59use reqwest::Client;
60
61use std::{
62 fmt::Display,
63 sync::Arc,
64 task::{Context, Poll},
65};
66use tower::{Layer, Service};
67
68use crate::{
69 auth_cache::AuthCache,
70 auth_router::{
71 handle_auth::handle_auth,
72 handle_callback::{handle_callback, AccessTokenResponse},
73 handle_default::handle_default,
74 },
75 auth_session::AuthSession,
76 errors::Error,
77};
78
79/// PKCE code challenge method.
80///
81/// Defines how the code verifier is transformed into a code challenge
82/// during the OAuth2 PKCE flow.
83///
84/// # Variants
85///
86/// - `S256` - SHA-256 hash of the code verifier (recommended)
87/// - `Plain` - Plain text code verifier (not recommended for production)
88///
89/// # Examples
90///
91/// ```
92/// use axum_oidc_client::auth::CodeChallengeMethod;
93///
94/// let method = CodeChallengeMethod::S256;
95/// assert_eq!(method.to_string(), "S256");
96///
97/// let plain = CodeChallengeMethod::Plain;
98/// assert_eq!(plain.to_string(), "plain");
99/// ```
100#[derive(Debug, Clone, PartialEq, Default)]
101pub enum CodeChallengeMethod {
102 /// SHA-256 hashing method (recommended, default)
103 #[default]
104 S256,
105 /// Plain text method (not recommended for production)
106 Plain,
107}
108
109impl Display for CodeChallengeMethod {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 match self {
112 CodeChallengeMethod::S256 => write!(f, "S256"),
113 CodeChallengeMethod::Plain => write!(f, "plain"),
114 }
115 }
116}
117
118/// Calculate token expiration time based on expires_in and token_max_age.
119///
120/// This function determines when a token should be considered expired,
121/// taking into account both the provider's expiration time and the
122/// application's configured maximum token age.
123///
124/// # Arguments
125///
126/// * `expires_in` - Seconds until token expiration from the OAuth provider
127/// * `token_max_age` - Maximum allowed token age in seconds from configuration
128///
129/// # Returns
130///
131/// The current time plus the calculated expiration duration.
132/// Returns the maximum of:
133/// - 1 second (minimum)
134/// - The maximum of (expires_in - 1) and token_max_age
135///
136/// # Examples
137///
138/// ```ignore
139/// // Token expires in 3600 seconds, max age is 1800
140/// let expiration = calculate_token_expiration(3600, 1800);
141/// // Uses 3599 seconds (expires_in - 1)
142///
143/// // Token expires in 300 seconds, max age is 1800
144/// let expiration = calculate_token_expiration(300, 1800);
145/// // Uses 1800 seconds (token_max_age)
146/// ```
147pub fn calculate_token_expiration(
148 expires_in: i64,
149 token_max_age: Option<i64>,
150) -> chrono::DateTime<Local> {
151 Local::now()
152 + Duration::seconds(std::cmp::max(
153 1,
154 std::cmp::min(expires_in - 1, token_max_age.unwrap_or(0)),
155 ))
156}
157
158impl AuthSession {
159 pub fn new(response: &AccessTokenResponse, conf: &OAuthConfiguration) -> Self {
160 AuthSession {
161 id_token: response.id_token.to_owned(),
162 access_token: response.access_token.to_owned(),
163 token_type: response.token_type.to_owned(),
164 refresh_token: response.refresh_token.to_owned(),
165 scope: response.scope.to_owned(),
166 expires: calculate_token_expiration(response.expires_in, conf.token_max_age),
167 }
168 }
169}
170
171impl From<CodeChallengeMethod> for Method {
172 fn from(method: CodeChallengeMethod) -> Self {
173 match method {
174 CodeChallengeMethod::S256 => Method::Sha256,
175 CodeChallengeMethod::Plain => Method::Plain,
176 }
177 }
178}
179
180/// OAuth2/OIDC configuration.
181///
182/// Contains all necessary configuration for OAuth2 authentication including
183/// endpoints, credentials, and session management settings.
184///
185/// # Fields
186///
187/// * `private_cookie_key` - Secret key for encrypting session cookies
188/// * `client_id` - OAuth2 client identifier
189/// * `client_secret` - OAuth2 client secret
190/// * `redirect_uri` - URI where the provider redirects after authentication
191/// * `authorization_endpoint` - OAuth2 authorization endpoint URL
192/// * `token_endpoint` - OAuth2 token endpoint URL
193/// * `end_session_endpoint` - Optional OIDC end session endpoint URL
194/// * `post_logout_redirect_uri` - URI to redirect to after logout
195/// * `scopes` - Space-separated list of OAuth2 scopes
196/// * `code_challenge_method` - PKCE code challenge method
197/// * `custom_ca_cert` - Optional path to custom CA certificate
198/// * `session_max_age` - Maximum session age in seconds
199/// * `token_max_age` - Optional maximum token age in seconds
200///
201/// # Examples
202///
203/// Use [`crate::auth_builder::OAuthConfigurationBuilder`] to construct:
204///
205/// ```rust,no_run
206/// use axum_oidc_client::auth_builder::OAuthConfigurationBuilder;
207///
208/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
209/// let config = OAuthConfigurationBuilder::default()
210/// .with_authorization_endpoint("https://provider.com/oauth/authorize")
211/// .with_token_endpoint("https://provider.com/oauth/token")
212/// .with_client_id("my-client-id")
213/// .with_client_secret("my-client-secret")
214/// .with_redirect_uri("http://localhost:8080/auth/callback")
215/// .with_private_cookie_key("secret-key-at-least-32-bytes")
216/// .with_scopes(vec!["openid", "email", "profile"])
217/// .build()?;
218/// # Ok(())
219/// # }
220/// ```
221#[derive(Clone)]
222pub struct OAuthConfiguration {
223 /// Secret key for encrypting session cookies
224 pub private_cookie_key: Key,
225 /// OAuth2 client identifier
226 pub client_id: String,
227 /// OAuth2 client secret
228 pub client_secret: String,
229 /// Redirect URI for OAuth2 callback
230 pub redirect_uri: String,
231 /// OAuth2 authorization endpoint URL
232 pub authorization_endpoint: String,
233 /// OAuth2 token endpoint URL
234 pub token_endpoint: String,
235 /// Optional OIDC end session endpoint URL
236 pub end_session_endpoint: Option<String>,
237 /// URI to redirect to after logout
238 pub post_logout_redirect_uri: String,
239 /// Space-separated list of OAuth2 scopes
240 pub scopes: String,
241 /// PKCE code challenge method
242 pub code_challenge_method: CodeChallengeMethod,
243 /// Optional path to custom CA certificate file
244 pub custom_ca_cert: Option<String>,
245 /// Maximum session age in seconds
246 pub session_max_age: i64,
247 /// Optional maximum token age in seconds
248 pub token_max_age: Option<i64>,
249 /// Base path for authentication routes (default: "/auth")
250 pub base_path: String,
251}
252
253/// Session cookie key name.
254///
255/// This constant defines the name of the cookie used to store the session identifier.
256pub const SESSION_KEY: &str = "AUTH_SESSION";
257
258/// Trait for handling logout behavior.
259///
260/// Implement this trait to customize the logout process for your application.
261/// The library provides two built-in implementations:
262/// - [`crate::logout::handle_default_logout::DefaultLogoutHandler`] - Simple logout with session cleanup
263/// - [`crate::logout::handle_oidc_logout::OidcLogoutHandler`] - OIDC logout with provider notification
264///
265/// # Examples
266///
267/// ## Using the Default Handler
268///
269/// ```rust,no_run
270/// use axum_oidc_client::logout::handle_default_logout::DefaultLogoutHandler;
271/// use std::sync::Arc;
272///
273/// let logout_handler = Arc::new(DefaultLogoutHandler);
274/// ```
275///
276/// ## Using the OIDC Handler
277///
278/// ```rust,no_run
279/// use axum_oidc_client::logout::handle_oidc_logout::OidcLogoutHandler;
280/// use std::sync::Arc;
281///
282/// let logout_handler = Arc::new(
283/// OidcLogoutHandler::new("https://provider.com/oauth/logout")
284/// );
285/// ```
286///
287/// ## Custom Implementation
288///
289/// ```rust,no_run
290/// use axum_oidc_client::auth::{LogoutHandler, OAuthConfiguration};
291/// use axum_oidc_client::auth_cache::AuthCache;
292/// use axum_oidc_client::errors::Error;
293/// use axum::response::Response;
294/// use http::request::Parts;
295/// use std::sync::Arc;
296/// use futures_util::future::BoxFuture;
297///
298/// struct CustomLogoutHandler;
299///
300/// impl LogoutHandler for CustomLogoutHandler {
301/// fn handle_logout<'a>(
302/// &'a self,
303/// parts: &'a mut Parts,
304/// configuration: Arc<OAuthConfiguration>,
305/// cache: Arc<dyn AuthCache + Send + Sync>,
306/// ) -> BoxFuture<'a, Result<Response, Error>> {
307/// Box::pin(async move {
308/// // Custom logout logic here
309/// # unimplemented!()
310/// })
311/// }
312/// }
313/// ```
314pub trait LogoutHandler: Send + Sync {
315 /// Handle the logout request.
316 ///
317 /// This method is called when a user requests to log out. Implementations should:
318 /// 1. Remove the session cookie
319 /// 2. Invalidate the session in the cache
320 /// 3. Optionally notify the OAuth provider
321 /// 4. Redirect the user appropriately
322 ///
323 /// # Arguments
324 ///
325 /// * `parts` - The request parts containing headers, extensions, and query parameters
326 /// * `configuration` - The OAuth configuration
327 /// * `cache` - The authentication cache for session storage
328 ///
329 /// # Returns
330 ///
331 /// A future that resolves to either:
332 /// * `Ok(Response)` - A successful logout response (typically a redirect)
333 /// * `Err(Error)` - An error if logout fails
334 ///
335 /// # Returns
336 /// A response that handles the logout (typically a redirect or HTML page)
337 fn handle_logout<'a>(
338 &'a self,
339 parts: &'a mut Parts,
340 configuration: Arc<OAuthConfiguration>,
341 cache: Arc<dyn AuthCache + Send + Sync>,
342 ) -> BoxFuture<'a, Result<Response, Error>>;
343}
344
345#[derive(Clone)]
346pub struct AuthLayer {
347 oauth_client: Arc<Client>,
348 configuration: Arc<OAuthConfiguration>,
349 cache: Arc<dyn AuthCache + Send + Sync>,
350 logout_handler: Arc<dyn LogoutHandler>,
351}
352
353impl AuthLayer {
354 pub fn new(
355 configuration: Arc<OAuthConfiguration>,
356 cache: Arc<dyn AuthCache + Send + Sync>,
357 logout_handler: Arc<dyn LogoutHandler>,
358 ) -> Self {
359 let oauth_client = Arc::new(
360 match configuration.custom_ca_cert.clone() {
361 Some(custom_ca_cert) => {
362 let cert = std::fs::read(custom_ca_cert).unwrap();
363 let cert = reqwest::Certificate::from_pem(&cert).unwrap();
364 reqwest::ClientBuilder::new()
365 .add_root_certificate(cert)
366 .use_rustls_tls()
367 }
368 None => reqwest::ClientBuilder::new(),
369 }
370 .build()
371 .unwrap(),
372 );
373 Self {
374 configuration,
375 cache,
376 oauth_client,
377 logout_handler,
378 }
379 }
380
381 /// Create a new AuthLayer with a custom logout handler
382 ///
383 /// This is an alias for `new()` and is provided for backwards compatibility.
384 pub fn with_logout_handler(
385 configuration: Arc<OAuthConfiguration>,
386 cache: Arc<dyn AuthCache + Send + Sync>,
387 logout_handler: Arc<dyn LogoutHandler>,
388 ) -> Self {
389 Self::new(configuration, cache, logout_handler)
390 }
391}
392
393impl<S> Layer<S> for AuthLayer {
394 type Service = AuthMiddleware<S>;
395
396 fn layer(&self, inner: S) -> Self::Service {
397 AuthMiddleware {
398 inner,
399 configuration: self.configuration.clone(),
400 cache: self.cache.clone(),
401 oauth_client: self.oauth_client.clone(),
402 logout_handler: self.logout_handler.clone(),
403 }
404 }
405}
406
407#[derive(Clone)]
408pub struct AuthMiddleware<S> {
409 inner: S,
410 configuration: Arc<OAuthConfiguration>,
411 cache: Arc<dyn AuthCache + Send + Sync>,
412 oauth_client: Arc<Client>,
413 logout_handler: Arc<dyn LogoutHandler>,
414}
415
416impl<S> Service<Request> for AuthMiddleware<S>
417where
418 S: Service<Request, Response = Response> + Send + 'static,
419 S::Future: Send + 'static,
420{
421 type Response = S::Response;
422 type Error = S::Error;
423
424 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
425
426 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
427 self.inner.poll_ready(cx)
428 }
429
430 fn call(&mut self, mut request: Request) -> Self::Future {
431 let OAuthConfiguration {
432 private_cookie_key, ..
433 } = self.configuration.as_ref();
434 let headers = request.headers().clone();
435 let uri = request.uri().clone();
436 let path = uri.path().to_string();
437 let jar = PrivateCookieJar::from_headers(&headers, private_cookie_key.to_owned());
438
439 let cache = self.cache.clone();
440 let configuration = self.configuration.clone();
441 let client = self.oauth_client.clone();
442
443 // Add extensions to request for extractors
444 request.extensions_mut().insert(cache.clone());
445 request.extensions_mut().insert(configuration.clone());
446 request.extensions_mut().insert(client.clone());
447
448 let session_id = jar
449 .get(SESSION_KEY)
450 .map(|cookie| cookie.value().to_string());
451
452 // Build the auth routes dynamically based on base_path from configuration
453 let base_path = &configuration.base_path;
454 let auth_route = base_path.clone();
455 let callback_route = format!("{}/callback", base_path);
456 let logout_route = format!("{}/logout", base_path);
457
458 match path.as_str() {
459 p if p == auth_route => Box::pin(async move {
460 match handle_auth(configuration, cache).await {
461 Ok(response) => Ok(response),
462 Err(err) => Ok(err.into_response()),
463 }
464 }),
465 p if p == callback_route => {
466 let (mut parts, _) = request.into_parts();
467 Box::pin(async move {
468 match handle_callback(&mut parts, uri).await {
469 Ok(response) => Ok(response),
470 Err(err) => match err {
471 Error::MissingCodeVerifier => {
472 Ok((jar, Redirect::temporary("/MissingCodeVerifier"))
473 .into_response())
474 }
475 _ => Ok(err.into_response()),
476 },
477 }
478 })
479 }
480 p if p == logout_route => {
481 let (mut parts, _) = request.into_parts();
482 let logout_handler = self.logout_handler.clone();
483 Box::pin(async move {
484 match logout_handler
485 .handle_logout(&mut parts, configuration, cache)
486 .await
487 {
488 Ok(response) => Ok(response),
489 Err(err) => Ok(err.into_response()),
490 }
491 })
492 }
493 _ => {
494 let future = self.inner.call(request);
495 Box::pin(async move {
496 handle_default(configuration, cache, jar, session_id, future).await
497 })
498 }
499 }
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use crate::auth_session::AuthSession;
507
508 // Mock cache for testing
509 #[allow(dead_code)]
510 struct MockCache;
511
512 impl AuthCache for MockCache {
513 fn get_code_verifier(
514 &self,
515 _challenge_state: &str,
516 ) -> BoxFuture<'_, Result<Option<String>, Error>> {
517 Box::pin(async { Ok(None) })
518 }
519
520 fn set_code_verifier(
521 &self,
522 _challenge_state: &str,
523 _code_verifier: &str,
524 ) -> BoxFuture<'_, Result<(), Error>> {
525 Box::pin(async { Ok(()) })
526 }
527
528 fn invalidate_code_verifier(
529 &self,
530 _challenge_state: &str,
531 ) -> BoxFuture<'_, Result<(), Error>> {
532 Box::pin(async { Ok(()) })
533 }
534
535 fn get_auth_session(&self, _id: &str) -> BoxFuture<'_, Result<Option<AuthSession>, Error>> {
536 Box::pin(async { Ok(None) })
537 }
538
539 fn set_auth_session(
540 &self,
541 _id: &str,
542 _session: AuthSession,
543 ) -> BoxFuture<'_, Result<(), Error>> {
544 Box::pin(async { Ok(()) })
545 }
546
547 fn invalidate_auth_session(&self, _id: &str) -> BoxFuture<'_, Result<(), Error>> {
548 Box::pin(async { Ok(()) })
549 }
550
551 fn extend_auth_session(&self, _id: &str, _ttl: i64) -> BoxFuture<'_, Result<(), Error>> {
552 Box::pin(async { Ok(()) })
553 }
554 }
555
556 fn create_test_config() -> OAuthConfiguration {
557 use axum_extra::extract::cookie::Key;
558 OAuthConfiguration {
559 private_cookie_key: Key::from(&[0u8; 64]),
560 client_id: "test-client".to_string(),
561 client_secret: "test-secret".to_string(),
562 redirect_uri: "http://localhost:8080/auth/callback".to_string(),
563 authorization_endpoint: "http://localhost/auth".to_string(),
564 token_endpoint: "http://localhost/token".to_string(),
565 end_session_endpoint: None,
566 post_logout_redirect_uri: "/".to_string(),
567 scopes: "openid email".to_string(),
568 code_challenge_method: CodeChallengeMethod::S256,
569 session_max_age: 30,
570 token_max_age: Some(60),
571 custom_ca_cert: None,
572 base_path: "/auth".to_string(),
573 }
574 }
575
576 #[test]
577 fn test_default_base_path() {
578 let config = create_test_config();
579
580 assert_eq!(config.base_path, "/auth");
581 }
582
583 #[test]
584 fn test_custom_base_path() {
585 let mut config = create_test_config();
586 config.base_path = "/api/auth".to_string();
587
588 assert_eq!(config.base_path, "/api/auth");
589 }
590
591 #[test]
592 fn test_base_path_can_be_customized() {
593 let mut config = create_test_config();
594 config.base_path = "/oauth".to_string();
595
596 assert_eq!(config.base_path, "/oauth");
597 }
598
599 #[test]
600 fn test_base_path_with_different_values() {
601 let mut config1 = create_test_config();
602 config1.base_path = "/oauth".to_string();
603 assert_eq!(config1.base_path, "/oauth");
604
605 let mut config2 = create_test_config();
606 config2.base_path = "/api/v1/auth".to_string();
607 assert_eq!(config2.base_path, "/api/v1/auth");
608
609 let mut config3 = create_test_config();
610 config3.base_path = "/auth/custom".to_string();
611 assert_eq!(config3.base_path, "/auth/custom");
612 }
613}