1use async_trait::async_trait;
2use rs_auth_core::error::AuthError;
3use rs_auth_core::store::AccountStore;
4use rs_auth_core::types::{Account, NewAccount};
5use sqlx::Row;
6use time::OffsetDateTime;
7
8use crate::db::AuthDb;
9
10#[async_trait]
11impl AccountStore for AuthDb {
12 async fn create_account(&self, account: NewAccount) -> Result<Account, AuthError> {
13 let row = sqlx::query(
14 r#"
15 INSERT INTO accounts (user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope)
16 VALUES ($1, $2, $3, $4, $5, $6, $7)
17 RETURNING id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
18 "#,
19 )
20 .bind(account.user_id)
21 .bind(&account.provider_id)
22 .bind(&account.account_id)
23 .bind(&account.access_token)
24 .bind(&account.refresh_token)
25 .bind(account.access_token_expires_at)
26 .bind(&account.scope)
27 .fetch_one(&self.pool)
28 .await
29 .map_err(|e| AuthError::Store(e.to_string()))?;
30
31 Ok(Account {
32 id: row.get("id"),
33 user_id: row.get("user_id"),
34 provider_id: row.get("provider_id"),
35 account_id: row.get("account_id"),
36 access_token: row.get("access_token"),
37 refresh_token: row.get("refresh_token"),
38 access_token_expires_at: row.get("access_token_expires_at"),
39 scope: row.get("scope"),
40 created_at: row.get("created_at"),
41 updated_at: row.get("updated_at"),
42 })
43 }
44
45 async fn find_by_provider(
46 &self,
47 provider_id: &str,
48 account_id: &str,
49 ) -> Result<Option<Account>, AuthError> {
50 let row = sqlx::query(
51 r#"
52 SELECT id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
53 FROM accounts
54 WHERE provider_id = $1 AND account_id = $2
55 "#,
56 )
57 .bind(provider_id)
58 .bind(account_id)
59 .fetch_optional(&self.pool)
60 .await
61 .map_err(|e| AuthError::Store(e.to_string()))?;
62
63 Ok(row.map(|row| Account {
64 id: row.get("id"),
65 user_id: row.get("user_id"),
66 provider_id: row.get("provider_id"),
67 account_id: row.get("account_id"),
68 access_token: row.get("access_token"),
69 refresh_token: row.get("refresh_token"),
70 access_token_expires_at: row.get("access_token_expires_at"),
71 scope: row.get("scope"),
72 created_at: row.get("created_at"),
73 updated_at: row.get("updated_at"),
74 }))
75 }
76
77 async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Account>, AuthError> {
78 let rows = sqlx::query(
79 r#"
80 SELECT id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
81 FROM accounts
82 WHERE user_id = $1
83 ORDER BY created_at DESC
84 "#,
85 )
86 .bind(user_id)
87 .fetch_all(&self.pool)
88 .await
89 .map_err(|e| AuthError::Store(e.to_string()))?;
90
91 Ok(rows
92 .into_iter()
93 .map(|row| Account {
94 id: row.get("id"),
95 user_id: row.get("user_id"),
96 provider_id: row.get("provider_id"),
97 account_id: row.get("account_id"),
98 access_token: row.get("access_token"),
99 refresh_token: row.get("refresh_token"),
100 access_token_expires_at: row.get("access_token_expires_at"),
101 scope: row.get("scope"),
102 created_at: row.get("created_at"),
103 updated_at: row.get("updated_at"),
104 })
105 .collect())
106 }
107
108 async fn delete_account(&self, id: i64) -> Result<(), AuthError> {
109 sqlx::query(r#"DELETE FROM accounts WHERE id = $1"#)
110 .bind(id)
111 .execute(&self.pool)
112 .await
113 .map_err(|e| AuthError::Store(e.to_string()))?;
114 Ok(())
115 }
116
117 async fn update_account(
118 &self,
119 id: i64,
120 access_token: Option<String>,
121 refresh_token: Option<String>,
122 access_token_expires_at: Option<OffsetDateTime>,
123 scope: Option<String>,
124 ) -> Result<(), AuthError> {
125 sqlx::query(
126 r#"
127 UPDATE accounts
128 SET access_token = $2, refresh_token = $3,
129 access_token_expires_at = $4, scope = $5, updated_at = now()
130 WHERE id = $1
131 "#,
132 )
133 .bind(id)
134 .bind(access_token)
135 .bind(refresh_token)
136 .bind(access_token_expires_at)
137 .bind(scope)
138 .execute(&self.pool)
139 .await
140 .map_err(|e| AuthError::Store(e.to_string()))?;
141 Ok(())
142 }
143}