1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use sha2::{Digest, Sha256};
6
7use crate::db::Db;
8use crate::error::AuthError;
9use crate::types::{ApplicationId, AuthorizationCodeId, ConsentId, TokenHash, UserId};
10
11#[derive(Debug, Clone, sqlx::FromRow)]
12pub struct AuthorizationCode {
13 pub id: AuthorizationCodeId,
14 pub application_id: ApplicationId,
15 pub user_id: UserId,
16 pub code_hash: TokenHash,
17 pub redirect_uri: String,
18 pub scopes: String,
19 pub code_challenge: String,
20 pub code_challenge_method: String,
21 pub nonce: Option<String>,
22 pub expires_at: DateTime<Utc>,
23 pub used_at: Option<DateTime<Utc>>,
24 pub created_at: DateTime<Utc>,
25}
26
27#[derive(Debug, Clone, sqlx::FromRow)]
28pub struct Consent {
29 pub id: ConsentId,
30 pub user_id: UserId,
31 pub application_id: ApplicationId,
32 pub scopes: String,
33 pub created_at: DateTime<Utc>,
34 pub updated_at: DateTime<Utc>,
35}
36
37const SUPPORTED_SCOPES: &[&str] = &["openid", "profile", "email", "offline_access"];
39
40pub fn validate_scopes(scope_str: &str) -> Result<Vec<String>, AuthError> {
48 let scopes: Vec<String> = scope_str
49 .split_whitespace()
50 .map(|s| s.to_string())
51 .collect();
52
53 if scopes.is_empty() || !scopes.iter().any(|s| s == "openid") {
54 return Err(AuthError::InvalidAuthorizationRequest(
55 "scope must include openid".into(),
56 ));
57 }
58
59 for scope in &scopes {
60 if !SUPPORTED_SCOPES.contains(&scope.as_str()) {
61 return Err(AuthError::InvalidAuthorizationRequest(format!(
62 "unsupported scope: {scope}"
63 )));
64 }
65 }
66
67 Ok(scopes)
68}
69
70pub fn generate_authorization_code() -> String {
75 let mut bytes = [0u8; 32];
76 OsRng
77 .try_fill_bytes(&mut bytes)
78 .expect("OS RNG unavailable");
79 Base64UrlUnpadded::encode_string(&bytes)
80}
81
82pub fn hash_authorization_code(raw: &str) -> TokenHash {
87 let digest = Sha256::digest(raw.as_bytes());
88 TokenHash::new_unchecked(format!("{digest:x}"))
89}
90
91impl Db {
92 pub async fn has_sufficient_consent(
94 &self,
95 user_id: UserId,
96 application_id: ApplicationId,
97 requested_scopes: &[String],
98 ) -> Result<bool, AuthError> {
99 let consent = self.get_consent(user_id, application_id).await?;
100 let Some(consent) = consent else {
101 return Ok(false);
102 };
103 let stored: Vec<String> = serde_json::from_str(&consent.scopes)
104 .map_err(|e| AuthError::Database(sqlx::Error::Decode(Box::new(e))))?;
105 let stored_set: std::collections::HashSet<&str> =
106 stored.iter().map(|s| s.as_str()).collect();
107 Ok(requested_scopes
108 .iter()
109 .all(|s| stored_set.contains(s.as_str())))
110 }
111
112 pub async fn upsert_consent(
116 &self,
117 user_id: UserId,
118 application_id: ApplicationId,
119 scopes: &[String],
120 ) -> Result<(), AuthError> {
121 let id = ConsentId::new();
122 let scopes_json = serde_json::to_string(scopes).expect("Vec<String> serializes to JSON");
123 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
124
125 let existing = self.get_consent(user_id, application_id).await?;
126 let merged_json = if let Some(existing) = existing {
127 let mut stored: Vec<String> = serde_json::from_str(&existing.scopes)
128 .map_err(|e| AuthError::Database(sqlx::Error::Decode(Box::new(e))))?;
129 for scope in scopes {
130 if !stored.contains(scope) {
131 stored.push(scope.clone());
132 }
133 }
134 serde_json::to_string(&stored).expect("Vec<String> serializes to JSON")
135 } else {
136 scopes_json
137 };
138
139 sqlx::query(
140 "INSERT INTO allowthem_consents \
141 (id, user_id, application_id, scopes, created_at, updated_at) \
142 VALUES (?1, ?2, ?3, ?4, ?5, ?5) \
143 ON CONFLICT(user_id, application_id) DO UPDATE SET scopes = ?4, updated_at = ?5",
144 )
145 .bind(id)
146 .bind(user_id)
147 .bind(application_id)
148 .bind(&merged_json)
149 .bind(&now)
150 .execute(self.pool())
151 .await?;
152
153 Ok(())
154 }
155
156 pub async fn get_consent(
158 &self,
159 user_id: UserId,
160 application_id: ApplicationId,
161 ) -> Result<Option<Consent>, AuthError> {
162 sqlx::query_as::<_, Consent>(
163 "SELECT id, user_id, application_id, scopes, created_at, updated_at \
164 FROM allowthem_consents WHERE user_id = ? AND application_id = ?",
165 )
166 .bind(user_id)
167 .bind(application_id)
168 .fetch_optional(self.pool())
169 .await
170 .map_err(AuthError::Database)
171 }
172
173 #[allow(clippy::too_many_arguments)]
175 pub async fn create_authorization_code(
176 &self,
177 application_id: ApplicationId,
178 user_id: UserId,
179 code_hash: &TokenHash,
180 redirect_uri: &str,
181 scopes: &[String],
182 code_challenge: &str,
183 code_challenge_method: &str,
184 nonce: Option<&str>,
185 ) -> Result<AuthorizationCode, AuthError> {
186 let id = AuthorizationCodeId::new();
187 let scopes_json = serde_json::to_string(scopes).expect("Vec<String> serializes to JSON");
188 let now = Utc::now();
189 let now_str = now.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
190 let expires_at = now + chrono::Duration::minutes(10);
191 let expires_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
192
193 sqlx::query(
194 "INSERT INTO allowthem_authorization_codes \
195 (id, application_id, user_id, code_hash, redirect_uri, scopes, \
196 code_challenge, code_challenge_method, nonce, expires_at, created_at) \
197 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
198 )
199 .bind(id)
200 .bind(application_id)
201 .bind(user_id)
202 .bind(code_hash)
203 .bind(redirect_uri)
204 .bind(&scopes_json)
205 .bind(code_challenge)
206 .bind(code_challenge_method)
207 .bind(nonce)
208 .bind(&expires_str)
209 .bind(&now_str)
210 .execute(self.pool())
211 .await?;
212
213 sqlx::query_as::<_, AuthorizationCode>(
214 "SELECT id, application_id, user_id, code_hash, redirect_uri, scopes, \
215 code_challenge, code_challenge_method, nonce, expires_at, used_at, created_at \
216 FROM allowthem_authorization_codes WHERE id = ?",
217 )
218 .bind(id)
219 .fetch_one(self.pool())
220 .await
221 .map_err(AuthError::Database)
222 }
223
224 pub async fn get_authorization_code_by_hash(
226 &self,
227 code_hash: &TokenHash,
228 ) -> Result<Option<AuthorizationCode>, AuthError> {
229 sqlx::query_as::<_, AuthorizationCode>(
230 "SELECT id, application_id, user_id, code_hash, redirect_uri, scopes, \
231 code_challenge, code_challenge_method, nonce, expires_at, used_at, created_at \
232 FROM allowthem_authorization_codes WHERE code_hash = ?",
233 )
234 .bind(code_hash)
235 .fetch_optional(self.pool())
236 .await
237 .map_err(AuthError::Database)
238 }
239
240 pub async fn mark_authorization_code_used(
242 &self,
243 id: AuthorizationCodeId,
244 ) -> Result<(), AuthError> {
245 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
246 let result =
247 sqlx::query("UPDATE allowthem_authorization_codes SET used_at = ? WHERE id = ?")
248 .bind(&now)
249 .bind(id)
250 .execute(self.pool())
251 .await?;
252
253 if result.rows_affected() == 0 {
254 return Err(AuthError::NotFound);
255 }
256 Ok(())
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn valid_scopes_openid_only() {
266 let scopes = validate_scopes("openid").unwrap();
267 assert_eq!(scopes, vec!["openid"]);
268 }
269
270 #[test]
271 fn valid_scopes_all_three() {
272 let scopes = validate_scopes("openid profile email").unwrap();
273 assert_eq!(scopes, vec!["openid", "profile", "email"]);
274 }
275
276 #[test]
277 fn offline_access_is_accepted() {
278 let scopes = validate_scopes("openid offline_access").unwrap();
279 assert!(scopes.iter().any(|s| s == "offline_access"));
280 }
281
282 #[test]
283 fn full_default_scope_is_accepted() {
284 let scopes = validate_scopes("openid profile email offline_access").unwrap();
286 assert_eq!(scopes, vec!["openid", "profile", "email", "offline_access"]);
287 }
288
289 #[test]
290 fn missing_openid_is_rejected() {
291 let err = validate_scopes("profile email").unwrap_err();
292 assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
293 }
294
295 #[test]
296 fn empty_scope_is_rejected() {
297 let err = validate_scopes("").unwrap_err();
298 assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
299 }
300
301 #[test]
302 fn whitespace_only_scope_is_rejected() {
303 let err = validate_scopes(" ").unwrap_err();
304 assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
305 }
306
307 #[test]
308 fn unknown_scope_is_rejected() {
309 let err = validate_scopes("openid admin").unwrap_err();
310 assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
311 }
312
313 #[test]
314 fn duplicate_openid_is_fine() {
315 let scopes = validate_scopes("openid openid profile").unwrap();
316 assert_eq!(scopes, vec!["openid", "openid", "profile"]);
317 }
318
319 #[test]
320 fn code_is_43_chars() {
321 let code = generate_authorization_code();
322 assert_eq!(code.len(), 43, "32 bytes base64url = 43 chars");
323 }
324
325 #[test]
326 fn code_is_url_safe() {
327 let code = generate_authorization_code();
328 assert!(
329 code.chars()
330 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
331 "code must be URL-safe base64url: got {code}"
332 );
333 }
334
335 #[test]
336 fn two_codes_differ() {
337 let a = generate_authorization_code();
338 let b = generate_authorization_code();
339 assert_ne!(a, b);
340 }
341
342 #[test]
343 fn hash_is_deterministic() {
344 let code = generate_authorization_code();
345 let h1 = hash_authorization_code(&code);
346 let h2 = hash_authorization_code(&code);
347 assert_eq!(format!("{h1:?}"), format!("{h2:?}"));
348 }
349
350 #[test]
351 fn different_codes_produce_different_hashes() {
352 let a = generate_authorization_code();
353 let b = generate_authorization_code();
354 let ha = hash_authorization_code(&a);
355 let hb = hash_authorization_code(&b);
356 assert_ne!(format!("{ha:?}"), format!("{hb:?}"));
357 }
358}