Skip to main content

allowthem_core/
access_tokens.rs

1use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
2use serde::Deserialize;
3use uuid::Uuid;
4
5use crate::db::Db;
6use crate::error::{AccessTokenError, AuthError};
7use crate::types::UserId;
8
9/// Validated claims extracted from an RS256-signed access token.
10///
11/// The `sub` value is `User.id.to_string()` and MUST remain identical
12/// to the `sub` claim in ID tokens issued by the token endpoint (M41).
13/// OIDC Core Section 5.3.4 requires this consistency.
14#[derive(Debug, Clone)]
15pub struct AccessTokenClaims {
16    pub sub: UserId,
17    pub scope: String,
18    pub iss: String,
19    pub aud: String,
20    pub exp: i64,
21    pub iat: i64,
22    pub email: String,
23    pub email_verified: bool,
24    pub username: Option<String>,
25    pub roles: Vec<String>,
26    pub permissions: Vec<String>,
27}
28
29/// Raw claims for `jsonwebtoken::decode()`. Private — callers use `AccessTokenClaims`.
30#[derive(Debug, Deserialize)]
31struct RawAccessTokenClaims {
32    sub: String,
33    scope: String,
34    iss: String,
35    aud: String,
36    exp: i64,
37    iat: i64,
38    #[serde(default)]
39    email: String,
40    #[serde(default)]
41    email_verified: bool,
42    #[serde(default)]
43    username: Option<String>,
44    #[serde(default)]
45    roles: Vec<String>,
46    #[serde(default)]
47    permissions: Vec<String>,
48}
49
50/// Check if a space-delimited scope string contains a specific scope.
51///
52/// This utility may move to `authorization.rs` when that module exists (M39).
53pub fn has_scope(scope_string: &str, target: &str) -> bool {
54    scope_string.split(' ').any(|s| s == target)
55}
56
57impl Db {
58    /// Validate an RS256-signed access token JWT.
59    ///
60    /// Steps:
61    /// 1. Decode the JWT header to extract `kid`.
62    /// 2. Look up the signing key by `kid` in the database.
63    /// 3. Verify the RS256 signature using the public key PEM.
64    /// 4. Check `exp` against the current time.
65    /// 5. Verify `iss` matches `expected_issuer`.
66    /// 6. Parse `sub` as `UserId` and return `AccessTokenClaims`.
67    pub async fn validate_access_token(
68        &self,
69        token: &str,
70        expected_issuer: &str,
71    ) -> Result<AccessTokenClaims, AuthError> {
72        // Step 1: decode header to extract kid
73        let header = decode_header(token)
74            .map_err(|e| AuthError::AccessToken(AccessTokenError::MalformedToken(e.to_string())))?;
75
76        // Step 2: extract kid
77        let kid_str = header.kid.ok_or_else(|| {
78            AuthError::AccessToken(AccessTokenError::MalformedToken("missing kid".into()))
79        })?;
80
81        // Step 3: parse kid as UUID then SigningKeyId
82        let kid_uuid = Uuid::parse_str(&kid_str)
83            .map_err(|_| AuthError::AccessToken(AccessTokenError::UnknownKid(kid_str.clone())))?;
84        let kid_id = crate::types::SigningKeyId::from_uuid(kid_uuid);
85
86        // Step 4: fetch signing key
87        let key = self.get_signing_key(kid_id).await.map_err(|e| match e {
88            AuthError::NotFound => AuthError::AccessToken(AccessTokenError::UnknownKid(kid_str)),
89            other => other,
90        })?;
91
92        // Step 5: build decoding key from public PEM
93        let decoding_key = DecodingKey::from_rsa_pem(key.public_key_pem.as_bytes())
94            .map_err(|e| AuthError::SigningKey(e.to_string()))?;
95
96        // Step 6: build validation
97        let mut validation = Validation::new(Algorithm::RS256);
98        validation.set_issuer(&[expected_issuer]);
99        validation.validate_aud = false;
100        validation.leeway = 0;
101
102        // Step 7: decode and verify
103        let token_data = decode::<RawAccessTokenClaims>(token, &decoding_key, &validation)
104            .map_err(|e| {
105                let err = match e.kind() {
106                    jsonwebtoken::errors::ErrorKind::ExpiredSignature => AccessTokenError::Expired,
107                    jsonwebtoken::errors::ErrorKind::InvalidSignature => {
108                        AccessTokenError::InvalidSignature
109                    }
110                    jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
111                        AccessTokenError::InvalidClaims("invalid issuer".into())
112                    }
113                    _ => AccessTokenError::InvalidClaims(e.to_string()),
114                };
115                AuthError::AccessToken(err)
116            })?;
117
118        let raw = token_data.claims;
119
120        // Step 8: parse sub as UUID
121        let sub_uuid = Uuid::parse_str(&raw.sub).map_err(|_| {
122            AuthError::AccessToken(AccessTokenError::InvalidClaims("invalid sub".into()))
123        })?;
124
125        Ok(AccessTokenClaims {
126            sub: UserId::from_uuid(sub_uuid),
127            scope: raw.scope,
128            iss: raw.iss,
129            aud: raw.aud,
130            exp: raw.exp,
131            iat: raw.iat,
132            email: raw.email,
133            email_verified: raw.email_verified,
134            username: raw.username,
135            roles: raw.roles,
136            permissions: raw.permissions,
137        })
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use crate::signing_keys::decrypt_private_key;
145    use base64ct::Encoding as _;
146    use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
147    use serde::Serialize;
148    use sqlx::SqlitePool;
149    use sqlx::sqlite::SqliteConnectOptions;
150    use std::str::FromStr;
151    use uuid::Uuid;
152
153    const ENC_KEY: [u8; 32] = [0x42; 32];
154    const ISSUER: &str = "https://auth.example.com";
155
156    async fn test_db() -> Db {
157        let opts = SqliteConnectOptions::from_str("sqlite::memory:")
158            .unwrap()
159            .pragma("foreign_keys", "ON");
160        let pool = SqlitePool::connect_with(opts).await.unwrap();
161        Db::new(pool).await.unwrap()
162    }
163
164    #[derive(Serialize)]
165    struct TestClaims {
166        sub: String,
167        scope: String,
168        iss: String,
169        aud: String,
170        exp: i64,
171        iat: i64,
172        email: String,
173        email_verified: bool,
174        #[serde(skip_serializing_if = "Option::is_none")]
175        username: Option<String>,
176        roles: Vec<String>,
177        permissions: Vec<String>,
178    }
179
180    /// Create a signing key in the DB and return the signed JWT string.
181    async fn sign_test_jwt(
182        db: &Db,
183        sub: &str,
184        scope: &str,
185        issuer: &str,
186        exp_offset_secs: i64,
187    ) -> (String, crate::types::SigningKeyId) {
188        let key = db.create_signing_key(&ENC_KEY).await.unwrap();
189        db.activate_signing_key(key.id).await.unwrap();
190
191        let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
192        let encoding_key = EncodingKey::from_rsa_pem(pem.as_bytes()).unwrap();
193
194        let now = chrono::Utc::now().timestamp();
195        let claims = TestClaims {
196            sub: sub.to_string(),
197            scope: scope.to_string(),
198            iss: issuer.to_string(),
199            aud: "ath_test_client".to_string(),
200            exp: now + exp_offset_secs,
201            iat: now,
202            email: "test@example.com".to_string(),
203            email_verified: true,
204            username: Some("testuser".to_string()),
205            roles: vec!["admin".to_string()],
206            permissions: vec!["posts:write".to_string()],
207        };
208
209        let mut header = Header::new(Algorithm::RS256);
210        header.kid = Some(key.id.to_string());
211
212        let token = encode(&header, &claims, &encoding_key).unwrap();
213        (token, key.id)
214    }
215
216    #[tokio::test]
217    async fn validate_access_token_valid() {
218        let db = test_db().await;
219        let sub = UserId::new().to_string();
220        let (token, _) = sign_test_jwt(&db, &sub, "openid profile", ISSUER, 300).await;
221
222        let claims = db.validate_access_token(&token, ISSUER).await.unwrap();
223        assert_eq!(claims.sub.to_string(), sub);
224        assert_eq!(claims.scope, "openid profile");
225        assert_eq!(claims.iss, ISSUER);
226    }
227
228    #[tokio::test]
229    async fn validate_access_token_expired() {
230        let db = test_db().await;
231        let sub = UserId::new().to_string();
232        let (token, _) = sign_test_jwt(&db, &sub, "openid", ISSUER, -60).await;
233
234        let err = db.validate_access_token(&token, ISSUER).await.unwrap_err();
235        assert!(matches!(
236            err,
237            AuthError::AccessToken(AccessTokenError::Expired)
238        ));
239    }
240
241    #[tokio::test]
242    async fn validate_access_token_wrong_issuer() {
243        let db = test_db().await;
244        let sub = UserId::new().to_string();
245        let (token, _) = sign_test_jwt(&db, &sub, "openid", "https://wrong.example.com", 300).await;
246
247        let err = db.validate_access_token(&token, ISSUER).await.unwrap_err();
248        assert!(matches!(
249            err,
250            AuthError::AccessToken(AccessTokenError::InvalidClaims(_))
251        ));
252        if let AuthError::AccessToken(AccessTokenError::InvalidClaims(msg)) = err {
253            assert!(msg.contains("issuer"));
254        }
255    }
256
257    #[tokio::test]
258    async fn validate_access_token_unknown_kid() {
259        let db = test_db().await;
260        let sub = UserId::new().to_string();
261        let (token, _) = sign_test_jwt(&db, &sub, "openid", ISSUER, 300).await;
262
263        // Tamper: reconstruct token with a random kid not in DB
264        let random_kid = Uuid::new_v4().to_string();
265        let parts: Vec<&str> = token.splitn(3, '.').collect();
266        let fake_header = base64ct::Base64UrlUnpadded::encode_string(
267            format!(r#"{{"alg":"RS256","kid":"{random_kid}","typ":"JWT"}}"#).as_bytes(),
268        );
269        let tampered = format!("{}.{}.{}", fake_header, parts[1], parts[2]);
270
271        let err = db
272            .validate_access_token(&tampered, ISSUER)
273            .await
274            .unwrap_err();
275        assert!(matches!(
276            err,
277            AuthError::AccessToken(AccessTokenError::UnknownKid(_))
278        ));
279    }
280
281    #[tokio::test]
282    async fn validate_access_token_bad_signature() {
283        let db = test_db().await;
284        let sub = UserId::new().to_string();
285
286        // Sign with key1's private key
287        let key1 = db.create_signing_key(&ENC_KEY).await.unwrap();
288        db.activate_signing_key(key1.id).await.unwrap();
289
290        // Create key2 (different key pair) — sign payload with key2's private key
291        let key2 = db.create_signing_key(&ENC_KEY).await.unwrap();
292        db.activate_signing_key(key2.id).await.unwrap();
293
294        let pem2 = decrypt_private_key(&key2, &ENC_KEY).unwrap();
295        let encoding_key2 = EncodingKey::from_rsa_pem(pem2.as_bytes()).unwrap();
296
297        let now = chrono::Utc::now().timestamp();
298        let claims = TestClaims {
299            sub: sub.clone(),
300            scope: "openid".to_string(),
301            iss: ISSUER.to_string(),
302            aud: "ath_test_client".to_string(),
303            exp: now + 300,
304            iat: now,
305            email: "test@example.com".to_string(),
306            email_verified: true,
307            username: None,
308            roles: vec![],
309            permissions: vec![],
310        };
311
312        // Set kid to key1's id but sign with key2's private key
313        let mut header = Header::new(Algorithm::RS256);
314        header.kid = Some(key1.id.to_string());
315        let token = encode(&header, &claims, &encoding_key2).unwrap();
316
317        let err = db.validate_access_token(&token, ISSUER).await.unwrap_err();
318        assert!(matches!(
319            err,
320            AuthError::AccessToken(AccessTokenError::InvalidSignature)
321        ));
322    }
323
324    #[tokio::test]
325    async fn validate_access_token_missing_kid() {
326        let db = test_db().await;
327        let sub = UserId::new().to_string();
328        let key = db.create_signing_key(&ENC_KEY).await.unwrap();
329        db.activate_signing_key(key.id).await.unwrap();
330
331        let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
332        let encoding_key = EncodingKey::from_rsa_pem(pem.as_bytes()).unwrap();
333
334        let now = chrono::Utc::now().timestamp();
335        let claims = TestClaims {
336            sub: sub.clone(),
337            scope: "openid".to_string(),
338            iss: ISSUER.to_string(),
339            aud: "ath_test_client".to_string(),
340            exp: now + 300,
341            iat: now,
342            email: "test@example.com".to_string(),
343            email_verified: true,
344            username: None,
345            roles: vec![],
346            permissions: vec![],
347        };
348
349        // No kid in header
350        let header = Header::new(Algorithm::RS256);
351        let token = encode(&header, &claims, &encoding_key).unwrap();
352
353        let err = db.validate_access_token(&token, ISSUER).await.unwrap_err();
354        assert!(matches!(
355            err,
356            AuthError::AccessToken(AccessTokenError::MalformedToken(_))
357        ));
358    }
359
360    #[tokio::test]
361    async fn has_scope_present() {
362        assert!(has_scope("openid profile email", "profile"));
363    }
364
365    #[tokio::test]
366    async fn has_scope_absent() {
367        assert!(!has_scope("openid profile", "email"));
368    }
369
370    #[tokio::test]
371    async fn has_scope_no_partial_match() {
372        assert!(!has_scope("openid profile_extended", "profile"));
373    }
374}