Skip to main content

better_auth_api/plugins/
account_management.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthAccount, AuthSession, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
10
11/// Account management plugin for listing and unlinking user accounts.
12pub struct AccountManagementPlugin {
13    config: AccountManagementConfig,
14}
15
16#[derive(Debug, Clone)]
17pub struct AccountManagementConfig {
18    pub require_authentication: bool,
19}
20
21#[derive(Debug, Deserialize, Validate)]
22struct UnlinkAccountRequest {
23    #[serde(rename = "providerId")]
24    #[validate(length(min = 1, message = "Provider ID is required"))]
25    provider_id: String,
26}
27
28#[derive(Debug, Serialize)]
29struct AccountResponse {
30    id: String,
31    #[serde(rename = "accountId")]
32    account_id: String,
33    provider: String,
34    #[serde(rename = "createdAt")]
35    created_at: String,
36    #[serde(rename = "updatedAt")]
37    updated_at: String,
38    scopes: Vec<String>,
39}
40
41#[derive(Debug, Serialize)]
42struct StatusResponse {
43    status: bool,
44}
45
46impl AccountManagementPlugin {
47    pub fn new() -> Self {
48        Self {
49            config: AccountManagementConfig::default(),
50        }
51    }
52
53    pub fn with_config(config: AccountManagementConfig) -> Self {
54        Self { config }
55    }
56
57    pub fn require_authentication(mut self, require: bool) -> Self {
58        self.config.require_authentication = require;
59        self
60    }
61}
62
63impl Default for AccountManagementPlugin {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl Default for AccountManagementConfig {
70    fn default() -> Self {
71        Self {
72            require_authentication: true,
73        }
74    }
75}
76
77#[async_trait]
78impl<DB: DatabaseAdapter> AuthPlugin<DB> for AccountManagementPlugin {
79    fn name(&self) -> &'static str {
80        "account-management"
81    }
82
83    fn routes(&self) -> Vec<AuthRoute> {
84        vec![
85            AuthRoute::get("/list-accounts", "list_accounts"),
86            AuthRoute::post("/unlink-account", "unlink_account"),
87        ]
88    }
89
90    async fn on_request(
91        &self,
92        req: &AuthRequest,
93        ctx: &AuthContext<DB>,
94    ) -> AuthResult<Option<AuthResponse>> {
95        match (req.method(), req.path()) {
96            (HttpMethod::Get, "/list-accounts") => {
97                Ok(Some(self.handle_list_accounts(req, ctx).await?))
98            }
99            (HttpMethod::Post, "/unlink-account") => {
100                Ok(Some(self.handle_unlink_account(req, ctx).await?))
101            }
102            _ => Ok(None),
103        }
104    }
105}
106
107impl AccountManagementPlugin {
108    async fn require_session<DB: DatabaseAdapter>(
109        &self,
110        req: &AuthRequest,
111        ctx: &AuthContext<DB>,
112    ) -> AuthResult<(DB::User, DB::Session)> {
113        let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
114
115        if let Some(token) = session_manager.extract_session_token(req)
116            && let Some(session) = session_manager.get_session(&token).await?
117            && let Some(user) = ctx.database.get_user_by_id(session.user_id()).await?
118        {
119            return Ok((user, session));
120        }
121
122        Err(AuthError::Unauthenticated)
123    }
124
125    async fn handle_list_accounts<DB: DatabaseAdapter>(
126        &self,
127        req: &AuthRequest,
128        ctx: &AuthContext<DB>,
129    ) -> AuthResult<AuthResponse> {
130        let (user, _session) = self.require_session(req, ctx).await?;
131
132        let accounts = ctx.database.get_user_accounts(user.id()).await?;
133
134        // Filter sensitive fields (password, tokens)
135        let filtered: Vec<AccountResponse> = accounts
136            .iter()
137            .map(|acc| AccountResponse {
138                id: acc.id().to_string(),
139                account_id: acc.account_id().to_string(),
140                provider: acc.provider_id().to_string(),
141                created_at: acc.created_at().to_rfc3339(),
142                updated_at: acc.updated_at().to_rfc3339(),
143                scopes: acc
144                    .scope()
145                    .map(|s| {
146                        s.split([' ', ','])
147                            .filter(|s| !s.is_empty())
148                            .map(|s| s.to_string())
149                            .collect()
150                    })
151                    .unwrap_or_default(),
152            })
153            .collect();
154
155        Ok(AuthResponse::json(200, &filtered)?)
156    }
157
158    async fn handle_unlink_account<DB: DatabaseAdapter>(
159        &self,
160        req: &AuthRequest,
161        ctx: &AuthContext<DB>,
162    ) -> AuthResult<AuthResponse> {
163        let (user, _session) = self.require_session(req, ctx).await?;
164
165        let unlink_req: UnlinkAccountRequest = match better_auth_core::validate_request_body(req) {
166            Ok(v) => v,
167            Err(resp) => return Ok(resp),
168        };
169
170        let accounts = ctx.database.get_user_accounts(user.id()).await?;
171
172        // Check if user has a password (credential provider)
173        let has_password = user
174            .metadata()
175            .get("password_hash")
176            .and_then(|v| v.as_str())
177            .is_some();
178
179        // Count remaining credentials after unlinking
180        let remaining_accounts = accounts
181            .iter()
182            .filter(|acc| acc.provider_id() != unlink_req.provider_id)
183            .count();
184
185        // Prevent unlinking the last credential
186        if !has_password && remaining_accounts == 0 {
187            return Err(AuthError::bad_request(
188                "Cannot unlink the last account. You must have at least one authentication method.",
189            ));
190        }
191
192        // Find and delete the account
193        let account_to_remove = accounts
194            .iter()
195            .find(|acc| acc.provider_id() == unlink_req.provider_id)
196            .ok_or_else(|| AuthError::not_found("No account found with this provider"))?;
197
198        ctx.database.delete_account(account_to_remove.id()).await?;
199
200        let response = StatusResponse { status: true };
201        Ok(AuthResponse::json(200, &response)?)
202    }
203}