use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{FromRow, PgPool};
use super::types::{OAuthProvider, OAuthUserInfo};
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct OAuthAccount {
pub id: i64,
pub user_id: i64,
#[sqlx(try_from = "String")]
pub provider: OAuthProvider,
pub provider_user_id: String,
pub email: String,
pub name: Option<String>,
pub avatar_url: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl OAuthAccount {
pub async fn find_by_provider(
pool: &PgPool,
provider: OAuthProvider,
provider_user_id: &str,
) -> Result<Option<Self>, sqlx::Error> {
sqlx::query_as::<_, Self>(
r"
SELECT id, user_id, provider, provider_user_id, email, name, avatar_url,
created_at, updated_at
FROM oauth_accounts
WHERE provider = $1 AND provider_user_id = $2
",
)
.bind(provider.as_str())
.bind(provider_user_id)
.fetch_optional(pool)
.await
}
pub async fn find_by_user_id(pool: &PgPool, user_id: i64) -> Result<Vec<Self>, sqlx::Error> {
sqlx::query_as::<_, Self>(
r"
SELECT id, user_id, provider, provider_user_id, email, name, avatar_url,
created_at, updated_at
FROM oauth_accounts
WHERE user_id = $1
ORDER BY created_at DESC
",
)
.bind(user_id)
.fetch_all(pool)
.await
}
pub async fn link_account(
pool: &PgPool,
user_id: i64,
provider: OAuthProvider,
user_info: &OAuthUserInfo,
) -> Result<Self, sqlx::Error> {
sqlx::query_as::<_, Self>(
r"
INSERT INTO oauth_accounts (user_id, provider, provider_user_id, email, name, avatar_url)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (provider, provider_user_id)
DO UPDATE SET
user_id = EXCLUDED.user_id,
email = EXCLUDED.email,
name = EXCLUDED.name,
avatar_url = EXCLUDED.avatar_url,
updated_at = NOW()
RETURNING id, user_id, provider, provider_user_id, email, name, avatar_url,
created_at, updated_at
",
)
.bind(user_id)
.bind(provider.as_str())
.bind(&user_info.provider_user_id)
.bind(&user_info.email)
.bind(&user_info.name)
.bind(&user_info.avatar_url)
.fetch_one(pool)
.await
}
pub async fn unlink_account(
pool: &PgPool,
user_id: i64,
provider: OAuthProvider,
) -> Result<bool, sqlx::Error> {
let result = sqlx::query(
r"
DELETE FROM oauth_accounts
WHERE user_id = $1 AND provider = $2
",
)
.bind(user_id)
.bind(provider.as_str())
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}
pub async fn update_info(
&mut self,
pool: &PgPool,
user_info: &OAuthUserInfo,
) -> Result<(), sqlx::Error> {
let updated = sqlx::query_as::<_, Self>(
r"
UPDATE oauth_accounts
SET email = $1, name = $2, avatar_url = $3, updated_at = NOW()
WHERE id = $4
RETURNING id, user_id, provider, provider_user_id, email, name, avatar_url,
created_at, updated_at
",
)
.bind(&user_info.email)
.bind(&user_info.name)
.bind(&user_info.avatar_url)
.bind(self.id)
.fetch_one(pool)
.await?;
*self = updated;
Ok(())
}
pub async fn user_has_oauth_accounts(
pool: &PgPool,
user_id: i64,
) -> Result<bool, sqlx::Error> {
let count: (i64,) = sqlx::query_as(
r"
SELECT COUNT(*) FROM oauth_accounts WHERE user_id = $1
",
)
.bind(user_id)
.fetch_one(pool)
.await?;
Ok(count.0 > 0)
}
}
impl TryFrom<String> for OAuthProvider {
type Error = String;
fn try_from(value: String) -> Result<Self, Self::Error> {
value.parse().map_err(|e: super::types::OAuthError| e.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oauth_account_serialization() {
let account = OAuthAccount {
id: 1,
user_id: 100,
provider: OAuthProvider::Google,
provider_user_id: "123456".to_string(),
email: "test@gmail.com".to_string(),
name: Some("Test User".to_string()),
avatar_url: Some("https://example.com/avatar.jpg".to_string()),
created_at: Utc::now(),
updated_at: Utc::now(),
};
let json = serde_json::to_string(&account).unwrap();
assert!(json.contains("123456"));
assert!(json.contains("test@gmail.com"));
let deserialized: OAuthAccount = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, 1);
assert_eq!(deserialized.user_id, 100);
assert_eq!(deserialized.provider, OAuthProvider::Google);
assert_eq!(deserialized.provider_user_id, "123456");
assert_eq!(deserialized.email, "test@gmail.com");
}
#[test]
fn test_oauth_provider_try_from_string() {
assert_eq!(
OAuthProvider::try_from("google".to_string()).unwrap(),
OAuthProvider::Google
);
assert_eq!(
OAuthProvider::try_from("github".to_string()).unwrap(),
OAuthProvider::GitHub
);
assert_eq!(
OAuthProvider::try_from("oidc".to_string()).unwrap(),
OAuthProvider::Oidc
);
assert!(OAuthProvider::try_from("invalid".to_string()).is_err());
}
#[test]
fn test_oauth_account_debug() {
let account = OAuthAccount {
id: 1,
user_id: 100,
provider: OAuthProvider::GitHub,
provider_user_id: "gh123".to_string(),
email: "test@github.com".to_string(),
name: None,
avatar_url: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let debug_str = format!("{account:?}");
assert!(debug_str.contains("OAuthAccount"));
assert!(debug_str.contains("GitHub"));
assert!(debug_str.contains("gh123"));
}
#[test]
fn test_oauth_account_clone() {
let account = OAuthAccount {
id: 1,
user_id: 100,
provider: OAuthProvider::Google,
provider_user_id: "123456".to_string(),
email: "test@gmail.com".to_string(),
name: Some("Test User".to_string()),
avatar_url: Some("https://example.com/avatar.jpg".to_string()),
created_at: Utc::now(),
updated_at: Utc::now(),
};
let cloned = account.clone();
assert_eq!(cloned.id, account.id);
assert_eq!(cloned.user_id, account.user_id);
assert_eq!(cloned.provider, account.provider);
assert_eq!(cloned.provider_user_id, account.provider_user_id);
assert_eq!(cloned.email, account.email);
}
}