Skip to main content

rs_auth_postgres/
account.rs

1use async_trait::async_trait;
2use rs_auth_core::error::AuthError;
3use rs_auth_core::store::AccountStore;
4use rs_auth_core::types::{Account, NewAccount};
5use sqlx::Row;
6
7use crate::db::AuthDb;
8
9#[async_trait]
10impl AccountStore for AuthDb {
11    async fn create_account(&self, account: NewAccount) -> Result<Account, AuthError> {
12        let row = sqlx::query(
13            r#"
14            INSERT INTO accounts (user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope)
15            VALUES ($1, $2, $3, $4, $5, $6, $7)
16            RETURNING id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
17            "#,
18        )
19        .bind(account.user_id)
20        .bind(&account.provider_id)
21        .bind(&account.account_id)
22        .bind(&account.access_token)
23        .bind(&account.refresh_token)
24        .bind(account.access_token_expires_at)
25        .bind(&account.scope)
26        .fetch_one(&self.pool)
27        .await
28        .map_err(|e| AuthError::Store(e.to_string()))?;
29
30        Ok(Account {
31            id: row.get("id"),
32            user_id: row.get("user_id"),
33            provider_id: row.get("provider_id"),
34            account_id: row.get("account_id"),
35            access_token: row.get("access_token"),
36            refresh_token: row.get("refresh_token"),
37            access_token_expires_at: row.get("access_token_expires_at"),
38            scope: row.get("scope"),
39            created_at: row.get("created_at"),
40            updated_at: row.get("updated_at"),
41        })
42    }
43
44    async fn find_by_provider(
45        &self,
46        provider_id: &str,
47        account_id: &str,
48    ) -> Result<Option<Account>, AuthError> {
49        let row = sqlx::query(
50            r#"
51            SELECT id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
52            FROM accounts
53            WHERE provider_id = $1 AND account_id = $2
54            "#,
55        )
56        .bind(provider_id)
57        .bind(account_id)
58        .fetch_optional(&self.pool)
59        .await
60        .map_err(|e| AuthError::Store(e.to_string()))?;
61
62        Ok(row.map(|row| Account {
63            id: row.get("id"),
64            user_id: row.get("user_id"),
65            provider_id: row.get("provider_id"),
66            account_id: row.get("account_id"),
67            access_token: row.get("access_token"),
68            refresh_token: row.get("refresh_token"),
69            access_token_expires_at: row.get("access_token_expires_at"),
70            scope: row.get("scope"),
71            created_at: row.get("created_at"),
72            updated_at: row.get("updated_at"),
73        }))
74    }
75
76    async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Account>, AuthError> {
77        let rows = sqlx::query(
78            r#"
79            SELECT id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
80            FROM accounts
81            WHERE user_id = $1
82            ORDER BY created_at DESC
83            "#,
84        )
85        .bind(user_id)
86        .fetch_all(&self.pool)
87        .await
88        .map_err(|e| AuthError::Store(e.to_string()))?;
89
90        Ok(rows
91            .into_iter()
92            .map(|row| Account {
93                id: row.get("id"),
94                user_id: row.get("user_id"),
95                provider_id: row.get("provider_id"),
96                account_id: row.get("account_id"),
97                access_token: row.get("access_token"),
98                refresh_token: row.get("refresh_token"),
99                access_token_expires_at: row.get("access_token_expires_at"),
100                scope: row.get("scope"),
101                created_at: row.get("created_at"),
102                updated_at: row.get("updated_at"),
103            })
104            .collect())
105    }
106
107    async fn delete_account(&self, id: i64) -> Result<(), AuthError> {
108        sqlx::query(r#"DELETE FROM accounts WHERE id = $1"#)
109            .bind(id)
110            .execute(&self.pool)
111            .await
112            .map_err(|e| AuthError::Store(e.to_string()))?;
113        Ok(())
114    }
115}