fraiseql_server/middleware/
oidc_auth.rs1use std::sync::Arc;
7
8use axum::{
9 body::Body,
10 extract::State,
11 http::{Request, StatusCode, header},
12 middleware::Next,
13 response::{IntoResponse, Response},
14};
15use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
16
17#[derive(Clone)]
19pub struct OidcAuthState {
20 pub validator: Arc<OidcValidator>,
22}
23
24impl OidcAuthState {
25 #[must_use]
27 pub fn new(validator: Arc<OidcValidator>) -> Self {
28 Self { validator }
29 }
30}
31
32#[derive(Clone, Debug)]
37pub struct AuthUser(pub AuthenticatedUser);
38
39pub async fn oidc_auth_middleware(
61 State(auth_state): State<OidcAuthState>,
62 mut request: Request<Body>,
63 next: Next,
64) -> Response {
65 let auth_header = request
67 .headers()
68 .get(header::AUTHORIZATION)
69 .and_then(|value| value.to_str().ok());
70
71 match auth_header {
72 None => {
73 if auth_state.validator.is_required() {
75 tracing::debug!("Authentication required but no Authorization header");
76 return (
77 StatusCode::UNAUTHORIZED,
78 [(
79 header::WWW_AUTHENTICATE,
80 format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
81 )],
82 "Authentication required",
83 )
84 .into_response();
85 }
86 next.run(request).await
88 },
89 Some(header_value) => {
90 if !header_value.starts_with("Bearer ") {
92 tracing::debug!("Invalid Authorization header format");
93 return (
94 StatusCode::UNAUTHORIZED,
95 [(header::WWW_AUTHENTICATE, "Bearer error=\"invalid_request\"".to_string())],
96 "Invalid Authorization header format",
97 )
98 .into_response();
99 }
100
101 let token = &header_value[7..];
102
103 match auth_state.validator.validate_token(token).await {
105 Ok(user) => {
106 tracing::debug!(
107 user_id = %user.user_id,
108 scopes = ?user.scopes,
109 "User authenticated successfully"
110 );
111 request.extensions_mut().insert(AuthUser(user));
113 next.run(request).await
114 },
115 Err(e) => {
116 tracing::debug!(error = %e, "Token validation failed");
117 let error_description = match &e {
118 fraiseql_core::security::SecurityError::TokenExpired { .. } => {
119 "Bearer error=\"invalid_token\", error_description=\"Token has expired\""
120 },
121 fraiseql_core::security::SecurityError::InvalidToken => {
122 "Bearer error=\"invalid_token\", error_description=\"Token is invalid\""
123 },
124 _ => "Bearer error=\"invalid_token\"",
125 };
126 (
127 StatusCode::UNAUTHORIZED,
128 [(header::WWW_AUTHENTICATE, error_description.to_string())],
129 "Invalid or expired token",
130 )
131 .into_response()
132 },
133 }
134 },
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_auth_user_clone() {
144 use chrono::Utc;
145
146 let user = AuthenticatedUser {
147 user_id: "user123".to_string(),
148 scopes: vec!["read".to_string()],
149 expires_at: Utc::now(),
150 };
151
152 let auth_user = AuthUser(user.clone());
153 let cloned = auth_user.clone();
154
155 assert_eq!(auth_user.0.user_id, cloned.0.user_id);
156 }
157
158 #[test]
159 fn test_oidc_auth_state_clone() {
160 fn assert_clone<T: Clone>() {}
163 assert_clone::<OidcAuthState>();
164 }
165}