Skip to main content

better_auth_api/plugins/
account_management.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
6use better_auth_core::{AuthError, AuthResult};
7use better_auth_core::{AuthRequest, AuthResponse, HttpMethod, Session, User};
8
9/// Account management plugin for listing and unlinking user accounts.
10pub struct AccountManagementPlugin {
11    config: AccountManagementConfig,
12}
13
14#[derive(Debug, Clone)]
15pub struct AccountManagementConfig {
16    pub require_authentication: bool,
17}
18
19#[derive(Debug, Deserialize, Validate)]
20struct UnlinkAccountRequest {
21    #[serde(rename = "providerId")]
22    #[validate(length(min = 1, message = "Provider ID is required"))]
23    provider_id: String,
24}
25
26#[derive(Debug, Serialize)]
27struct AccountResponse {
28    id: String,
29    #[serde(rename = "accountId")]
30    account_id: String,
31    #[serde(rename = "providerId")]
32    provider_id: String,
33    #[serde(rename = "userId")]
34    user_id: String,
35    #[serde(rename = "createdAt")]
36    created_at: String,
37    #[serde(rename = "updatedAt")]
38    updated_at: String,
39    scope: Option<String>,
40}
41
42#[derive(Debug, Serialize)]
43struct StatusResponse {
44    status: bool,
45}
46
47impl AccountManagementPlugin {
48    pub fn new() -> Self {
49        Self {
50            config: AccountManagementConfig::default(),
51        }
52    }
53
54    pub fn with_config(config: AccountManagementConfig) -> Self {
55        Self { config }
56    }
57
58    pub fn require_authentication(mut self, require: bool) -> Self {
59        self.config.require_authentication = require;
60        self
61    }
62}
63
64impl Default for AccountManagementPlugin {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl Default for AccountManagementConfig {
71    fn default() -> Self {
72        Self {
73            require_authentication: true,
74        }
75    }
76}
77
78#[async_trait]
79impl AuthPlugin for AccountManagementPlugin {
80    fn name(&self) -> &'static str {
81        "account-management"
82    }
83
84    fn routes(&self) -> Vec<AuthRoute> {
85        vec![
86            AuthRoute::get("/list-accounts", "list_accounts"),
87            AuthRoute::post("/unlink-account", "unlink_account"),
88        ]
89    }
90
91    async fn on_request(
92        &self,
93        req: &AuthRequest,
94        ctx: &AuthContext,
95    ) -> AuthResult<Option<AuthResponse>> {
96        match (req.method(), req.path()) {
97            (HttpMethod::Get, "/list-accounts") => {
98                Ok(Some(self.handle_list_accounts(req, ctx).await?))
99            }
100            (HttpMethod::Post, "/unlink-account") => {
101                Ok(Some(self.handle_unlink_account(req, ctx).await?))
102            }
103            _ => Ok(None),
104        }
105    }
106}
107
108impl AccountManagementPlugin {
109    async fn require_session(
110        &self,
111        req: &AuthRequest,
112        ctx: &AuthContext,
113    ) -> AuthResult<(User, Session)> {
114        let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
115
116        if let Some(token) = session_manager.extract_session_token(req)
117            && let Some(session) = session_manager.get_session(&token).await?
118            && let Some(user) = ctx.database.get_user_by_id(&session.user_id).await?
119        {
120            return Ok((user, session));
121        }
122
123        Err(AuthError::Unauthenticated)
124    }
125
126    async fn handle_list_accounts(
127        &self,
128        req: &AuthRequest,
129        ctx: &AuthContext,
130    ) -> AuthResult<AuthResponse> {
131        let (user, _session) = self.require_session(req, ctx).await?;
132
133        let accounts = ctx.database.get_user_accounts(&user.id).await?;
134
135        // Filter sensitive fields (password, tokens)
136        let filtered: Vec<AccountResponse> = accounts
137            .iter()
138            .map(|acc| AccountResponse {
139                id: acc.id.clone(),
140                account_id: acc.account_id.clone(),
141                provider_id: acc.provider_id.clone(),
142                user_id: acc.user_id.clone(),
143                created_at: acc.created_at.to_rfc3339(),
144                updated_at: acc.updated_at.to_rfc3339(),
145                scope: acc.scope.clone(),
146            })
147            .collect();
148
149        Ok(AuthResponse::json(200, &filtered)?)
150    }
151
152    async fn handle_unlink_account(
153        &self,
154        req: &AuthRequest,
155        ctx: &AuthContext,
156    ) -> AuthResult<AuthResponse> {
157        let (user, _session) = self.require_session(req, ctx).await?;
158
159        let unlink_req: UnlinkAccountRequest = match better_auth_core::validate_request_body(req) {
160            Ok(v) => v,
161            Err(resp) => return Ok(resp),
162        };
163
164        let accounts = ctx.database.get_user_accounts(&user.id).await?;
165
166        // Check if user has a password (credential provider)
167        let has_password = user
168            .metadata
169            .get("password_hash")
170            .and_then(|v| v.as_str())
171            .is_some();
172
173        // Count remaining credentials after unlinking
174        let remaining_accounts = accounts
175            .iter()
176            .filter(|acc| acc.provider_id != unlink_req.provider_id)
177            .count();
178
179        // Prevent unlinking the last credential
180        if !has_password && remaining_accounts == 0 {
181            return Err(AuthError::bad_request(
182                "Cannot unlink the last account. You must have at least one authentication method.",
183            ));
184        }
185
186        // Find and delete the account
187        let account_to_remove = accounts
188            .iter()
189            .find(|acc| acc.provider_id == unlink_req.provider_id)
190            .ok_or_else(|| AuthError::not_found("No account found with this provider"))?;
191
192        ctx.database.delete_account(&account_to_remove.id).await?;
193
194        let response = StatusResponse { status: true };
195        Ok(AuthResponse::json(200, &response)?)
196    }
197}