1use std::sync::Arc;
3
4use axum::{
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use serde::{Deserialize, Serialize};
9
10use crate::auth::{
11 error::{AuthError, Result},
12 jwt::{Claims, JwtValidator},
13 session::SessionStore,
14};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AuthenticatedUser {
19 pub user_id: String,
21 pub claims: Claims,
23}
24
25impl AuthenticatedUser {
26 pub fn get_custom_claim(&self, key: &str) -> Option<&serde_json::Value> {
28 self.claims.get_custom(key)
29 }
30
31 pub fn has_role(&self, role: &str) -> bool {
33 if let Some(serde_json::Value::String(user_role)) = self.claims.get_custom("role") {
34 user_role == role
35 } else if let Some(serde_json::Value::Array(roles)) = self.claims.get_custom("roles") {
36 roles.iter().any(|r| {
37 if let serde_json::Value::String(r_str) = r {
38 r_str == role
39 } else {
40 false
41 }
42 })
43 } else {
44 false
45 }
46 }
47}
48
49pub struct AuthMiddleware {
51 validator: Arc<JwtValidator>,
52 _session_store: Arc<dyn SessionStore>,
53 public_key: Vec<u8>,
54 _optional: bool,
55}
56
57impl AuthMiddleware {
58 pub fn new(
66 validator: Arc<JwtValidator>,
67 session_store: Arc<dyn SessionStore>,
68 public_key: Vec<u8>,
69 optional: bool,
70 ) -> Self {
71 Self {
72 validator,
73 _session_store: session_store,
74 public_key,
75 _optional: optional,
76 }
77 }
78
79 pub async fn validate_token(&self, token: &str) -> Result<Claims> {
81 self.validator.validate(token, &self.public_key)
82 }
83}
84
85impl IntoResponse for AuthError {
86 fn into_response(self) -> Response {
87 let (status, error, message) = match self {
88 AuthError::TokenExpired => {
89 (StatusCode::UNAUTHORIZED, "token_expired", "Authentication token has expired")
90 },
91 AuthError::InvalidSignature => {
92 (StatusCode::UNAUTHORIZED, "invalid_signature", "Token signature is invalid")
93 },
94 AuthError::InvalidToken { ref reason } => {
95 (StatusCode::UNAUTHORIZED, "invalid_token", reason.as_str())
96 },
97 AuthError::TokenNotFound => {
98 (StatusCode::UNAUTHORIZED, "token_not_found", "Authentication token not found")
99 },
100 AuthError::SessionRevoked => {
101 (StatusCode::UNAUTHORIZED, "session_revoked", "Session has been revoked")
102 },
103 AuthError::Forbidden { ref message } => {
104 (StatusCode::FORBIDDEN, "forbidden", message.as_str())
105 },
106 _ => (
107 StatusCode::INTERNAL_SERVER_ERROR,
108 "auth_error",
109 "An authentication error occurred",
110 ),
111 };
112
113 let body = serde_json::json!({
114 "errors": [{
115 "message": message,
116 "extensions": {
117 "code": error
118 }
119 }]
120 });
121
122 (status, axum::Json(body)).into_response()
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn test_authenticated_user_clone() {
132 use std::collections::HashMap;
133
134 use crate::auth::Claims;
135
136 let claims = Claims {
137 sub: "user123".to_string(),
138 iat: 1000,
139 exp: 2000,
140 iss: "https://example.com".to_string(),
141 aud: vec!["api".to_string()],
142 extra: HashMap::new(),
143 };
144
145 let user = AuthenticatedUser {
146 user_id: "user123".to_string(),
147 claims,
148 };
149
150 let _cloned = user.clone();
151 assert_eq!(user.user_id, "user123");
152 }
153
154 #[test]
155 fn test_has_role_single_string() {
156 use std::collections::HashMap;
157
158 use crate::auth::Claims;
159
160 let mut claims = Claims {
161 sub: "user123".to_string(),
162 iat: 1000,
163 exp: 2000,
164 iss: "https://example.com".to_string(),
165 aud: vec!["api".to_string()],
166 extra: HashMap::new(),
167 };
168
169 claims.extra.insert("role".to_string(), serde_json::json!("admin"));
170
171 let user = AuthenticatedUser {
172 user_id: "user123".to_string(),
173 claims,
174 };
175
176 assert!(user.has_role("admin"));
177 assert!(!user.has_role("user"));
178 }
179
180 #[test]
181 fn test_has_role_array() {
182 use std::collections::HashMap;
183
184 use crate::auth::Claims;
185
186 let mut claims = Claims {
187 sub: "user123".to_string(),
188 iat: 1000,
189 exp: 2000,
190 iss: "https://example.com".to_string(),
191 aud: vec!["api".to_string()],
192 extra: HashMap::new(),
193 };
194
195 claims
196 .extra
197 .insert("roles".to_string(), serde_json::json!(["admin", "user", "editor"]));
198
199 let user = AuthenticatedUser {
200 user_id: "user123".to_string(),
201 claims,
202 };
203
204 assert!(user.has_role("admin"));
205 assert!(user.has_role("user"));
206 assert!(user.has_role("editor"));
207 assert!(!user.has_role("moderator"));
208 }
209
210 #[test]
211 fn test_get_custom_claim() {
212 use std::collections::HashMap;
213
214 use crate::auth::Claims;
215
216 let mut claims = Claims {
217 sub: "user123".to_string(),
218 iat: 1000,
219 exp: 2000,
220 iss: "https://example.com".to_string(),
221 aud: vec!["api".to_string()],
222 extra: HashMap::new(),
223 };
224
225 claims.extra.insert("org_id".to_string(), serde_json::json!("org_456"));
226
227 let user = AuthenticatedUser {
228 user_id: "user123".to_string(),
229 claims,
230 };
231
232 assert_eq!(user.get_custom_claim("org_id"), Some(&serde_json::json!("org_456")));
233 assert_eq!(user.get_custom_claim("nonexistent"), None);
234 }
235}