Skip to main content

better_auth_api/plugins/
account_management.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthAccount, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
10
11use super::StatusResponse;
12
13/// Account management plugin for listing and unlinking user accounts.
14pub struct AccountManagementPlugin {
15    config: AccountManagementConfig,
16}
17
18#[derive(Debug, Clone)]
19pub struct AccountManagementConfig {
20    pub require_authentication: bool,
21}
22
23#[derive(Debug, Deserialize, Validate)]
24struct UnlinkAccountRequest {
25    #[serde(rename = "providerId")]
26    #[validate(length(min = 1, message = "Provider ID is required"))]
27    provider_id: String,
28}
29
30#[derive(Debug, Serialize)]
31struct AccountResponse {
32    id: String,
33    #[serde(rename = "accountId")]
34    account_id: String,
35    provider: String,
36    #[serde(rename = "createdAt")]
37    created_at: String,
38    #[serde(rename = "updatedAt")]
39    updated_at: String,
40    scopes: Vec<String>,
41}
42
43impl AccountManagementPlugin {
44    pub fn new() -> Self {
45        Self {
46            config: AccountManagementConfig::default(),
47        }
48    }
49
50    pub fn with_config(config: AccountManagementConfig) -> Self {
51        Self { config }
52    }
53
54    pub fn require_authentication(mut self, require: bool) -> Self {
55        self.config.require_authentication = require;
56        self
57    }
58}
59
60impl Default for AccountManagementPlugin {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl Default for AccountManagementConfig {
67    fn default() -> Self {
68        Self {
69            require_authentication: true,
70        }
71    }
72}
73
74#[async_trait]
75impl<DB: DatabaseAdapter> AuthPlugin<DB> for AccountManagementPlugin {
76    fn name(&self) -> &'static str {
77        "account-management"
78    }
79
80    fn routes(&self) -> Vec<AuthRoute> {
81        vec![
82            AuthRoute::get("/list-accounts", "list_accounts"),
83            AuthRoute::post("/unlink-account", "unlink_account"),
84        ]
85    }
86
87    async fn on_request(
88        &self,
89        req: &AuthRequest,
90        ctx: &AuthContext<DB>,
91    ) -> AuthResult<Option<AuthResponse>> {
92        match (req.method(), req.path()) {
93            (HttpMethod::Get, "/list-accounts") => {
94                Ok(Some(self.handle_list_accounts(req, ctx).await?))
95            }
96            (HttpMethod::Post, "/unlink-account") => {
97                Ok(Some(self.handle_unlink_account(req, ctx).await?))
98            }
99            _ => Ok(None),
100        }
101    }
102}
103
104impl AccountManagementPlugin {
105    async fn handle_list_accounts<DB: DatabaseAdapter>(
106        &self,
107        req: &AuthRequest,
108        ctx: &AuthContext<DB>,
109    ) -> AuthResult<AuthResponse> {
110        let (user, _session) = ctx.require_session(req).await?;
111
112        let accounts = ctx.database.get_user_accounts(user.id()).await?;
113
114        // Filter sensitive fields (password, tokens)
115        let filtered: Vec<AccountResponse> = accounts
116            .iter()
117            .map(|acc| AccountResponse {
118                id: acc.id().to_string(),
119                account_id: acc.account_id().to_string(),
120                provider: acc.provider_id().to_string(),
121                created_at: acc.created_at().to_rfc3339(),
122                updated_at: acc.updated_at().to_rfc3339(),
123                scopes: acc
124                    .scope()
125                    .map(|s| {
126                        s.split([' ', ','])
127                            .filter(|s| !s.is_empty())
128                            .map(|s| s.to_string())
129                            .collect()
130                    })
131                    .unwrap_or_default(),
132            })
133            .collect();
134
135        Ok(AuthResponse::json(200, &filtered)?)
136    }
137
138    async fn handle_unlink_account<DB: DatabaseAdapter>(
139        &self,
140        req: &AuthRequest,
141        ctx: &AuthContext<DB>,
142    ) -> AuthResult<AuthResponse> {
143        let (user, _session) = ctx.require_session(req).await?;
144
145        let unlink_req: UnlinkAccountRequest = match better_auth_core::validate_request_body(req) {
146            Ok(v) => v,
147            Err(resp) => return Ok(resp),
148        };
149
150        let accounts = ctx.database.get_user_accounts(user.id()).await?;
151
152        let allow_unlinking_all = ctx.config.account.account_linking.allow_unlinking_all;
153
154        // Check if user has a password (credential provider)
155        let has_password = user
156            .metadata()
157            .get("password_hash")
158            .and_then(|v| v.as_str())
159            .is_some();
160
161        // Count remaining credentials after unlinking
162        let remaining_accounts = accounts
163            .iter()
164            .filter(|acc| acc.provider_id() != unlink_req.provider_id)
165            .count();
166
167        // Prevent unlinking the last credential (unless allow_unlinking_all is true)
168        if !allow_unlinking_all && !has_password && remaining_accounts == 0 {
169            return Err(AuthError::bad_request(
170                "Cannot unlink the last account. You must have at least one authentication method.",
171            ));
172        }
173
174        // Find and delete the account
175        let account_to_remove = accounts
176            .iter()
177            .find(|acc| acc.provider_id() == unlink_req.provider_id)
178            .ok_or_else(|| AuthError::not_found("No account found with this provider"))?;
179
180        ctx.database.delete_account(account_to_remove.id()).await?;
181
182        let response = StatusResponse { status: true };
183        Ok(AuthResponse::json(200, &response)?)
184    }
185}