1use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
2use serde::Deserialize;
3use uuid::Uuid;
4
5use crate::db::Db;
6use crate::error::{AccessTokenError, AuthError};
7use crate::types::UserId;
8
9#[derive(Debug, Clone)]
15pub struct AccessTokenClaims {
16 pub sub: UserId,
17 pub scope: String,
18 pub iss: String,
19 pub aud: String,
20 pub exp: i64,
21 pub iat: i64,
22 pub email: String,
23 pub email_verified: bool,
24 pub username: Option<String>,
25 pub roles: Vec<String>,
26 pub permissions: Vec<String>,
27}
28
29#[derive(Debug, Deserialize)]
31struct RawAccessTokenClaims {
32 sub: String,
33 scope: String,
34 iss: String,
35 aud: String,
36 exp: i64,
37 iat: i64,
38 #[serde(default)]
39 email: String,
40 #[serde(default)]
41 email_verified: bool,
42 #[serde(default)]
43 username: Option<String>,
44 #[serde(default)]
45 roles: Vec<String>,
46 #[serde(default)]
47 permissions: Vec<String>,
48}
49
50pub fn has_scope(scope_string: &str, target: &str) -> bool {
54 scope_string.split(' ').any(|s| s == target)
55}
56
57impl Db {
58 pub async fn validate_access_token(
68 &self,
69 token: &str,
70 expected_issuer: &str,
71 ) -> Result<AccessTokenClaims, AuthError> {
72 let header = decode_header(token)
74 .map_err(|e| AuthError::AccessToken(AccessTokenError::MalformedToken(e.to_string())))?;
75
76 let kid_str = header.kid.ok_or_else(|| {
78 AuthError::AccessToken(AccessTokenError::MalformedToken("missing kid".into()))
79 })?;
80
81 let kid_uuid = Uuid::parse_str(&kid_str)
83 .map_err(|_| AuthError::AccessToken(AccessTokenError::UnknownKid(kid_str.clone())))?;
84 let kid_id = crate::types::SigningKeyId::from_uuid(kid_uuid);
85
86 let key = self.get_signing_key(kid_id).await.map_err(|e| match e {
88 AuthError::NotFound => AuthError::AccessToken(AccessTokenError::UnknownKid(kid_str)),
89 other => other,
90 })?;
91
92 let decoding_key = DecodingKey::from_rsa_pem(key.public_key_pem.as_bytes())
94 .map_err(|e| AuthError::SigningKey(e.to_string()))?;
95
96 let mut validation = Validation::new(Algorithm::RS256);
98 validation.set_issuer(&[expected_issuer]);
99 validation.validate_aud = false;
100 validation.leeway = 0;
101
102 let token_data = decode::<RawAccessTokenClaims>(token, &decoding_key, &validation)
104 .map_err(|e| {
105 let err = match e.kind() {
106 jsonwebtoken::errors::ErrorKind::ExpiredSignature => AccessTokenError::Expired,
107 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
108 AccessTokenError::InvalidSignature
109 }
110 jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
111 AccessTokenError::InvalidClaims("invalid issuer".into())
112 }
113 _ => AccessTokenError::InvalidClaims(e.to_string()),
114 };
115 AuthError::AccessToken(err)
116 })?;
117
118 let raw = token_data.claims;
119
120 let sub_uuid = Uuid::parse_str(&raw.sub).map_err(|_| {
122 AuthError::AccessToken(AccessTokenError::InvalidClaims("invalid sub".into()))
123 })?;
124
125 Ok(AccessTokenClaims {
126 sub: UserId::from_uuid(sub_uuid),
127 scope: raw.scope,
128 iss: raw.iss,
129 aud: raw.aud,
130 exp: raw.exp,
131 iat: raw.iat,
132 email: raw.email,
133 email_verified: raw.email_verified,
134 username: raw.username,
135 roles: raw.roles,
136 permissions: raw.permissions,
137 })
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use crate::signing_keys::decrypt_private_key;
145 use base64ct::Encoding as _;
146 use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
147 use serde::Serialize;
148 use sqlx::SqlitePool;
149 use sqlx::sqlite::SqliteConnectOptions;
150 use std::str::FromStr;
151 use uuid::Uuid;
152
153 const ENC_KEY: [u8; 32] = [0x42; 32];
154 const ISSUER: &str = "https://auth.example.com";
155
156 async fn test_db() -> Db {
157 let opts = SqliteConnectOptions::from_str("sqlite::memory:")
158 .unwrap()
159 .pragma("foreign_keys", "ON");
160 let pool = SqlitePool::connect_with(opts).await.unwrap();
161 Db::new(pool).await.unwrap()
162 }
163
164 #[derive(Serialize)]
165 struct TestClaims {
166 sub: String,
167 scope: String,
168 iss: String,
169 aud: String,
170 exp: i64,
171 iat: i64,
172 email: String,
173 email_verified: bool,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 username: Option<String>,
176 roles: Vec<String>,
177 permissions: Vec<String>,
178 }
179
180 async fn sign_test_jwt(
182 db: &Db,
183 sub: &str,
184 scope: &str,
185 issuer: &str,
186 exp_offset_secs: i64,
187 ) -> (String, crate::types::SigningKeyId) {
188 let key = db.create_signing_key(&ENC_KEY).await.unwrap();
189 db.activate_signing_key(key.id).await.unwrap();
190
191 let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
192 let encoding_key = EncodingKey::from_rsa_pem(pem.as_bytes()).unwrap();
193
194 let now = chrono::Utc::now().timestamp();
195 let claims = TestClaims {
196 sub: sub.to_string(),
197 scope: scope.to_string(),
198 iss: issuer.to_string(),
199 aud: "ath_test_client".to_string(),
200 exp: now + exp_offset_secs,
201 iat: now,
202 email: "test@example.com".to_string(),
203 email_verified: true,
204 username: Some("testuser".to_string()),
205 roles: vec!["admin".to_string()],
206 permissions: vec!["posts:write".to_string()],
207 };
208
209 let mut header = Header::new(Algorithm::RS256);
210 header.kid = Some(key.id.to_string());
211
212 let token = encode(&header, &claims, &encoding_key).unwrap();
213 (token, key.id)
214 }
215
216 #[tokio::test]
217 async fn validate_access_token_valid() {
218 let db = test_db().await;
219 let sub = UserId::new().to_string();
220 let (token, _) = sign_test_jwt(&db, &sub, "openid profile", ISSUER, 300).await;
221
222 let claims = db.validate_access_token(&token, ISSUER).await.unwrap();
223 assert_eq!(claims.sub.to_string(), sub);
224 assert_eq!(claims.scope, "openid profile");
225 assert_eq!(claims.iss, ISSUER);
226 }
227
228 #[tokio::test]
229 async fn validate_access_token_expired() {
230 let db = test_db().await;
231 let sub = UserId::new().to_string();
232 let (token, _) = sign_test_jwt(&db, &sub, "openid", ISSUER, -60).await;
233
234 let err = db.validate_access_token(&token, ISSUER).await.unwrap_err();
235 assert!(matches!(
236 err,
237 AuthError::AccessToken(AccessTokenError::Expired)
238 ));
239 }
240
241 #[tokio::test]
242 async fn validate_access_token_wrong_issuer() {
243 let db = test_db().await;
244 let sub = UserId::new().to_string();
245 let (token, _) = sign_test_jwt(&db, &sub, "openid", "https://wrong.example.com", 300).await;
246
247 let err = db.validate_access_token(&token, ISSUER).await.unwrap_err();
248 assert!(matches!(
249 err,
250 AuthError::AccessToken(AccessTokenError::InvalidClaims(_))
251 ));
252 if let AuthError::AccessToken(AccessTokenError::InvalidClaims(msg)) = err {
253 assert!(msg.contains("issuer"));
254 }
255 }
256
257 #[tokio::test]
258 async fn validate_access_token_unknown_kid() {
259 let db = test_db().await;
260 let sub = UserId::new().to_string();
261 let (token, _) = sign_test_jwt(&db, &sub, "openid", ISSUER, 300).await;
262
263 let random_kid = Uuid::new_v4().to_string();
265 let parts: Vec<&str> = token.splitn(3, '.').collect();
266 let fake_header = base64ct::Base64UrlUnpadded::encode_string(
267 format!(r#"{{"alg":"RS256","kid":"{random_kid}","typ":"JWT"}}"#).as_bytes(),
268 );
269 let tampered = format!("{}.{}.{}", fake_header, parts[1], parts[2]);
270
271 let err = db
272 .validate_access_token(&tampered, ISSUER)
273 .await
274 .unwrap_err();
275 assert!(matches!(
276 err,
277 AuthError::AccessToken(AccessTokenError::UnknownKid(_))
278 ));
279 }
280
281 #[tokio::test]
282 async fn validate_access_token_bad_signature() {
283 let db = test_db().await;
284 let sub = UserId::new().to_string();
285
286 let key1 = db.create_signing_key(&ENC_KEY).await.unwrap();
288 db.activate_signing_key(key1.id).await.unwrap();
289
290 let key2 = db.create_signing_key(&ENC_KEY).await.unwrap();
292 db.activate_signing_key(key2.id).await.unwrap();
293
294 let pem2 = decrypt_private_key(&key2, &ENC_KEY).unwrap();
295 let encoding_key2 = EncodingKey::from_rsa_pem(pem2.as_bytes()).unwrap();
296
297 let now = chrono::Utc::now().timestamp();
298 let claims = TestClaims {
299 sub: sub.clone(),
300 scope: "openid".to_string(),
301 iss: ISSUER.to_string(),
302 aud: "ath_test_client".to_string(),
303 exp: now + 300,
304 iat: now,
305 email: "test@example.com".to_string(),
306 email_verified: true,
307 username: None,
308 roles: vec![],
309 permissions: vec![],
310 };
311
312 let mut header = Header::new(Algorithm::RS256);
314 header.kid = Some(key1.id.to_string());
315 let token = encode(&header, &claims, &encoding_key2).unwrap();
316
317 let err = db.validate_access_token(&token, ISSUER).await.unwrap_err();
318 assert!(matches!(
319 err,
320 AuthError::AccessToken(AccessTokenError::InvalidSignature)
321 ));
322 }
323
324 #[tokio::test]
325 async fn validate_access_token_missing_kid() {
326 let db = test_db().await;
327 let sub = UserId::new().to_string();
328 let key = db.create_signing_key(&ENC_KEY).await.unwrap();
329 db.activate_signing_key(key.id).await.unwrap();
330
331 let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
332 let encoding_key = EncodingKey::from_rsa_pem(pem.as_bytes()).unwrap();
333
334 let now = chrono::Utc::now().timestamp();
335 let claims = TestClaims {
336 sub: sub.clone(),
337 scope: "openid".to_string(),
338 iss: ISSUER.to_string(),
339 aud: "ath_test_client".to_string(),
340 exp: now + 300,
341 iat: now,
342 email: "test@example.com".to_string(),
343 email_verified: true,
344 username: None,
345 roles: vec![],
346 permissions: vec![],
347 };
348
349 let header = Header::new(Algorithm::RS256);
351 let token = encode(&header, &claims, &encoding_key).unwrap();
352
353 let err = db.validate_access_token(&token, ISSUER).await.unwrap_err();
354 assert!(matches!(
355 err,
356 AuthError::AccessToken(AccessTokenError::MalformedToken(_))
357 ));
358 }
359
360 #[tokio::test]
361 async fn has_scope_present() {
362 assert!(has_scope("openid profile email", "profile"));
363 }
364
365 #[tokio::test]
366 async fn has_scope_absent() {
367 assert!(!has_scope("openid profile", "email"));
368 }
369
370 #[tokio::test]
371 async fn has_scope_no_partial_match() {
372 assert!(!has_scope("openid profile_extended", "profile"));
373 }
374}