1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use sha2::{Digest, Sha256};
6
7use crate::db::Db;
8use crate::error::AuthError;
9use crate::types::{ApplicationId, AuthorizationCodeId, ConsentId, TokenHash, UserId};
10
11#[derive(Debug, Clone, sqlx::FromRow)]
12pub struct AuthorizationCode {
13 pub id: AuthorizationCodeId,
14 pub application_id: ApplicationId,
15 pub user_id: UserId,
16 pub code_hash: TokenHash,
17 pub redirect_uri: String,
18 pub scopes: String,
19 pub code_challenge: String,
20 pub code_challenge_method: String,
21 pub nonce: Option<String>,
22 pub expires_at: DateTime<Utc>,
23 pub used_at: Option<DateTime<Utc>>,
24 pub created_at: DateTime<Utc>,
25}
26
27#[derive(Debug, Clone, sqlx::FromRow)]
28pub struct Consent {
29 pub id: ConsentId,
30 pub user_id: UserId,
31 pub application_id: ApplicationId,
32 pub scopes: String,
33 pub created_at: DateTime<Utc>,
34 pub updated_at: DateTime<Utc>,
35}
36
37const SUPPORTED_SCOPES: &[&str] = &["openid", "profile", "email"];
39
40pub fn validate_scopes(scope_str: &str) -> Result<Vec<String>, AuthError> {
48 let scopes: Vec<String> = scope_str
49 .split_whitespace()
50 .map(|s| s.to_string())
51 .collect();
52
53 if scopes.is_empty() || !scopes.iter().any(|s| s == "openid") {
54 return Err(AuthError::InvalidAuthorizationRequest(
55 "scope must include openid".into(),
56 ));
57 }
58
59 for scope in &scopes {
60 if !SUPPORTED_SCOPES.contains(&scope.as_str()) {
61 return Err(AuthError::InvalidAuthorizationRequest(format!(
62 "unsupported scope: {scope}"
63 )));
64 }
65 }
66
67 Ok(scopes)
68}
69
70pub fn generate_authorization_code() -> String {
75 let mut bytes = [0u8; 32];
76 OsRng
77 .try_fill_bytes(&mut bytes)
78 .expect("OS RNG unavailable");
79 Base64UrlUnpadded::encode_string(&bytes)
80}
81
82pub fn hash_authorization_code(raw: &str) -> TokenHash {
87 let digest = Sha256::digest(raw.as_bytes());
88 TokenHash::new_unchecked(format!("{digest:x}"))
89}
90
91impl Db {
92 pub async fn has_sufficient_consent(
94 &self,
95 user_id: UserId,
96 application_id: ApplicationId,
97 requested_scopes: &[String],
98 ) -> Result<bool, AuthError> {
99 let consent = self.get_consent(user_id, application_id).await?;
100 let Some(consent) = consent else {
101 return Ok(false);
102 };
103 let stored: Vec<String> = serde_json::from_str(&consent.scopes)
104 .map_err(|e| AuthError::Database(sqlx::Error::Decode(Box::new(e))))?;
105 let stored_set: std::collections::HashSet<&str> =
106 stored.iter().map(|s| s.as_str()).collect();
107 Ok(requested_scopes
108 .iter()
109 .all(|s| stored_set.contains(s.as_str())))
110 }
111
112 pub async fn upsert_consent(
116 &self,
117 user_id: UserId,
118 application_id: ApplicationId,
119 scopes: &[String],
120 ) -> Result<(), AuthError> {
121 let id = ConsentId::new();
122 let scopes_json = serde_json::to_string(scopes).expect("Vec<String> serializes to JSON");
123 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
124
125 let existing = self.get_consent(user_id, application_id).await?;
126 let merged_json = if let Some(existing) = existing {
127 let mut stored: Vec<String> = serde_json::from_str(&existing.scopes)
128 .map_err(|e| AuthError::Database(sqlx::Error::Decode(Box::new(e))))?;
129 for scope in scopes {
130 if !stored.contains(scope) {
131 stored.push(scope.clone());
132 }
133 }
134 serde_json::to_string(&stored).expect("Vec<String> serializes to JSON")
135 } else {
136 scopes_json
137 };
138
139 sqlx::query(
140 "INSERT INTO allowthem_consents \
141 (id, user_id, application_id, scopes, created_at, updated_at) \
142 VALUES (?1, ?2, ?3, ?4, ?5, ?5) \
143 ON CONFLICT(user_id, application_id) DO UPDATE SET scopes = ?4, updated_at = ?5",
144 )
145 .bind(id)
146 .bind(user_id)
147 .bind(application_id)
148 .bind(&merged_json)
149 .bind(&now)
150 .execute(self.pool())
151 .await?;
152
153 Ok(())
154 }
155
156 pub async fn get_consent(
158 &self,
159 user_id: UserId,
160 application_id: ApplicationId,
161 ) -> Result<Option<Consent>, AuthError> {
162 sqlx::query_as::<_, Consent>(
163 "SELECT id, user_id, application_id, scopes, created_at, updated_at \
164 FROM allowthem_consents WHERE user_id = ? AND application_id = ?",
165 )
166 .bind(user_id)
167 .bind(application_id)
168 .fetch_optional(self.pool())
169 .await
170 .map_err(AuthError::Database)
171 }
172
173 #[allow(clippy::too_many_arguments)]
175 pub async fn create_authorization_code(
176 &self,
177 application_id: ApplicationId,
178 user_id: UserId,
179 code_hash: &TokenHash,
180 redirect_uri: &str,
181 scopes: &[String],
182 code_challenge: &str,
183 code_challenge_method: &str,
184 nonce: Option<&str>,
185 ) -> Result<AuthorizationCode, AuthError> {
186 let id = AuthorizationCodeId::new();
187 let scopes_json = serde_json::to_string(scopes).expect("Vec<String> serializes to JSON");
188 let now = Utc::now();
189 let now_str = now.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
190 let expires_at = now + chrono::Duration::minutes(10);
191 let expires_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
192
193 sqlx::query(
194 "INSERT INTO allowthem_authorization_codes \
195 (id, application_id, user_id, code_hash, redirect_uri, scopes, \
196 code_challenge, code_challenge_method, nonce, expires_at, created_at) \
197 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
198 )
199 .bind(id)
200 .bind(application_id)
201 .bind(user_id)
202 .bind(code_hash)
203 .bind(redirect_uri)
204 .bind(&scopes_json)
205 .bind(code_challenge)
206 .bind(code_challenge_method)
207 .bind(nonce)
208 .bind(&expires_str)
209 .bind(&now_str)
210 .execute(self.pool())
211 .await?;
212
213 sqlx::query_as::<_, AuthorizationCode>(
214 "SELECT id, application_id, user_id, code_hash, redirect_uri, scopes, \
215 code_challenge, code_challenge_method, nonce, expires_at, used_at, created_at \
216 FROM allowthem_authorization_codes WHERE id = ?",
217 )
218 .bind(id)
219 .fetch_one(self.pool())
220 .await
221 .map_err(AuthError::Database)
222 }
223
224 pub async fn get_authorization_code_by_hash(
226 &self,
227 code_hash: &TokenHash,
228 ) -> Result<Option<AuthorizationCode>, AuthError> {
229 sqlx::query_as::<_, AuthorizationCode>(
230 "SELECT id, application_id, user_id, code_hash, redirect_uri, scopes, \
231 code_challenge, code_challenge_method, nonce, expires_at, used_at, created_at \
232 FROM allowthem_authorization_codes WHERE code_hash = ?",
233 )
234 .bind(code_hash)
235 .fetch_optional(self.pool())
236 .await
237 .map_err(AuthError::Database)
238 }
239
240 pub async fn mark_authorization_code_used(
242 &self,
243 id: AuthorizationCodeId,
244 ) -> Result<(), AuthError> {
245 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
246 let result =
247 sqlx::query("UPDATE allowthem_authorization_codes SET used_at = ? WHERE id = ?")
248 .bind(&now)
249 .bind(id)
250 .execute(self.pool())
251 .await?;
252
253 if result.rows_affected() == 0 {
254 return Err(AuthError::NotFound);
255 }
256 Ok(())
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn valid_scopes_openid_only() {
266 let scopes = validate_scopes("openid").unwrap();
267 assert_eq!(scopes, vec!["openid"]);
268 }
269
270 #[test]
271 fn valid_scopes_all_three() {
272 let scopes = validate_scopes("openid profile email").unwrap();
273 assert_eq!(scopes, vec!["openid", "profile", "email"]);
274 }
275
276 #[test]
277 fn missing_openid_is_rejected() {
278 let err = validate_scopes("profile email").unwrap_err();
279 assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
280 }
281
282 #[test]
283 fn empty_scope_is_rejected() {
284 let err = validate_scopes("").unwrap_err();
285 assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
286 }
287
288 #[test]
289 fn whitespace_only_scope_is_rejected() {
290 let err = validate_scopes(" ").unwrap_err();
291 assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
292 }
293
294 #[test]
295 fn unknown_scope_is_rejected() {
296 let err = validate_scopes("openid admin").unwrap_err();
297 assert!(matches!(err, AuthError::InvalidAuthorizationRequest(_)));
298 }
299
300 #[test]
301 fn duplicate_openid_is_fine() {
302 let scopes = validate_scopes("openid openid profile").unwrap();
303 assert_eq!(scopes, vec!["openid", "openid", "profile"]);
304 }
305
306 #[test]
307 fn code_is_43_chars() {
308 let code = generate_authorization_code();
309 assert_eq!(code.len(), 43, "32 bytes base64url = 43 chars");
310 }
311
312 #[test]
313 fn code_is_url_safe() {
314 let code = generate_authorization_code();
315 assert!(
316 code.chars()
317 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
318 "code must be URL-safe base64url: got {code}"
319 );
320 }
321
322 #[test]
323 fn two_codes_differ() {
324 let a = generate_authorization_code();
325 let b = generate_authorization_code();
326 assert_ne!(a, b);
327 }
328
329 #[test]
330 fn hash_is_deterministic() {
331 let code = generate_authorization_code();
332 let h1 = hash_authorization_code(&code);
333 let h2 = hash_authorization_code(&code);
334 assert_eq!(format!("{h1:?}"), format!("{h2:?}"));
335 }
336
337 #[test]
338 fn different_codes_produce_different_hashes() {
339 let a = generate_authorization_code();
340 let b = generate_authorization_code();
341 let ha = hash_authorization_code(&a);
342 let hb = hash_authorization_code(&b);
343 assert_ne!(format!("{ha:?}"), format!("{hb:?}"));
344 }
345}