Skip to main content

rs_auth_postgres/
account.rs

1use async_trait::async_trait;
2use rs_auth_core::error::AuthError;
3use rs_auth_core::store::AccountStore;
4use rs_auth_core::types::{Account, NewAccount};
5use sqlx::Row;
6use time::OffsetDateTime;
7
8use crate::db::AuthDb;
9
10#[async_trait]
11impl AccountStore for AuthDb {
12    async fn create_account(&self, account: NewAccount) -> Result<Account, AuthError> {
13        let row = sqlx::query(
14            r#"
15            INSERT INTO accounts (user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope)
16            VALUES ($1, $2, $3, $4, $5, $6, $7)
17            RETURNING id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
18            "#,
19        )
20        .bind(account.user_id)
21        .bind(&account.provider_id)
22        .bind(&account.account_id)
23        .bind(&account.access_token)
24        .bind(&account.refresh_token)
25        .bind(account.access_token_expires_at)
26        .bind(&account.scope)
27        .fetch_one(&self.pool)
28        .await
29        .map_err(|e| AuthError::Store(e.to_string()))?;
30
31        Ok(Account {
32            id: row.get("id"),
33            user_id: row.get("user_id"),
34            provider_id: row.get("provider_id"),
35            account_id: row.get("account_id"),
36            access_token: row.get("access_token"),
37            refresh_token: row.get("refresh_token"),
38            access_token_expires_at: row.get("access_token_expires_at"),
39            scope: row.get("scope"),
40            created_at: row.get("created_at"),
41            updated_at: row.get("updated_at"),
42        })
43    }
44
45    async fn find_by_provider(
46        &self,
47        provider_id: &str,
48        account_id: &str,
49    ) -> Result<Option<Account>, AuthError> {
50        let row = sqlx::query(
51            r#"
52            SELECT id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
53            FROM accounts
54            WHERE provider_id = $1 AND account_id = $2
55            "#,
56        )
57        .bind(provider_id)
58        .bind(account_id)
59        .fetch_optional(&self.pool)
60        .await
61        .map_err(|e| AuthError::Store(e.to_string()))?;
62
63        Ok(row.map(|row| Account {
64            id: row.get("id"),
65            user_id: row.get("user_id"),
66            provider_id: row.get("provider_id"),
67            account_id: row.get("account_id"),
68            access_token: row.get("access_token"),
69            refresh_token: row.get("refresh_token"),
70            access_token_expires_at: row.get("access_token_expires_at"),
71            scope: row.get("scope"),
72            created_at: row.get("created_at"),
73            updated_at: row.get("updated_at"),
74        }))
75    }
76
77    async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Account>, AuthError> {
78        let rows = sqlx::query(
79            r#"
80            SELECT id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
81            FROM accounts
82            WHERE user_id = $1
83            ORDER BY created_at DESC
84            "#,
85        )
86        .bind(user_id)
87        .fetch_all(&self.pool)
88        .await
89        .map_err(|e| AuthError::Store(e.to_string()))?;
90
91        Ok(rows
92            .into_iter()
93            .map(|row| Account {
94                id: row.get("id"),
95                user_id: row.get("user_id"),
96                provider_id: row.get("provider_id"),
97                account_id: row.get("account_id"),
98                access_token: row.get("access_token"),
99                refresh_token: row.get("refresh_token"),
100                access_token_expires_at: row.get("access_token_expires_at"),
101                scope: row.get("scope"),
102                created_at: row.get("created_at"),
103                updated_at: row.get("updated_at"),
104            })
105            .collect())
106    }
107
108    async fn delete_account(&self, id: i64) -> Result<(), AuthError> {
109        sqlx::query(r#"DELETE FROM accounts WHERE id = $1"#)
110            .bind(id)
111            .execute(&self.pool)
112            .await
113            .map_err(|e| AuthError::Store(e.to_string()))?;
114        Ok(())
115    }
116
117    async fn update_account(
118        &self,
119        id: i64,
120        access_token: Option<String>,
121        refresh_token: Option<String>,
122        access_token_expires_at: Option<OffsetDateTime>,
123        scope: Option<String>,
124    ) -> Result<(), AuthError> {
125        sqlx::query(
126            r#"
127            UPDATE accounts
128            SET access_token = $2, refresh_token = $3,
129                access_token_expires_at = $4, scope = $5, updated_at = now()
130            WHERE id = $1
131            "#,
132        )
133        .bind(id)
134        .bind(access_token)
135        .bind(refresh_token)
136        .bind(access_token_expires_at)
137        .bind(scope)
138        .execute(&self.pool)
139        .await
140        .map_err(|e| AuthError::Store(e.to_string()))?;
141        Ok(())
142    }
143}