Skip to main content

better_auth_api/plugins/
account_management.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthAccount, AuthSession, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
10
11/// Account management plugin for listing and unlinking user accounts.
12pub struct AccountManagementPlugin {
13    config: AccountManagementConfig,
14}
15
16#[derive(Debug, Clone)]
17pub struct AccountManagementConfig {
18    pub require_authentication: bool,
19}
20
21#[derive(Debug, Deserialize, Validate)]
22struct UnlinkAccountRequest {
23    #[serde(rename = "providerId")]
24    #[validate(length(min = 1, message = "Provider ID is required"))]
25    provider_id: String,
26}
27
28#[derive(Debug, Serialize)]
29struct AccountResponse {
30    id: String,
31    #[serde(rename = "accountId")]
32    account_id: String,
33    #[serde(rename = "providerId")]
34    provider_id: String,
35    #[serde(rename = "userId")]
36    user_id: String,
37    #[serde(rename = "createdAt")]
38    created_at: String,
39    #[serde(rename = "updatedAt")]
40    updated_at: String,
41    scope: Option<String>,
42}
43
44#[derive(Debug, Serialize)]
45struct StatusResponse {
46    status: bool,
47}
48
49impl AccountManagementPlugin {
50    pub fn new() -> Self {
51        Self {
52            config: AccountManagementConfig::default(),
53        }
54    }
55
56    pub fn with_config(config: AccountManagementConfig) -> Self {
57        Self { config }
58    }
59
60    pub fn require_authentication(mut self, require: bool) -> Self {
61        self.config.require_authentication = require;
62        self
63    }
64}
65
66impl Default for AccountManagementPlugin {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl Default for AccountManagementConfig {
73    fn default() -> Self {
74        Self {
75            require_authentication: true,
76        }
77    }
78}
79
80#[async_trait]
81impl<DB: DatabaseAdapter> AuthPlugin<DB> for AccountManagementPlugin {
82    fn name(&self) -> &'static str {
83        "account-management"
84    }
85
86    fn routes(&self) -> Vec<AuthRoute> {
87        vec![
88            AuthRoute::get("/list-accounts", "list_accounts"),
89            AuthRoute::post("/unlink-account", "unlink_account"),
90        ]
91    }
92
93    async fn on_request(
94        &self,
95        req: &AuthRequest,
96        ctx: &AuthContext<DB>,
97    ) -> AuthResult<Option<AuthResponse>> {
98        match (req.method(), req.path()) {
99            (HttpMethod::Get, "/list-accounts") => {
100                Ok(Some(self.handle_list_accounts(req, ctx).await?))
101            }
102            (HttpMethod::Post, "/unlink-account") => {
103                Ok(Some(self.handle_unlink_account(req, ctx).await?))
104            }
105            _ => Ok(None),
106        }
107    }
108}
109
110impl AccountManagementPlugin {
111    async fn require_session<DB: DatabaseAdapter>(
112        &self,
113        req: &AuthRequest,
114        ctx: &AuthContext<DB>,
115    ) -> AuthResult<(DB::User, DB::Session)> {
116        let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
117
118        if let Some(token) = session_manager.extract_session_token(req)
119            && let Some(session) = session_manager.get_session(&token).await?
120            && let Some(user) = ctx.database.get_user_by_id(session.user_id()).await?
121        {
122            return Ok((user, session));
123        }
124
125        Err(AuthError::Unauthenticated)
126    }
127
128    async fn handle_list_accounts<DB: DatabaseAdapter>(
129        &self,
130        req: &AuthRequest,
131        ctx: &AuthContext<DB>,
132    ) -> AuthResult<AuthResponse> {
133        let (user, _session) = self.require_session(req, ctx).await?;
134
135        let accounts = ctx.database.get_user_accounts(user.id()).await?;
136
137        // Filter sensitive fields (password, tokens)
138        let filtered: Vec<AccountResponse> = accounts
139            .iter()
140            .map(|acc| AccountResponse {
141                id: acc.id().to_string(),
142                account_id: acc.account_id().to_string(),
143                provider_id: acc.provider_id().to_string(),
144                user_id: acc.user_id().to_string(),
145                created_at: acc.created_at().to_rfc3339(),
146                updated_at: acc.updated_at().to_rfc3339(),
147                scope: acc.scope().map(|s| s.to_string()),
148            })
149            .collect();
150
151        Ok(AuthResponse::json(200, &filtered)?)
152    }
153
154    async fn handle_unlink_account<DB: DatabaseAdapter>(
155        &self,
156        req: &AuthRequest,
157        ctx: &AuthContext<DB>,
158    ) -> AuthResult<AuthResponse> {
159        let (user, _session) = self.require_session(req, ctx).await?;
160
161        let unlink_req: UnlinkAccountRequest = match better_auth_core::validate_request_body(req) {
162            Ok(v) => v,
163            Err(resp) => return Ok(resp),
164        };
165
166        let accounts = ctx.database.get_user_accounts(user.id()).await?;
167
168        // Check if user has a password (credential provider)
169        let has_password = user
170            .metadata()
171            .get("password_hash")
172            .and_then(|v| v.as_str())
173            .is_some();
174
175        // Count remaining credentials after unlinking
176        let remaining_accounts = accounts
177            .iter()
178            .filter(|acc| acc.provider_id() != unlink_req.provider_id)
179            .count();
180
181        // Prevent unlinking the last credential
182        if !has_password && remaining_accounts == 0 {
183            return Err(AuthError::bad_request(
184                "Cannot unlink the last account. You must have at least one authentication method.",
185            ));
186        }
187
188        // Find and delete the account
189        let account_to_remove = accounts
190            .iter()
191            .find(|acc| acc.provider_id() == unlink_req.provider_id)
192            .ok_or_else(|| AuthError::not_found("No account found with this provider"))?;
193
194        ctx.database.delete_account(account_to_remove.id()).await?;
195
196        let response = StatusResponse { status: true };
197        Ok(AuthResponse::json(200, &response)?)
198    }
199}