Skip to main content

allowthem_core/
oauth.rs

1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Duration, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use serde::Serialize;
6use sha2::{Digest, Sha256};
7
8use crate::auth_client::AuthFuture;
9use crate::db::Db;
10use crate::error::AuthError;
11use crate::types::{Email, OAuthAccountId, OAuthStateId, User, UserId};
12use crate::users::map_unique_violation;
13
14// ---------------------------------------------------------------------------
15// Types
16// ---------------------------------------------------------------------------
17
18/// Information returned by a provider after fetching user info.
19#[derive(Debug, Clone)]
20pub struct OAuthUserInfo {
21    pub provider_user_id: String,
22    pub email: String,
23    pub email_verified: bool,
24    pub name: Option<String>,
25}
26
27/// Stored state returned when validating an OAuth callback.
28#[derive(Debug, Clone, sqlx::FromRow)]
29pub struct OAuthStateInfo {
30    pub provider: String,
31    pub redirect_uri: String,
32    pub pkce_verifier: String,
33    pub post_login_redirect: Option<String>,
34    /// Non-null for the link flow: the authenticated user that initiated linking.
35    /// Null for the standard login/register flow.
36    pub linking_user_id: Option<UserId>,
37}
38
39/// A linked OAuth account for a user — provider name, provider user id, and email.
40#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
41pub struct OAuthAccountInfo {
42    pub provider: String,
43    pub provider_user_id: String,
44    pub email: String,
45    pub created_at: DateTime<Utc>,
46}
47
48// ---------------------------------------------------------------------------
49// OAuthProvider trait
50// ---------------------------------------------------------------------------
51
52/// Abstraction over an OAuth2 authorization code flow provider.
53///
54/// Each provider (Google, GitHub, etc.) implements this trait. The server
55/// crate stores providers in a `HashMap<String, Box<dyn OAuthProvider>>`
56/// keyed by provider name.
57pub trait OAuthProvider: Send + Sync {
58    /// Provider name, lowercase. Used as the URL path segment and the
59    /// `provider` column in `oauth_accounts`.
60    fn name(&self) -> &str;
61
62    /// Build the authorization URL the user should be redirected to.
63    fn authorize_url(&self, redirect_uri: &str, state: &str, pkce_challenge: &str) -> String;
64
65    /// Exchange an authorization code for an access token.
66    fn exchange_code<'a>(
67        &'a self,
68        code: &'a str,
69        redirect_uri: &'a str,
70        pkce_verifier: &'a str,
71    ) -> AuthFuture<'a, String>;
72
73    /// Fetch user information from the provider using the access token.
74    fn user_info<'a>(&'a self, access_token: &'a str) -> AuthFuture<'a, OAuthUserInfo>;
75}
76
77// ---------------------------------------------------------------------------
78// PKCE utilities
79// ---------------------------------------------------------------------------
80
81/// Generate a random PKCE code verifier (43 chars, base64url-unpadded).
82pub fn generate_pkce_verifier() -> String {
83    let mut bytes = [0u8; 32];
84    OsRng
85        .try_fill_bytes(&mut bytes)
86        .expect("OS RNG unavailable");
87    Base64UrlUnpadded::encode_string(&bytes)
88}
89
90/// Derive the S256 PKCE code challenge from a verifier.
91///
92/// `code_challenge = BASE64URL(SHA256(code_verifier))`
93pub fn pkce_challenge(verifier: &str) -> String {
94    let digest = Sha256::digest(verifier.as_bytes());
95    Base64UrlUnpadded::encode_string(&digest)
96}
97
98// ---------------------------------------------------------------------------
99// State helpers (private)
100// ---------------------------------------------------------------------------
101
102/// Generate a random state parameter (43 chars, base64url-unpadded).
103fn generate_state() -> String {
104    let mut bytes = [0u8; 32];
105    OsRng
106        .try_fill_bytes(&mut bytes)
107        .expect("OS RNG unavailable");
108    Base64UrlUnpadded::encode_string(&bytes)
109}
110
111/// SHA-256 hex hash of a raw state string.
112fn hash_state(raw: &str) -> String {
113    let digest = Sha256::digest(raw.as_bytes());
114    format!("{digest:x}")
115}
116
117// ---------------------------------------------------------------------------
118// Db methods — OAuth state
119// ---------------------------------------------------------------------------
120
121impl Db {
122    /// Create an OAuth state record. Returns the raw state value (for the authorize URL).
123    ///
124    /// `linking_user_id` is `Some` when initiating the account-linking flow (the user is
125    /// already authenticated and wants to add a provider). It is `None` for the standard
126    /// login/register flow.
127    pub async fn create_oauth_state(
128        &self,
129        provider: &str,
130        redirect_uri: &str,
131        pkce_verifier: &str,
132        post_login_redirect: Option<&str>,
133        linking_user_id: Option<UserId>,
134    ) -> Result<String, AuthError> {
135        let raw_state = generate_state();
136        let state_hash = hash_state(&raw_state);
137        let id = OAuthStateId::new();
138        let expires_at = (Utc::now() + Duration::minutes(10))
139            .format("%Y-%m-%dT%H:%M:%S%.3fZ")
140            .to_string();
141
142        sqlx::query(
143            "INSERT INTO allowthem_oauth_states \
144             (id, state_hash, provider, redirect_uri, pkce_verifier, post_login_redirect, expires_at, linking_user_id) \
145             VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
146        )
147        .bind(id)
148        .bind(&state_hash)
149        .bind(provider)
150        .bind(redirect_uri)
151        .bind(pkce_verifier)
152        .bind(post_login_redirect)
153        .bind(&expires_at)
154        .bind(linking_user_id)
155        .execute(self.pool())
156        .await?;
157
158        Ok(raw_state)
159    }
160
161    /// Validate and consume an OAuth state. Returns the stored info
162    /// or None if invalid/expired. Atomically deletes to prevent reuse.
163    pub async fn validate_oauth_state(
164        &self,
165        raw_state: &str,
166    ) -> Result<Option<OAuthStateInfo>, AuthError> {
167        let state_hash = hash_state(raw_state);
168        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
169
170        sqlx::query_as::<_, OAuthStateInfo>(
171            "DELETE FROM allowthem_oauth_states \
172             WHERE state_hash = ? AND expires_at > ? \
173             RETURNING provider, redirect_uri, pkce_verifier, post_login_redirect, linking_user_id",
174        )
175        .bind(&state_hash)
176        .bind(&now)
177        .fetch_optional(self.pool())
178        .await
179        .map_err(AuthError::Database)
180    }
181
182    // -----------------------------------------------------------------------
183    // Db methods — OAuth users and accounts
184    // -----------------------------------------------------------------------
185
186    /// Create a user via OAuth -- no password.
187    ///
188    /// Creates the user (password_hash = NULL) and the oauth_accounts row
189    /// in a single transaction. Returns the created User.
190    pub async fn create_oauth_user(
191        &self,
192        email: Email,
193        provider: &str,
194        provider_user_id: &str,
195    ) -> Result<User, AuthError> {
196        let user_id = UserId::new();
197        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
198
199        let mut tx = self.pool().begin().await.map_err(AuthError::Database)?;
200
201        sqlx::query(
202            "INSERT INTO allowthem_users \
203             (id, email, username, password_hash, email_verified, is_active, created_at, updated_at) \
204             VALUES (?, ?, NULL, NULL, 0, 1, ?, ?)",
205        )
206        .bind(user_id)
207        .bind(&email)
208        .bind(&now)
209        .bind(&now)
210        .execute(&mut *tx)
211        .await
212        .map_err(map_unique_violation)?;
213
214        sqlx::query(
215            "INSERT INTO allowthem_oauth_accounts \
216             (id, user_id, provider, provider_user_id, email, created_at) \
217             VALUES (?, ?, ?, ?, ?, ?)",
218        )
219        .bind(OAuthAccountId::new())
220        .bind(user_id)
221        .bind(provider)
222        .bind(provider_user_id)
223        .bind(email.as_str())
224        .bind(&now)
225        .execute(&mut *tx)
226        .await
227        .map_err(map_unique_violation)?;
228
229        tx.commit().await.map_err(AuthError::Database)?;
230
231        self.get_user(user_id).await
232    }
233
234    /// Link an OAuth identity to an existing user.
235    pub async fn link_oauth_account(
236        &self,
237        user_id: UserId,
238        provider: &str,
239        provider_user_id: &str,
240        email: &str,
241    ) -> Result<(), AuthError> {
242        sqlx::query(
243            "INSERT INTO allowthem_oauth_accounts \
244             (id, user_id, provider, provider_user_id, email) \
245             VALUES (?, ?, ?, ?, ?)",
246        )
247        .bind(OAuthAccountId::new())
248        .bind(user_id)
249        .bind(provider)
250        .bind(provider_user_id)
251        .bind(email)
252        .execute(self.pool())
253        .await
254        .map_err(map_unique_violation)?;
255
256        Ok(())
257    }
258
259    /// Find an allowthem user by provider + provider_user_id.
260    pub async fn find_user_by_oauth(
261        &self,
262        provider: &str,
263        provider_user_id: &str,
264    ) -> Result<Option<User>, AuthError> {
265        sqlx::query_as::<_, User>(
266            "SELECT u.id, u.email, u.username, NULL as password_hash, \
267             u.email_verified, u.is_active, u.created_at, u.updated_at, u.custom_data \
268             FROM allowthem_users u \
269             INNER JOIN allowthem_oauth_accounts oa ON oa.user_id = u.id \
270             WHERE oa.provider = ? AND oa.provider_user_id = ?",
271        )
272        .bind(provider)
273        .bind(provider_user_id)
274        .fetch_optional(self.pool())
275        .await
276        .map_err(AuthError::Database)
277    }
278
279    /// List all OAuth accounts linked to a user.
280    pub async fn get_user_oauth_accounts(
281        &self,
282        user_id: UserId,
283    ) -> Result<Vec<OAuthAccountInfo>, AuthError> {
284        sqlx::query_as::<_, OAuthAccountInfo>(
285            "SELECT provider, provider_user_id, email, created_at \
286             FROM allowthem_oauth_accounts \
287             WHERE user_id = ? \
288             ORDER BY created_at ASC",
289        )
290        .bind(user_id)
291        .fetch_all(self.pool())
292        .await
293        .map_err(AuthError::Database)
294    }
295
296    /// Remove an OAuth account link for a user + provider.
297    ///
298    /// Returns `true` if a row was deleted, `false` if no link existed.
299    pub async fn unlink_oauth_account(
300        &self,
301        user_id: UserId,
302        provider: &str,
303    ) -> Result<bool, AuthError> {
304        let result =
305            sqlx::query("DELETE FROM allowthem_oauth_accounts WHERE user_id = ? AND provider = ?")
306                .bind(user_id)
307                .bind(provider)
308                .execute(self.pool())
309                .await
310                .map_err(AuthError::Database)?;
311
312        Ok(result.rows_affected() > 0)
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::db::Db;
320
321    async fn test_db() -> Db {
322        Db::connect("sqlite::memory:").await.expect("in-memory db")
323    }
324
325    // --- PKCE tests ---
326
327    #[test]
328    fn pkce_verifier_is_43_chars() {
329        let v = generate_pkce_verifier();
330        assert_eq!(v.len(), 43);
331    }
332
333    #[test]
334    fn pkce_challenge_is_deterministic() {
335        let v = generate_pkce_verifier();
336        let c1 = pkce_challenge(&v);
337        let c2 = pkce_challenge(&v);
338        assert_eq!(c1, c2);
339    }
340
341    #[test]
342    fn pkce_challenge_is_base64url() {
343        let v = generate_pkce_verifier();
344        let c = pkce_challenge(&v);
345        assert!(!c.contains('+'), "must not contain +");
346        assert!(!c.contains('/'), "must not contain /");
347        assert!(!c.contains('='), "must not contain =");
348    }
349
350    #[test]
351    fn pkce_challenge_differs_from_verifier() {
352        let v = generate_pkce_verifier();
353        let c = pkce_challenge(&v);
354        assert_ne!(v, c);
355    }
356
357    // --- State lifecycle tests ---
358
359    #[tokio::test]
360    async fn create_state_returns_nonempty_string() {
361        let db = test_db().await;
362        let state = db
363            .create_oauth_state(
364                "google",
365                "https://example.com/callback",
366                "verifier123",
367                None,
368                None,
369            )
370            .await
371            .expect("create state");
372        assert!(!state.is_empty());
373    }
374
375    #[tokio::test]
376    async fn validate_state_returns_info_for_valid_state() {
377        let db = test_db().await;
378        let raw = db
379            .create_oauth_state(
380                "google",
381                "https://example.com/cb",
382                "my-verifier",
383                None,
384                None,
385            )
386            .await
387            .expect("create");
388        let info = db.validate_oauth_state(&raw).await.expect("validate");
389        assert!(info.is_some());
390        let info = info.unwrap();
391        assert_eq!(info.provider, "google");
392        assert_eq!(info.redirect_uri, "https://example.com/cb");
393        assert_eq!(info.pkce_verifier, "my-verifier");
394    }
395
396    #[tokio::test]
397    async fn validate_state_is_single_use() {
398        let db = test_db().await;
399        let raw = db
400            .create_oauth_state("github", "https://example.com/cb", "v", None, None)
401            .await
402            .expect("create");
403        let first = db.validate_oauth_state(&raw).await.expect("first");
404        assert!(first.is_some());
405        let second = db.validate_oauth_state(&raw).await.expect("second");
406        assert!(second.is_none(), "state must be single-use");
407    }
408
409    #[tokio::test]
410    async fn validate_state_returns_none_for_garbage() {
411        let db = test_db().await;
412        let result = db
413            .validate_oauth_state("not-a-real-state")
414            .await
415            .expect("validate");
416        assert!(result.is_none());
417    }
418
419    #[tokio::test]
420    async fn validate_state_preserves_post_login_redirect() {
421        let db = test_db().await;
422        let raw = db
423            .create_oauth_state(
424                "google",
425                "https://example.com/cb",
426                "v",
427                Some("/settings"),
428                None,
429            )
430            .await
431            .expect("create");
432        let info = db
433            .validate_oauth_state(&raw)
434            .await
435            .expect("validate")
436            .unwrap();
437        assert_eq!(info.post_login_redirect.as_deref(), Some("/settings"));
438    }
439
440    #[tokio::test]
441    async fn validate_state_returns_none_for_post_login_redirect_when_not_set() {
442        let db = test_db().await;
443        let raw = db
444            .create_oauth_state("google", "https://example.com/cb", "v", None, None)
445            .await
446            .expect("create");
447        let info = db
448            .validate_oauth_state(&raw)
449            .await
450            .expect("validate")
451            .unwrap();
452        assert!(info.post_login_redirect.is_none());
453    }
454
455    // --- OAuth user tests ---
456
457    #[tokio::test]
458    async fn create_oauth_user_creates_user_without_password() {
459        let db = test_db().await;
460        let email = Email::new("oauth@example.com".into()).unwrap();
461        let user = db
462            .create_oauth_user(email, "google", "gid-123")
463            .await
464            .expect("create oauth user");
465        assert!(user.password_hash.is_none());
466        assert_eq!(user.email.as_str(), "oauth@example.com");
467    }
468
469    #[tokio::test]
470    async fn create_oauth_user_creates_linked_account() {
471        let db = test_db().await;
472        let email = Email::new("linked@example.com".into()).unwrap();
473        let user = db
474            .create_oauth_user(email, "google", "gid-456")
475            .await
476            .expect("create");
477        let found = db
478            .find_user_by_oauth("google", "gid-456")
479            .await
480            .expect("find");
481        assert!(found.is_some());
482        assert_eq!(found.unwrap().id, user.id);
483    }
484
485    #[tokio::test]
486    async fn create_oauth_user_conflict_on_duplicate_email() {
487        let db = test_db().await;
488        let email = Email::new("dup@example.com".into()).unwrap();
489        db.create_user(email.clone(), "password123", None, None)
490            .await
491            .expect("create password user");
492        let result = db.create_oauth_user(email, "google", "gid-789").await;
493        assert!(matches!(result, Err(AuthError::Conflict(_))));
494    }
495
496    #[tokio::test]
497    async fn link_oauth_account_links_to_existing_user() {
498        let db = test_db().await;
499        let email = Email::new("link@example.com".into()).unwrap();
500        let user = db
501            .create_user(email, "password123", None, None)
502            .await
503            .expect("create user");
504        db.link_oauth_account(user.id, "github", "gh-111", "link@example.com")
505            .await
506            .expect("link");
507        let found = db
508            .find_user_by_oauth("github", "gh-111")
509            .await
510            .expect("find");
511        assert!(found.is_some());
512        assert_eq!(found.unwrap().id, user.id);
513    }
514
515    #[tokio::test]
516    async fn link_oauth_account_conflict_on_duplicate_provider_id() {
517        let db = test_db().await;
518        let email = Email::new("duplink@example.com".into()).unwrap();
519        let user = db
520            .create_user(email, "password123", None, None)
521            .await
522            .expect("create");
523        db.link_oauth_account(user.id, "github", "gh-dup", "duplink@example.com")
524            .await
525            .expect("first link");
526        let result = db
527            .link_oauth_account(user.id, "github", "gh-dup", "duplink@example.com")
528            .await;
529        assert!(matches!(result, Err(AuthError::Conflict(_))));
530    }
531
532    #[tokio::test]
533    async fn find_user_by_oauth_returns_none_when_not_linked() {
534        let db = test_db().await;
535        let result = db
536            .find_user_by_oauth("github", "nonexistent")
537            .await
538            .expect("find");
539        assert!(result.is_none());
540    }
541
542    #[tokio::test]
543    async fn find_user_by_oauth_does_not_return_password_hash() {
544        let db = test_db().await;
545        let email = Email::new("nopw@example.com".into()).unwrap();
546        db.create_oauth_user(email, "google", "gid-nopw")
547            .await
548            .expect("create");
549        let user = db
550            .find_user_by_oauth("google", "gid-nopw")
551            .await
552            .expect("find")
553            .unwrap();
554        assert!(user.password_hash.is_none());
555    }
556
557    // --- linking_user_id state tests ---
558
559    #[tokio::test]
560    async fn validate_state_preserves_linking_user_id() {
561        let db = test_db().await;
562        let user_id = UserId::new();
563        let raw = db
564            .create_oauth_state("google", "https://example.com/cb", "v", None, Some(user_id))
565            .await
566            .expect("create");
567        let info = db
568            .validate_oauth_state(&raw)
569            .await
570            .expect("validate")
571            .unwrap();
572        assert_eq!(info.linking_user_id, Some(user_id));
573    }
574
575    #[tokio::test]
576    async fn validate_state_linking_user_id_is_none_for_login_flow() {
577        let db = test_db().await;
578        let raw = db
579            .create_oauth_state("google", "https://example.com/cb", "v", None, None)
580            .await
581            .expect("create");
582        let info = db
583            .validate_oauth_state(&raw)
584            .await
585            .expect("validate")
586            .unwrap();
587        assert!(info.linking_user_id.is_none());
588    }
589
590    // --- get_user_oauth_accounts tests ---
591
592    #[tokio::test]
593    async fn get_user_oauth_accounts_returns_linked_providers() {
594        let db = test_db().await;
595        let email = Email::new("accts@example.com".into()).unwrap();
596        let user = db
597            .create_user(email, "password123", None, None)
598            .await
599            .expect("create");
600        db.link_oauth_account(user.id, "google", "g-1", "accts@example.com")
601            .await
602            .expect("link google");
603        db.link_oauth_account(user.id, "github", "gh-1", "accts@example.com")
604            .await
605            .expect("link github");
606
607        let accounts = db
608            .get_user_oauth_accounts(user.id)
609            .await
610            .expect("list accounts");
611        assert_eq!(accounts.len(), 2);
612        let providers: Vec<&str> = accounts.iter().map(|a| a.provider.as_str()).collect();
613        assert!(providers.contains(&"google"));
614        assert!(providers.contains(&"github"));
615    }
616
617    #[tokio::test]
618    async fn get_user_oauth_accounts_returns_empty_for_no_links() {
619        let db = test_db().await;
620        let email = Email::new("nolinks@example.com".into()).unwrap();
621        let user = db
622            .create_user(email, "password123", None, None)
623            .await
624            .expect("create");
625
626        let accounts = db
627            .get_user_oauth_accounts(user.id)
628            .await
629            .expect("list accounts");
630        assert!(accounts.is_empty());
631    }
632
633    // --- unlink_oauth_account tests ---
634
635    #[tokio::test]
636    async fn unlink_oauth_account_removes_link() {
637        let db = test_db().await;
638        let email = Email::new("unlink@example.com".into()).unwrap();
639        let user = db
640            .create_user(email, "password123", None, None)
641            .await
642            .expect("create");
643        db.link_oauth_account(user.id, "google", "g-unlink", "unlink@example.com")
644            .await
645            .expect("link");
646
647        let removed = db
648            .unlink_oauth_account(user.id, "google")
649            .await
650            .expect("unlink");
651        assert!(removed, "should return true when row deleted");
652
653        let found = db
654            .find_user_by_oauth("google", "g-unlink")
655            .await
656            .expect("find");
657        assert!(found.is_none(), "link should be gone");
658    }
659
660    #[tokio::test]
661    async fn unlink_oauth_account_returns_false_when_not_linked() {
662        let db = test_db().await;
663        let email = Email::new("notlinked@example.com".into()).unwrap();
664        let user = db
665            .create_user(email, "password123", None, None)
666            .await
667            .expect("create");
668
669        let removed = db
670            .unlink_oauth_account(user.id, "google")
671            .await
672            .expect("unlink");
673        assert!(!removed, "should return false when nothing deleted");
674    }
675}