acton_htmx/oauth2/
models.rs

1//! OAuth2 account database models
2//!
3//! This module provides the `OAuthAccount` model for managing OAuth2 provider
4//! accounts linked to users.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use sqlx::{FromRow, PgPool};
9
10use super::types::{OAuthProvider, OAuthUserInfo};
11
12/// OAuth2 account linked to a user
13///
14/// This represents a connection between a local user account and an OAuth2
15/// provider account (Google, GitHub, or generic OIDC).
16#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
17pub struct OAuthAccount {
18    /// Primary key
19    pub id: i64,
20    /// Local user ID
21    pub user_id: i64,
22    /// OAuth2 provider
23    #[sqlx(try_from = "String")]
24    pub provider: OAuthProvider,
25    /// Provider-specific user ID
26    pub provider_user_id: String,
27    /// Email from OAuth provider
28    pub email: String,
29    /// Display name from OAuth provider
30    pub name: Option<String>,
31    /// Avatar URL from OAuth provider
32    pub avatar_url: Option<String>,
33    /// When the account was linked
34    pub created_at: DateTime<Utc>,
35    /// When the account was last updated
36    pub updated_at: DateTime<Utc>,
37}
38
39impl OAuthAccount {
40    /// Find an OAuth account by provider and provider user ID
41    ///
42    /// # Errors
43    ///
44    /// Returns error if the database query fails
45    pub async fn find_by_provider(
46        pool: &PgPool,
47        provider: OAuthProvider,
48        provider_user_id: &str,
49    ) -> Result<Option<Self>, sqlx::Error> {
50        sqlx::query_as::<_, Self>(
51            r"
52            SELECT id, user_id, provider, provider_user_id, email, name, avatar_url,
53                   created_at, updated_at
54            FROM oauth_accounts
55            WHERE provider = $1 AND provider_user_id = $2
56            ",
57        )
58        .bind(provider.as_str())
59        .bind(provider_user_id)
60        .fetch_optional(pool)
61        .await
62    }
63
64    /// Find all OAuth accounts for a user
65    ///
66    /// # Errors
67    ///
68    /// Returns error if the database query fails
69    pub async fn find_by_user_id(pool: &PgPool, user_id: i64) -> Result<Vec<Self>, sqlx::Error> {
70        sqlx::query_as::<_, Self>(
71            r"
72            SELECT id, user_id, provider, provider_user_id, email, name, avatar_url,
73                   created_at, updated_at
74            FROM oauth_accounts
75            WHERE user_id = $1
76            ORDER BY created_at DESC
77            ",
78        )
79        .bind(user_id)
80        .fetch_all(pool)
81        .await
82    }
83
84    /// Link an OAuth account to a user
85    ///
86    /// # Errors
87    ///
88    /// Returns error if the database query fails or if the OAuth account
89    /// is already linked to a different user
90    pub async fn link_account(
91        pool: &PgPool,
92        user_id: i64,
93        provider: OAuthProvider,
94        user_info: &OAuthUserInfo,
95    ) -> Result<Self, sqlx::Error> {
96        sqlx::query_as::<_, Self>(
97            r"
98            INSERT INTO oauth_accounts (user_id, provider, provider_user_id, email, name, avatar_url)
99            VALUES ($1, $2, $3, $4, $5, $6)
100            ON CONFLICT (provider, provider_user_id)
101            DO UPDATE SET
102                user_id = EXCLUDED.user_id,
103                email = EXCLUDED.email,
104                name = EXCLUDED.name,
105                avatar_url = EXCLUDED.avatar_url,
106                updated_at = NOW()
107            RETURNING id, user_id, provider, provider_user_id, email, name, avatar_url,
108                      created_at, updated_at
109            ",
110        )
111        .bind(user_id)
112        .bind(provider.as_str())
113        .bind(&user_info.provider_user_id)
114        .bind(&user_info.email)
115        .bind(&user_info.name)
116        .bind(&user_info.avatar_url)
117        .fetch_one(pool)
118        .await
119    }
120
121    /// Unlink an OAuth account
122    ///
123    /// # Errors
124    ///
125    /// Returns error if the database query fails
126    pub async fn unlink_account(
127        pool: &PgPool,
128        user_id: i64,
129        provider: OAuthProvider,
130    ) -> Result<bool, sqlx::Error> {
131        let result = sqlx::query(
132            r"
133            DELETE FROM oauth_accounts
134            WHERE user_id = $1 AND provider = $2
135            ",
136        )
137        .bind(user_id)
138        .bind(provider.as_str())
139        .execute(pool)
140        .await?;
141
142        Ok(result.rows_affected() > 0)
143    }
144
145    /// Update OAuth account information
146    ///
147    /// # Errors
148    ///
149    /// Returns error if the database query fails
150    pub async fn update_info(
151        &mut self,
152        pool: &PgPool,
153        user_info: &OAuthUserInfo,
154    ) -> Result<(), sqlx::Error> {
155        let updated = sqlx::query_as::<_, Self>(
156            r"
157            UPDATE oauth_accounts
158            SET email = $1, name = $2, avatar_url = $3, updated_at = NOW()
159            WHERE id = $4
160            RETURNING id, user_id, provider, provider_user_id, email, name, avatar_url,
161                      created_at, updated_at
162            ",
163        )
164        .bind(&user_info.email)
165        .bind(&user_info.name)
166        .bind(&user_info.avatar_url)
167        .bind(self.id)
168        .fetch_one(pool)
169        .await?;
170
171        *self = updated;
172        Ok(())
173    }
174
175    /// Check if a user has any OAuth accounts linked
176    ///
177    /// # Errors
178    ///
179    /// Returns error if the database query fails
180    pub async fn user_has_oauth_accounts(
181        pool: &PgPool,
182        user_id: i64,
183    ) -> Result<bool, sqlx::Error> {
184        let count: (i64,) = sqlx::query_as(
185            r"
186            SELECT COUNT(*) FROM oauth_accounts WHERE user_id = $1
187            ",
188        )
189        .bind(user_id)
190        .fetch_one(pool)
191        .await?;
192
193        Ok(count.0 > 0)
194    }
195}
196
197// SQLx type conversion for OAuthProvider
198impl TryFrom<String> for OAuthProvider {
199    type Error = String;
200
201    fn try_from(value: String) -> Result<Self, Self::Error> {
202        value.parse().map_err(|e: super::types::OAuthError| e.to_string())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn test_oauth_account_serialization() {
212        let account = OAuthAccount {
213            id: 1,
214            user_id: 100,
215            provider: OAuthProvider::Google,
216            provider_user_id: "123456".to_string(),
217            email: "test@gmail.com".to_string(),
218            name: Some("Test User".to_string()),
219            avatar_url: Some("https://example.com/avatar.jpg".to_string()),
220            created_at: Utc::now(),
221            updated_at: Utc::now(),
222        };
223
224        // Test serialization
225        let json = serde_json::to_string(&account).unwrap();
226        assert!(json.contains("123456"));
227        assert!(json.contains("test@gmail.com"));
228
229        // Test deserialization
230        let deserialized: OAuthAccount = serde_json::from_str(&json).unwrap();
231        assert_eq!(deserialized.id, 1);
232        assert_eq!(deserialized.user_id, 100);
233        assert_eq!(deserialized.provider, OAuthProvider::Google);
234        assert_eq!(deserialized.provider_user_id, "123456");
235        assert_eq!(deserialized.email, "test@gmail.com");
236    }
237
238    #[test]
239    fn test_oauth_provider_try_from_string() {
240        // Test valid conversions
241        assert_eq!(
242            OAuthProvider::try_from("google".to_string()).unwrap(),
243            OAuthProvider::Google
244        );
245        assert_eq!(
246            OAuthProvider::try_from("github".to_string()).unwrap(),
247            OAuthProvider::GitHub
248        );
249        assert_eq!(
250            OAuthProvider::try_from("oidc".to_string()).unwrap(),
251            OAuthProvider::Oidc
252        );
253
254        // Test invalid conversion
255        assert!(OAuthProvider::try_from("invalid".to_string()).is_err());
256    }
257
258    #[test]
259    fn test_oauth_account_debug() {
260        let account = OAuthAccount {
261            id: 1,
262            user_id: 100,
263            provider: OAuthProvider::GitHub,
264            provider_user_id: "gh123".to_string(),
265            email: "test@github.com".to_string(),
266            name: None,
267            avatar_url: None,
268            created_at: Utc::now(),
269            updated_at: Utc::now(),
270        };
271
272        let debug_str = format!("{account:?}");
273        assert!(debug_str.contains("OAuthAccount"));
274        assert!(debug_str.contains("GitHub"));
275        assert!(debug_str.contains("gh123"));
276    }
277
278    #[test]
279    fn test_oauth_account_clone() {
280        let account = OAuthAccount {
281            id: 1,
282            user_id: 100,
283            provider: OAuthProvider::Google,
284            provider_user_id: "123456".to_string(),
285            email: "test@gmail.com".to_string(),
286            name: Some("Test User".to_string()),
287            avatar_url: Some("https://example.com/avatar.jpg".to_string()),
288            created_at: Utc::now(),
289            updated_at: Utc::now(),
290        };
291
292        let cloned = account.clone();
293        assert_eq!(cloned.id, account.id);
294        assert_eq!(cloned.user_id, account.user_id);
295        assert_eq!(cloned.provider, account.provider);
296        assert_eq!(cloned.provider_user_id, account.provider_user_id);
297        assert_eq!(cloned.email, account.email);
298    }
299}