1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Duration, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use serde::Serialize;
6use sha2::{Digest, Sha256};
7
8use crate::auth_client::AuthFuture;
9use crate::db::Db;
10use crate::error::AuthError;
11use crate::types::{Email, OAuthAccountId, OAuthStateId, User, UserId};
12use crate::users::map_unique_violation;
13
14#[derive(Debug, Clone)]
20pub struct OAuthUserInfo {
21 pub provider_user_id: String,
22 pub email: String,
23 pub email_verified: bool,
24 pub name: Option<String>,
25}
26
27#[derive(Debug, Clone, sqlx::FromRow)]
29pub struct OAuthStateInfo {
30 pub provider: String,
31 pub redirect_uri: String,
32 pub pkce_verifier: String,
33 pub post_login_redirect: Option<String>,
34 pub linking_user_id: Option<UserId>,
37}
38
39#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
41pub struct OAuthAccountInfo {
42 pub provider: String,
43 pub provider_user_id: String,
44 pub email: String,
45 pub created_at: DateTime<Utc>,
46}
47
48pub trait OAuthProvider: Send + Sync {
58 fn name(&self) -> &str;
61
62 fn authorize_url(&self, redirect_uri: &str, state: &str, pkce_challenge: &str) -> String;
64
65 fn exchange_code<'a>(
67 &'a self,
68 code: &'a str,
69 redirect_uri: &'a str,
70 pkce_verifier: &'a str,
71 ) -> AuthFuture<'a, String>;
72
73 fn user_info<'a>(&'a self, access_token: &'a str) -> AuthFuture<'a, OAuthUserInfo>;
75}
76
77pub fn generate_pkce_verifier() -> String {
83 let mut bytes = [0u8; 32];
84 OsRng
85 .try_fill_bytes(&mut bytes)
86 .expect("OS RNG unavailable");
87 Base64UrlUnpadded::encode_string(&bytes)
88}
89
90pub fn pkce_challenge(verifier: &str) -> String {
94 let digest = Sha256::digest(verifier.as_bytes());
95 Base64UrlUnpadded::encode_string(&digest)
96}
97
98fn generate_state() -> String {
104 let mut bytes = [0u8; 32];
105 OsRng
106 .try_fill_bytes(&mut bytes)
107 .expect("OS RNG unavailable");
108 Base64UrlUnpadded::encode_string(&bytes)
109}
110
111fn hash_state(raw: &str) -> String {
113 let digest = Sha256::digest(raw.as_bytes());
114 format!("{digest:x}")
115}
116
117impl Db {
122 pub async fn create_oauth_state(
128 &self,
129 provider: &str,
130 redirect_uri: &str,
131 pkce_verifier: &str,
132 post_login_redirect: Option<&str>,
133 linking_user_id: Option<UserId>,
134 ) -> Result<String, AuthError> {
135 let raw_state = generate_state();
136 let state_hash = hash_state(&raw_state);
137 let id = OAuthStateId::new();
138 let expires_at = (Utc::now() + Duration::minutes(10))
139 .format("%Y-%m-%dT%H:%M:%S%.3fZ")
140 .to_string();
141
142 sqlx::query(
143 "INSERT INTO allowthem_oauth_states \
144 (id, state_hash, provider, redirect_uri, pkce_verifier, post_login_redirect, expires_at, linking_user_id) \
145 VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
146 )
147 .bind(id)
148 .bind(&state_hash)
149 .bind(provider)
150 .bind(redirect_uri)
151 .bind(pkce_verifier)
152 .bind(post_login_redirect)
153 .bind(&expires_at)
154 .bind(linking_user_id)
155 .execute(self.pool())
156 .await?;
157
158 Ok(raw_state)
159 }
160
161 pub async fn validate_oauth_state(
164 &self,
165 raw_state: &str,
166 ) -> Result<Option<OAuthStateInfo>, AuthError> {
167 let state_hash = hash_state(raw_state);
168 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
169
170 sqlx::query_as::<_, OAuthStateInfo>(
171 "DELETE FROM allowthem_oauth_states \
172 WHERE state_hash = ? AND expires_at > ? \
173 RETURNING provider, redirect_uri, pkce_verifier, post_login_redirect, linking_user_id",
174 )
175 .bind(&state_hash)
176 .bind(&now)
177 .fetch_optional(self.pool())
178 .await
179 .map_err(AuthError::Database)
180 }
181
182 pub async fn create_oauth_user(
191 &self,
192 email: Email,
193 provider: &str,
194 provider_user_id: &str,
195 ) -> Result<User, AuthError> {
196 let user_id = UserId::new();
197 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
198
199 let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
200
201 sqlx::query(
202 "INSERT INTO allowthem_users \
203 (id, email, username, password_hash, email_verified, is_active, created_at, updated_at) \
204 VALUES (?, ?, NULL, NULL, 0, 1, ?, ?)",
205 )
206 .bind(user_id)
207 .bind(&email)
208 .bind(&now)
209 .bind(&now)
210 .execute(&mut *tx)
211 .await
212 .map_err(map_unique_violation)?;
213
214 sqlx::query(
215 "INSERT INTO allowthem_oauth_accounts \
216 (id, user_id, provider, provider_user_id, email, created_at) \
217 VALUES (?, ?, ?, ?, ?, ?)",
218 )
219 .bind(OAuthAccountId::new())
220 .bind(user_id)
221 .bind(provider)
222 .bind(provider_user_id)
223 .bind(email.as_str())
224 .bind(&now)
225 .execute(&mut *tx)
226 .await
227 .map_err(map_unique_violation)?;
228
229 tx.commit().await.map_err(AuthError::Database)?;
230
231 self.get_user(user_id).await
232 }
233
234 pub async fn link_oauth_account(
236 &self,
237 user_id: UserId,
238 provider: &str,
239 provider_user_id: &str,
240 email: &str,
241 ) -> Result<(), AuthError> {
242 sqlx::query(
243 "INSERT INTO allowthem_oauth_accounts \
244 (id, user_id, provider, provider_user_id, email) \
245 VALUES (?, ?, ?, ?, ?)",
246 )
247 .bind(OAuthAccountId::new())
248 .bind(user_id)
249 .bind(provider)
250 .bind(provider_user_id)
251 .bind(email)
252 .execute(self.pool())
253 .await
254 .map_err(map_unique_violation)?;
255
256 Ok(())
257 }
258
259 pub async fn find_user_by_oauth(
261 &self,
262 provider: &str,
263 provider_user_id: &str,
264 ) -> Result<Option<User>, AuthError> {
265 sqlx::query_as::<_, User>(
266 "SELECT u.id, u.email, u.username, NULL as password_hash, \
267 u.email_verified, u.is_active, u.created_at, u.updated_at, u.custom_data \
268 FROM allowthem_users u \
269 INNER JOIN allowthem_oauth_accounts oa ON oa.user_id = u.id \
270 WHERE oa.provider = ? AND oa.provider_user_id = ?",
271 )
272 .bind(provider)
273 .bind(provider_user_id)
274 .fetch_optional(self.pool())
275 .await
276 .map_err(AuthError::Database)
277 }
278
279 pub async fn get_user_oauth_accounts(
281 &self,
282 user_id: UserId,
283 ) -> Result<Vec<OAuthAccountInfo>, AuthError> {
284 sqlx::query_as::<_, OAuthAccountInfo>(
285 "SELECT provider, provider_user_id, email, created_at \
286 FROM allowthem_oauth_accounts \
287 WHERE user_id = ? \
288 ORDER BY created_at ASC",
289 )
290 .bind(user_id)
291 .fetch_all(self.pool())
292 .await
293 .map_err(AuthError::Database)
294 }
295
296 pub async fn unlink_oauth_account(
300 &self,
301 user_id: UserId,
302 provider: &str,
303 ) -> Result<bool, AuthError> {
304 let result =
305 sqlx::query("DELETE FROM allowthem_oauth_accounts WHERE user_id = ? AND provider = ?")
306 .bind(user_id)
307 .bind(provider)
308 .execute(self.pool())
309 .await
310 .map_err(AuthError::Database)?;
311
312 Ok(result.rows_affected() > 0)
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::db::Db;
320
321 async fn test_db() -> Db {
322 Db::connect("sqlite::memory:").await.expect("in-memory db")
323 }
324
325 #[test]
328 fn pkce_verifier_is_43_chars() {
329 let v = generate_pkce_verifier();
330 assert_eq!(v.len(), 43);
331 }
332
333 #[test]
334 fn pkce_challenge_is_deterministic() {
335 let v = generate_pkce_verifier();
336 let c1 = pkce_challenge(&v);
337 let c2 = pkce_challenge(&v);
338 assert_eq!(c1, c2);
339 }
340
341 #[test]
342 fn pkce_challenge_is_base64url() {
343 let v = generate_pkce_verifier();
344 let c = pkce_challenge(&v);
345 assert!(!c.contains('+'), "must not contain +");
346 assert!(!c.contains('/'), "must not contain /");
347 assert!(!c.contains('='), "must not contain =");
348 }
349
350 #[test]
351 fn pkce_challenge_differs_from_verifier() {
352 let v = generate_pkce_verifier();
353 let c = pkce_challenge(&v);
354 assert_ne!(v, c);
355 }
356
357 #[tokio::test]
360 async fn create_state_returns_nonempty_string() {
361 let db = test_db().await;
362 let state = db
363 .create_oauth_state(
364 "google",
365 "https://example.com/callback",
366 "verifier123",
367 None,
368 None,
369 )
370 .await
371 .expect("create state");
372 assert!(!state.is_empty());
373 }
374
375 #[tokio::test]
376 async fn validate_state_returns_info_for_valid_state() {
377 let db = test_db().await;
378 let raw = db
379 .create_oauth_state(
380 "google",
381 "https://example.com/cb",
382 "my-verifier",
383 None,
384 None,
385 )
386 .await
387 .expect("create");
388 let info = db.validate_oauth_state(&raw).await.expect("validate");
389 assert!(info.is_some());
390 let info = info.unwrap();
391 assert_eq!(info.provider, "google");
392 assert_eq!(info.redirect_uri, "https://example.com/cb");
393 assert_eq!(info.pkce_verifier, "my-verifier");
394 }
395
396 #[tokio::test]
397 async fn validate_state_is_single_use() {
398 let db = test_db().await;
399 let raw = db
400 .create_oauth_state("github", "https://example.com/cb", "v", None, None)
401 .await
402 .expect("create");
403 let first = db.validate_oauth_state(&raw).await.expect("first");
404 assert!(first.is_some());
405 let second = db.validate_oauth_state(&raw).await.expect("second");
406 assert!(second.is_none(), "state must be single-use");
407 }
408
409 #[tokio::test]
410 async fn validate_state_returns_none_for_garbage() {
411 let db = test_db().await;
412 let result = db
413 .validate_oauth_state("not-a-real-state")
414 .await
415 .expect("validate");
416 assert!(result.is_none());
417 }
418
419 #[tokio::test]
420 async fn validate_state_preserves_post_login_redirect() {
421 let db = test_db().await;
422 let raw = db
423 .create_oauth_state(
424 "google",
425 "https://example.com/cb",
426 "v",
427 Some("/settings"),
428 None,
429 )
430 .await
431 .expect("create");
432 let info = db
433 .validate_oauth_state(&raw)
434 .await
435 .expect("validate")
436 .unwrap();
437 assert_eq!(info.post_login_redirect.as_deref(), Some("/settings"));
438 }
439
440 #[tokio::test]
441 async fn validate_state_returns_none_for_post_login_redirect_when_not_set() {
442 let db = test_db().await;
443 let raw = db
444 .create_oauth_state("google", "https://example.com/cb", "v", None, None)
445 .await
446 .expect("create");
447 let info = db
448 .validate_oauth_state(&raw)
449 .await
450 .expect("validate")
451 .unwrap();
452 assert!(info.post_login_redirect.is_none());
453 }
454
455 #[tokio::test]
458 async fn create_oauth_user_creates_user_without_password() {
459 let db = test_db().await;
460 let email = Email::new("oauth@example.com".into()).unwrap();
461 let user = db
462 .create_oauth_user(email, "google", "gid-123")
463 .await
464 .expect("create oauth user");
465 assert!(user.password_hash.is_none());
466 assert_eq!(user.email.as_str(), "oauth@example.com");
467 }
468
469 #[tokio::test]
470 async fn create_oauth_user_creates_linked_account() {
471 let db = test_db().await;
472 let email = Email::new("linked@example.com".into()).unwrap();
473 let user = db
474 .create_oauth_user(email, "google", "gid-456")
475 .await
476 .expect("create");
477 let found = db
478 .find_user_by_oauth("google", "gid-456")
479 .await
480 .expect("find");
481 assert!(found.is_some());
482 assert_eq!(found.unwrap().id, user.id);
483 }
484
485 #[tokio::test]
486 async fn create_oauth_user_conflict_on_duplicate_email() {
487 let db = test_db().await;
488 let email = Email::new("dup@example.com".into()).unwrap();
489 db.create_user(email.clone(), "password123", None, None)
490 .await
491 .expect("create password user");
492 let result = db.create_oauth_user(email, "google", "gid-789").await;
493 assert!(matches!(result, Err(AuthError::Conflict(_))));
494 }
495
496 #[tokio::test]
497 async fn link_oauth_account_links_to_existing_user() {
498 let db = test_db().await;
499 let email = Email::new("link@example.com".into()).unwrap();
500 let user = db
501 .create_user(email, "password123", None, None)
502 .await
503 .expect("create user");
504 db.link_oauth_account(user.id, "github", "gh-111", "link@example.com")
505 .await
506 .expect("link");
507 let found = db
508 .find_user_by_oauth("github", "gh-111")
509 .await
510 .expect("find");
511 assert!(found.is_some());
512 assert_eq!(found.unwrap().id, user.id);
513 }
514
515 #[tokio::test]
516 async fn link_oauth_account_conflict_on_duplicate_provider_id() {
517 let db = test_db().await;
518 let email = Email::new("duplink@example.com".into()).unwrap();
519 let user = db
520 .create_user(email, "password123", None, None)
521 .await
522 .expect("create");
523 db.link_oauth_account(user.id, "github", "gh-dup", "duplink@example.com")
524 .await
525 .expect("first link");
526 let result = db
527 .link_oauth_account(user.id, "github", "gh-dup", "duplink@example.com")
528 .await;
529 assert!(matches!(result, Err(AuthError::Conflict(_))));
530 }
531
532 #[tokio::test]
533 async fn find_user_by_oauth_returns_none_when_not_linked() {
534 let db = test_db().await;
535 let result = db
536 .find_user_by_oauth("github", "nonexistent")
537 .await
538 .expect("find");
539 assert!(result.is_none());
540 }
541
542 #[tokio::test]
543 async fn find_user_by_oauth_does_not_return_password_hash() {
544 let db = test_db().await;
545 let email = Email::new("nopw@example.com".into()).unwrap();
546 db.create_oauth_user(email, "google", "gid-nopw")
547 .await
548 .expect("create");
549 let user = db
550 .find_user_by_oauth("google", "gid-nopw")
551 .await
552 .expect("find")
553 .unwrap();
554 assert!(user.password_hash.is_none());
555 }
556
557 #[tokio::test]
560 async fn validate_state_preserves_linking_user_id() {
561 let db = test_db().await;
562 let user_id = UserId::new();
563 let raw = db
564 .create_oauth_state("google", "https://example.com/cb", "v", None, Some(user_id))
565 .await
566 .expect("create");
567 let info = db
568 .validate_oauth_state(&raw)
569 .await
570 .expect("validate")
571 .unwrap();
572 assert_eq!(info.linking_user_id, Some(user_id));
573 }
574
575 #[tokio::test]
576 async fn validate_state_linking_user_id_is_none_for_login_flow() {
577 let db = test_db().await;
578 let raw = db
579 .create_oauth_state("google", "https://example.com/cb", "v", None, None)
580 .await
581 .expect("create");
582 let info = db
583 .validate_oauth_state(&raw)
584 .await
585 .expect("validate")
586 .unwrap();
587 assert!(info.linking_user_id.is_none());
588 }
589
590 #[tokio::test]
593 async fn get_user_oauth_accounts_returns_linked_providers() {
594 let db = test_db().await;
595 let email = Email::new("accts@example.com".into()).unwrap();
596 let user = db
597 .create_user(email, "password123", None, None)
598 .await
599 .expect("create");
600 db.link_oauth_account(user.id, "google", "g-1", "accts@example.com")
601 .await
602 .expect("link google");
603 db.link_oauth_account(user.id, "github", "gh-1", "accts@example.com")
604 .await
605 .expect("link github");
606
607 let accounts = db
608 .get_user_oauth_accounts(user.id)
609 .await
610 .expect("list accounts");
611 assert_eq!(accounts.len(), 2);
612 let providers: Vec<&str> = accounts.iter().map(|a| a.provider.as_str()).collect();
613 assert!(providers.contains(&"google"));
614 assert!(providers.contains(&"github"));
615 }
616
617 #[tokio::test]
618 async fn get_user_oauth_accounts_returns_empty_for_no_links() {
619 let db = test_db().await;
620 let email = Email::new("nolinks@example.com".into()).unwrap();
621 let user = db
622 .create_user(email, "password123", None, None)
623 .await
624 .expect("create");
625
626 let accounts = db
627 .get_user_oauth_accounts(user.id)
628 .await
629 .expect("list accounts");
630 assert!(accounts.is_empty());
631 }
632
633 #[tokio::test]
636 async fn unlink_oauth_account_removes_link() {
637 let db = test_db().await;
638 let email = Email::new("unlink@example.com".into()).unwrap();
639 let user = db
640 .create_user(email, "password123", None, None)
641 .await
642 .expect("create");
643 db.link_oauth_account(user.id, "google", "g-unlink", "unlink@example.com")
644 .await
645 .expect("link");
646
647 let removed = db
648 .unlink_oauth_account(user.id, "google")
649 .await
650 .expect("unlink");
651 assert!(removed, "should return true when row deleted");
652
653 let found = db
654 .find_user_by_oauth("google", "g-unlink")
655 .await
656 .expect("find");
657 assert!(found.is_none(), "link should be gone");
658 }
659
660 #[tokio::test]
661 async fn unlink_oauth_account_returns_false_when_not_linked() {
662 let db = test_db().await;
663 let email = Email::new("notlinked@example.com".into()).unwrap();
664 let user = db
665 .create_user(email, "password123", None, None)
666 .await
667 .expect("create");
668
669 let removed = db
670 .unlink_oauth_account(user.id, "google")
671 .await
672 .expect("unlink");
673 assert!(!removed, "should return false when nothing deleted");
674 }
675}