1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use sqlx::{FromRow, PgPool};
9
10use super::types::{OAuthProvider, OAuthUserInfo};
11
12#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
17pub struct OAuthAccount {
18 pub id: i64,
20 pub user_id: i64,
22 #[sqlx(try_from = "String")]
24 pub provider: OAuthProvider,
25 pub provider_user_id: String,
27 pub email: String,
29 pub name: Option<String>,
31 pub avatar_url: Option<String>,
33 pub created_at: DateTime<Utc>,
35 pub updated_at: DateTime<Utc>,
37}
38
39impl OAuthAccount {
40 pub async fn find_by_provider(
46 pool: &PgPool,
47 provider: OAuthProvider,
48 provider_user_id: &str,
49 ) -> Result<Option<Self>, sqlx::Error> {
50 sqlx::query_as::<_, Self>(
51 r"
52 SELECT id, user_id, provider, provider_user_id, email, name, avatar_url,
53 created_at, updated_at
54 FROM oauth_accounts
55 WHERE provider = $1 AND provider_user_id = $2
56 ",
57 )
58 .bind(provider.as_str())
59 .bind(provider_user_id)
60 .fetch_optional(pool)
61 .await
62 }
63
64 pub async fn find_by_user_id(pool: &PgPool, user_id: i64) -> Result<Vec<Self>, sqlx::Error> {
70 sqlx::query_as::<_, Self>(
71 r"
72 SELECT id, user_id, provider, provider_user_id, email, name, avatar_url,
73 created_at, updated_at
74 FROM oauth_accounts
75 WHERE user_id = $1
76 ORDER BY created_at DESC
77 ",
78 )
79 .bind(user_id)
80 .fetch_all(pool)
81 .await
82 }
83
84 pub async fn link_account(
91 pool: &PgPool,
92 user_id: i64,
93 provider: OAuthProvider,
94 user_info: &OAuthUserInfo,
95 ) -> Result<Self, sqlx::Error> {
96 sqlx::query_as::<_, Self>(
97 r"
98 INSERT INTO oauth_accounts (user_id, provider, provider_user_id, email, name, avatar_url)
99 VALUES ($1, $2, $3, $4, $5, $6)
100 ON CONFLICT (provider, provider_user_id)
101 DO UPDATE SET
102 user_id = EXCLUDED.user_id,
103 email = EXCLUDED.email,
104 name = EXCLUDED.name,
105 avatar_url = EXCLUDED.avatar_url,
106 updated_at = NOW()
107 RETURNING id, user_id, provider, provider_user_id, email, name, avatar_url,
108 created_at, updated_at
109 ",
110 )
111 .bind(user_id)
112 .bind(provider.as_str())
113 .bind(&user_info.provider_user_id)
114 .bind(&user_info.email)
115 .bind(&user_info.name)
116 .bind(&user_info.avatar_url)
117 .fetch_one(pool)
118 .await
119 }
120
121 pub async fn unlink_account(
127 pool: &PgPool,
128 user_id: i64,
129 provider: OAuthProvider,
130 ) -> Result<bool, sqlx::Error> {
131 let result = sqlx::query(
132 r"
133 DELETE FROM oauth_accounts
134 WHERE user_id = $1 AND provider = $2
135 ",
136 )
137 .bind(user_id)
138 .bind(provider.as_str())
139 .execute(pool)
140 .await?;
141
142 Ok(result.rows_affected() > 0)
143 }
144
145 pub async fn update_info(
151 &mut self,
152 pool: &PgPool,
153 user_info: &OAuthUserInfo,
154 ) -> Result<(), sqlx::Error> {
155 let updated = sqlx::query_as::<_, Self>(
156 r"
157 UPDATE oauth_accounts
158 SET email = $1, name = $2, avatar_url = $3, updated_at = NOW()
159 WHERE id = $4
160 RETURNING id, user_id, provider, provider_user_id, email, name, avatar_url,
161 created_at, updated_at
162 ",
163 )
164 .bind(&user_info.email)
165 .bind(&user_info.name)
166 .bind(&user_info.avatar_url)
167 .bind(self.id)
168 .fetch_one(pool)
169 .await?;
170
171 *self = updated;
172 Ok(())
173 }
174
175 pub async fn user_has_oauth_accounts(
181 pool: &PgPool,
182 user_id: i64,
183 ) -> Result<bool, sqlx::Error> {
184 let count: (i64,) = sqlx::query_as(
185 r"
186 SELECT COUNT(*) FROM oauth_accounts WHERE user_id = $1
187 ",
188 )
189 .bind(user_id)
190 .fetch_one(pool)
191 .await?;
192
193 Ok(count.0 > 0)
194 }
195}
196
197impl TryFrom<String> for OAuthProvider {
199 type Error = String;
200
201 fn try_from(value: String) -> Result<Self, Self::Error> {
202 value.parse().map_err(|e: super::types::OAuthError| e.to_string())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn test_oauth_account_serialization() {
212 let account = OAuthAccount {
213 id: 1,
214 user_id: 100,
215 provider: OAuthProvider::Google,
216 provider_user_id: "123456".to_string(),
217 email: "test@gmail.com".to_string(),
218 name: Some("Test User".to_string()),
219 avatar_url: Some("https://example.com/avatar.jpg".to_string()),
220 created_at: Utc::now(),
221 updated_at: Utc::now(),
222 };
223
224 let json = serde_json::to_string(&account).unwrap();
226 assert!(json.contains("123456"));
227 assert!(json.contains("test@gmail.com"));
228
229 let deserialized: OAuthAccount = serde_json::from_str(&json).unwrap();
231 assert_eq!(deserialized.id, 1);
232 assert_eq!(deserialized.user_id, 100);
233 assert_eq!(deserialized.provider, OAuthProvider::Google);
234 assert_eq!(deserialized.provider_user_id, "123456");
235 assert_eq!(deserialized.email, "test@gmail.com");
236 }
237
238 #[test]
239 fn test_oauth_provider_try_from_string() {
240 assert_eq!(
242 OAuthProvider::try_from("google".to_string()).unwrap(),
243 OAuthProvider::Google
244 );
245 assert_eq!(
246 OAuthProvider::try_from("github".to_string()).unwrap(),
247 OAuthProvider::GitHub
248 );
249 assert_eq!(
250 OAuthProvider::try_from("oidc".to_string()).unwrap(),
251 OAuthProvider::Oidc
252 );
253
254 assert!(OAuthProvider::try_from("invalid".to_string()).is_err());
256 }
257
258 #[test]
259 fn test_oauth_account_debug() {
260 let account = OAuthAccount {
261 id: 1,
262 user_id: 100,
263 provider: OAuthProvider::GitHub,
264 provider_user_id: "gh123".to_string(),
265 email: "test@github.com".to_string(),
266 name: None,
267 avatar_url: None,
268 created_at: Utc::now(),
269 updated_at: Utc::now(),
270 };
271
272 let debug_str = format!("{account:?}");
273 assert!(debug_str.contains("OAuthAccount"));
274 assert!(debug_str.contains("GitHub"));
275 assert!(debug_str.contains("gh123"));
276 }
277
278 #[test]
279 fn test_oauth_account_clone() {
280 let account = OAuthAccount {
281 id: 1,
282 user_id: 100,
283 provider: OAuthProvider::Google,
284 provider_user_id: "123456".to_string(),
285 email: "test@gmail.com".to_string(),
286 name: Some("Test User".to_string()),
287 avatar_url: Some("https://example.com/avatar.jpg".to_string()),
288 created_at: Utc::now(),
289 updated_at: Utc::now(),
290 };
291
292 let cloned = account.clone();
293 assert_eq!(cloned.id, account.id);
294 assert_eq!(cloned.user_id, account.user_id);
295 assert_eq!(cloned.provider, account.provider);
296 assert_eq!(cloned.provider_user_id, account.provider_user_id);
297 assert_eq!(cloned.email, account.email);
298 }
299}