fraiseql_server/middleware/
oidc_auth.rs1use std::sync::Arc;
7
8use axum::{
9 body::Body,
10 extract::State,
11 http::{Request, StatusCode, header},
12 middleware::Next,
13 response::{IntoResponse, Response},
14};
15use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
16
17#[derive(Clone)]
19pub struct OidcAuthState {
20 pub validator: Arc<OidcValidator>,
22}
23
24impl OidcAuthState {
25 #[must_use]
27 pub const fn new(validator: Arc<OidcValidator>) -> Self {
28 Self { validator }
29 }
30}
31
32#[derive(Clone, Debug)]
37pub struct AuthUser(pub AuthenticatedUser);
38
39#[allow(clippy::cognitive_complexity)] pub async fn oidc_auth_middleware(
63 State(auth_state): State<OidcAuthState>,
64 mut request: Request<Body>,
65 next: Next,
66) -> Response {
67 let auth_header = request
69 .headers()
70 .get(header::AUTHORIZATION)
71 .and_then(|value| value.to_str().ok());
72
73 match auth_header {
74 None => {
75 if auth_state.validator.is_required() {
77 tracing::debug!("Authentication required but no Authorization header");
78 return (
79 StatusCode::UNAUTHORIZED,
80 [(
81 header::WWW_AUTHENTICATE,
82 format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
83 )],
84 "Authentication required",
85 )
86 .into_response();
87 }
88 next.run(request).await
90 },
91 Some(header_value) => {
92 if !header_value.starts_with("Bearer ") {
94 tracing::debug!("Invalid Authorization header format");
95 return (
96 StatusCode::UNAUTHORIZED,
97 [(header::WWW_AUTHENTICATE, "Bearer error=\"invalid_request\"".to_string())],
98 "Invalid Authorization header format",
99 )
100 .into_response();
101 }
102
103 let token = &header_value[7..];
104
105 match auth_state.validator.validate_token(token).await {
107 Ok(user) => {
108 tracing::debug!(
109 user_id = %user.user_id,
110 scopes = ?user.scopes,
111 "User authenticated successfully"
112 );
113 request.extensions_mut().insert(AuthUser(user));
115 next.run(request).await
116 },
117 Err(e) => {
118 tracing::debug!(error = %e, "Token validation failed");
119 let (www_authenticate, body) = match &e {
120 fraiseql_core::security::SecurityError::TokenExpired { .. } => (
121 "Bearer error=\"invalid_token\", error_description=\"Token has expired\"",
122 "Token has expired",
123 ),
124 fraiseql_core::security::SecurityError::InvalidToken => (
125 "Bearer error=\"invalid_token\", error_description=\"Token is invalid\"",
126 "Token is invalid",
127 ),
128 _ => ("Bearer error=\"invalid_token\"", "Invalid or expired token"),
129 };
130 (
131 StatusCode::UNAUTHORIZED,
132 [(header::WWW_AUTHENTICATE, www_authenticate.to_string())],
133 body,
134 )
135 .into_response()
136 },
137 }
138 },
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_auth_user_clone() {
148 use chrono::Utc;
149
150 let user = AuthenticatedUser {
151 user_id: "user123".to_string(),
152 scopes: vec!["read".to_string()],
153 expires_at: Utc::now(),
154 };
155
156 let auth_user = AuthUser(user);
157 let cloned = auth_user.clone();
158
159 assert_eq!(auth_user.0.user_id, cloned.0.user_id);
160 }
161
162 #[test]
163 fn test_oidc_auth_state_clone() {
164 fn assert_clone<T: Clone>() {}
167 assert_clone::<OidcAuthState>();
168 }
169}