Skip to main content

allowthem_core/
authorization.rs

1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use sha2::{Digest, Sha256};
6
7use crate::db::Db;
8use crate::error::AuthError;
9use crate::types::{ApplicationId, AuthorizationCodeId, ConsentId, TokenHash, UserId};
10
11#[derive(Debug, Clone, sqlx::FromRow)]
12pub struct AuthorizationCode {
13    pub id: AuthorizationCodeId,
14    pub application_id: ApplicationId,
15    pub user_id: UserId,
16    pub code_hash: TokenHash,
17    pub redirect_uri: String,
18    pub scopes: String,
19    pub code_challenge: String,
20    pub code_challenge_method: String,
21    pub nonce: Option<String>,
22    pub expires_at: DateTime<Utc>,
23    pub used_at: Option<DateTime<Utc>>,
24    pub created_at: DateTime<Utc>,
25}
26
27#[derive(Debug, Clone, sqlx::FromRow)]
28pub struct Consent {
29    pub id: ConsentId,
30    pub user_id: UserId,
31    pub application_id: ApplicationId,
32    pub scopes: String,
33    pub created_at: DateTime<Utc>,
34    pub updated_at: DateTime<Utc>,
35}
36
37/// The set of supported OIDC scopes.
38const SUPPORTED_SCOPES: &[&str] = &["openid", "profile", "email", "offline_access"];
39
40/// Parse and validate a space-separated scope string.
41///
42/// Rules:
43/// - `openid` must be present (this is an OIDC provider).
44/// - All scopes must be in `SUPPORTED_SCOPES`.
45///
46/// Returns the validated scopes as a `Vec<String>`.
47pub fn validate_scopes(scope_str: &str) -> Result<Vec<String>, AuthError> {
48    let scopes: Vec<String> = scope_str
49        .split_whitespace()
50        .map(|s| s.to_string())
51        .collect();
52
53    if scopes.is_empty() || !scopes.iter().any(|s| s == "openid") {
54        return Err(AuthError::InvalidAuthorizationRequest(
55            "scope must include openid".into(),
56        ));
57    }
58
59    for scope in &scopes {
60        if !SUPPORTED_SCOPES.contains(&scope.as_str()) {
61            return Err(AuthError::InvalidAuthorizationRequest(format!(
62                "unsupported scope: {scope}"
63            )));
64        }
65    }
66
67    Ok(scopes)
68}
69
70/// Generate a raw authorization code: 32 random bytes, base64url-encoded.
71///
72/// Same pattern as `sessions::generate_token()`. Returns the raw code string
73/// to include in the redirect URI. The caller must hash it before storage.
74pub fn generate_authorization_code() -> String {
75    let mut bytes = [0u8; 32];
76    OsRng
77        .try_fill_bytes(&mut bytes)
78        .expect("OS RNG unavailable");
79    Base64UrlUnpadded::encode_string(&bytes)
80}
81
82/// Hash a raw authorization code with SHA-256.
83///
84/// Returns the hex-encoded digest as a `TokenHash`. Same pattern as
85/// `sessions::hash_token()`.
86pub fn hash_authorization_code(raw: &str) -> TokenHash {
87    let digest = Sha256::digest(raw.as_bytes());
88    TokenHash::new_unchecked(format!("{digest:x}"))
89}
90
91impl Db {
92    /// Check whether the user has an existing consent that covers all requested scopes.
93    pub async fn has_sufficient_consent(
94        &self,
95        user_id: UserId,
96        application_id: ApplicationId,
97        requested_scopes: &[String],
98    ) -> Result<bool, AuthError> {
99        let consent = self.get_consent(user_id, application_id).await?;
100        let Some(consent) = consent else {
101            return Ok(false);
102        };
103        let stored: Vec<String> = serde_json::from_str(&consent.scopes)
104            .map_err(|e| AuthError::Database(sqlx::Error::Decode(Box::new(e))))?;
105        let stored_set: std::collections::HashSet<&str> =
106            stored.iter().map(|s| s.as_str()).collect();
107        Ok(requested_scopes
108            .iter()
109            .all(|s| stored_set.contains(s.as_str())))
110    }
111
112    /// Record or update user consent for an application.
113    ///
114    /// Stored scopes become the union of existing and new scopes (consent is additive).
115    pub async fn upsert_consent(
116        &self,
117        user_id: UserId,
118        application_id: ApplicationId,
119        scopes: &[String],
120    ) -> Result<(), AuthError> {
121        let id = ConsentId::new();
122        let scopes_json = serde_json::to_string(scopes).expect("Vec<String> serializes to JSON");
123        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
124
125        let existing = self.get_consent(user_id, application_id).await?;
126        let merged_json = if let Some(existing) = existing {
127            let mut stored: Vec<String> = serde_json::from_str(&existing.scopes)
128                .map_err(|e| AuthError::Database(sqlx::Error::Decode(Box::new(e))))?;
129            for scope in scopes {
130                if !stored.contains(scope) {
131                    stored.push(scope.clone());
132                }
133            }
134            serde_json::to_string(&stored).expect("Vec<String> serializes to JSON")
135        } else {
136            scopes_json
137        };
138
139        sqlx::query(
140            "INSERT INTO allowthem_consents \
141             (id, user_id, application_id, scopes, created_at, updated_at) \
142             VALUES (?1, ?2, ?3, ?4, ?5, ?5) \
143             ON CONFLICT(user_id, application_id) DO UPDATE SET scopes = ?4, updated_at = ?5",
144        )
145        .bind(id)
146        .bind(user_id)
147        .bind(application_id)
148        .bind(&merged_json)
149        .bind(&now)
150        .execute(self.pool())
151        .await?;
152
153        Ok(())
154    }
155
156    /// Get the consent record for a user and application, if any.
157    pub async fn get_consent(
158        &self,
159        user_id: UserId,
160        application_id: ApplicationId,
161    ) -> Result<Option<Consent>, AuthError> {
162        sqlx::query_as::<_, Consent>(
163            "SELECT id, user_id, application_id, scopes, created_at, updated_at \
164             FROM allowthem_consents WHERE user_id = ? AND application_id = ?",
165        )
166        .bind(user_id)
167        .bind(application_id)
168        .fetch_optional(self.pool())
169        .await
170        .map_err(AuthError::Database)
171    }
172
173    /// Create an authorization code record. Expires after 10 minutes.
174    #[allow(clippy::too_many_arguments)]
175    pub async fn create_authorization_code(
176        &self,
177        application_id: ApplicationId,
178        user_id: UserId,
179        code_hash: &TokenHash,
180        redirect_uri: &str,
181        scopes: &[String],
182        code_challenge: &str,
183        code_challenge_method: &str,
184        nonce: Option<&str>,
185    ) -> Result<AuthorizationCode, AuthError> {
186        let id = AuthorizationCodeId::new();
187        let scopes_json = serde_json::to_string(scopes).expect("Vec<String> serializes to JSON");
188        let now = Utc::now();
189        let now_str = now.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
190        let expires_at = now + chrono::Duration::minutes(10);
191        let expires_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
192
193        sqlx::query(
194            "INSERT INTO allowthem_authorization_codes \
195             (id, application_id, user_id, code_hash, redirect_uri, scopes, \
196              code_challenge, code_challenge_method, nonce, expires_at, created_at) \
197             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
198        )
199        .bind(id)
200        .bind(application_id)
201        .bind(user_id)
202        .bind(code_hash)
203        .bind(redirect_uri)
204        .bind(&scopes_json)
205        .bind(code_challenge)
206        .bind(code_challenge_method)
207        .bind(nonce)
208        .bind(&expires_str)
209        .bind(&now_str)
210        .execute(self.pool())
211        .await?;
212
213        sqlx::query_as::<_, AuthorizationCode>(
214            "SELECT id, application_id, user_id, code_hash, redirect_uri, scopes, \
215             code_challenge, code_challenge_method, nonce, expires_at, used_at, created_at \
216             FROM allowthem_authorization_codes WHERE id = ?",
217        )
218        .bind(id)
219        .fetch_one(self.pool())
220        .await
221        .map_err(AuthError::Database)
222    }
223
224    /// Look up an authorization code by its hash.
225    pub async fn get_authorization_code_by_hash(
226        &self,
227        code_hash: &TokenHash,
228    ) -> Result<Option<AuthorizationCode>, AuthError> {
229        sqlx::query_as::<_, AuthorizationCode>(
230            "SELECT id, application_id, user_id, code_hash, redirect_uri, scopes, \
231             code_challenge, code_challenge_method, nonce, expires_at, used_at, created_at \
232             FROM allowthem_authorization_codes WHERE code_hash = ?",
233        )
234        .bind(code_hash)
235        .fetch_optional(self.pool())
236        .await
237        .map_err(AuthError::Database)
238    }
239
240    /// Mark an authorization code as used.
241    pub async fn mark_authorization_code_used(
242        &self,
243        id: AuthorizationCodeId,
244    ) -> Result<(), AuthError> {
245        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
246        let result =
247            sqlx::query("UPDATE allowthem_authorization_codes SET used_at = ? WHERE id = ?")
248                .bind(&now)
249                .bind(id)
250                .execute(self.pool())
251                .await?;
252
253        if result.rows_affected() == 0 {
254            return Err(AuthError::NotFound);
255        }
256        Ok(())
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn valid_scopes_openid_only() {
266        let scopes = validate_scopes("openid").unwrap();
267        assert_eq!(scopes, vec!["openid"]);
268    }
269
270    #[test]
271    fn valid_scopes_all_three() {
272        let scopes = validate_scopes("openid profile email").unwrap();
273        assert_eq!(scopes, vec!["openid", "profile", "email"]);
274    }
275
276    #[test]
277    fn offline_access_is_accepted() {
278        let scopes = validate_scopes("openid offline_access").unwrap();
279        assert!(scopes.iter().any(|s| s == "offline_access"));
280    }
281
282    #[test]
283    fn full_default_scope_is_accepted() {
284        // Mirrors @allowthem/js's DEFAULT_SCOPE.
285        let scopes = validate_scopes("openid profile email offline_access").unwrap();
286        assert_eq!(scopes, vec!["openid", "profile", "email", "offline_access"]);
287    }
288
289    #[test]
290    fn missing_openid_is_rejected() {
291        let err = validate_scopes("profile email").unwrap_err();
292        assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
293    }
294
295    #[test]
296    fn empty_scope_is_rejected() {
297        let err = validate_scopes("").unwrap_err();
298        assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
299    }
300
301    #[test]
302    fn whitespace_only_scope_is_rejected() {
303        let err = validate_scopes("   ").unwrap_err();
304        assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
305    }
306
307    #[test]
308    fn unknown_scope_is_rejected() {
309        let err = validate_scopes("openid admin").unwrap_err();
310        assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
311    }
312
313    #[test]
314    fn duplicate_openid_is_fine() {
315        let scopes = validate_scopes("openid openid profile").unwrap();
316        assert_eq!(scopes, vec!["openid", "openid", "profile"]);
317    }
318
319    #[test]
320    fn code_is_43_chars() {
321        let code = generate_authorization_code();
322        assert_eq!(code.len(), 43, "32 bytes base64url = 43 chars");
323    }
324
325    #[test]
326    fn code_is_url_safe() {
327        let code = generate_authorization_code();
328        assert!(
329            code.chars()
330                .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
331            "code must be URL-safe base64url: got {code}"
332        );
333    }
334
335    #[test]
336    fn two_codes_differ() {
337        let a = generate_authorization_code();
338        let b = generate_authorization_code();
339        assert_ne!(a, b);
340    }
341
342    #[test]
343    fn hash_is_deterministic() {
344        let code = generate_authorization_code();
345        let h1 = hash_authorization_code(&code);
346        let h2 = hash_authorization_code(&code);
347        assert_eq!(format!("{h1:?}"), format!("{h2:?}"));
348    }
349
350    #[test]
351    fn different_codes_produce_different_hashes() {
352        let a = generate_authorization_code();
353        let b = generate_authorization_code();
354        let ha = hash_authorization_code(&a);
355        let hb = hash_authorization_code(&b);
356        assert_ne!(format!("{ha:?}"), format!("{hb:?}"));
357    }
358}