Skip to main content

better_auth_api/plugins/
account_management.rs

1use serde::{Deserialize, Serialize};
2use validator::Validate;
3
4use better_auth_core::adapters::DatabaseAdapter;
5use better_auth_core::entity::{AuthAccount, AuthUser};
6use better_auth_core::{AuthContext, AuthError, AuthResult};
7use better_auth_core::{AuthRequest, AuthResponse};
8
9use super::StatusResponse;
10
11/// Account management plugin for listing and unlinking user accounts.
12pub struct AccountManagementPlugin {
13    config: AccountManagementConfig,
14}
15
16#[derive(Debug, Clone, better_auth_core::PluginConfig)]
17#[plugin(name = "AccountManagementPlugin")]
18pub struct AccountManagementConfig {
19    #[config(default = true)]
20    pub require_authentication: bool,
21}
22
23#[derive(Debug, Deserialize, Validate)]
24struct UnlinkAccountRequest {
25    #[serde(rename = "providerId")]
26    #[validate(length(min = 1, message = "Provider ID is required"))]
27    provider_id: String,
28}
29
30#[derive(Debug, Serialize)]
31pub(crate) struct AccountResponse {
32    id: String,
33    #[serde(rename = "accountId")]
34    account_id: String,
35    provider: String,
36    #[serde(rename = "createdAt")]
37    created_at: String,
38    #[serde(rename = "updatedAt")]
39    updated_at: String,
40    scopes: Vec<String>,
41}
42
43better_auth_core::impl_auth_plugin! {
44    AccountManagementPlugin, "account-management";
45    routes {
46        get "/list-accounts" => handle_list_accounts, "list_accounts";
47        post "/unlink-account" => handle_unlink_account, "unlink_account";
48    }
49}
50
51// ---------------------------------------------------------------------------
52// Core functions — framework-agnostic business logic
53// ---------------------------------------------------------------------------
54
55pub(crate) async fn list_accounts_core<DB: DatabaseAdapter>(
56    user: &DB::User,
57    ctx: &AuthContext<DB>,
58) -> AuthResult<Vec<AccountResponse>> {
59    let accounts = ctx.database.get_user_accounts(user.id()).await?;
60
61    let filtered: Vec<AccountResponse> = accounts
62        .iter()
63        .map(|acc| AccountResponse {
64            id: acc.id().to_string(),
65            account_id: acc.account_id().to_string(),
66            provider: acc.provider_id().to_string(),
67            created_at: acc.created_at().to_rfc3339(),
68            updated_at: acc.updated_at().to_rfc3339(),
69            scopes: acc
70                .scope()
71                .map(|s| {
72                    s.split([' ', ','])
73                        .filter(|s| !s.is_empty())
74                        .map(|s| s.to_string())
75                        .collect()
76                })
77                .unwrap_or_default(),
78        })
79        .collect();
80
81    Ok(filtered)
82}
83
84pub(crate) async fn unlink_account_core<DB: DatabaseAdapter>(
85    user: &DB::User,
86    provider_id: &str,
87    ctx: &AuthContext<DB>,
88) -> AuthResult<StatusResponse> {
89    let accounts = ctx.database.get_user_accounts(user.id()).await?;
90
91    let allow_unlinking_all = ctx.config.account.account_linking.allow_unlinking_all;
92
93    // Check if user has a password (credential provider)
94    let has_password = user.password_hash().is_some();
95
96    // Count remaining credentials after unlinking
97    let remaining_accounts = accounts
98        .iter()
99        .filter(|acc| acc.provider_id() != provider_id)
100        .count();
101
102    // Prevent unlinking the last credential (unless allow_unlinking_all is true)
103    if !allow_unlinking_all && !has_password && remaining_accounts == 0 {
104        return Err(AuthError::bad_request(
105            "Cannot unlink the last account. You must have at least one authentication method.",
106        ));
107    }
108
109    // Find and delete the account
110    let account_to_remove = accounts
111        .iter()
112        .find(|acc| acc.provider_id() == provider_id)
113        .ok_or_else(|| AuthError::not_found("No account found with this provider"))?;
114
115    ctx.database.delete_account(account_to_remove.id()).await?;
116
117    Ok(StatusResponse { status: true })
118}
119
120// ---------------------------------------------------------------------------
121// Old handler methods — delegate to core functions
122// ---------------------------------------------------------------------------
123
124impl AccountManagementPlugin {
125    async fn handle_list_accounts<DB: DatabaseAdapter>(
126        &self,
127        req: &AuthRequest,
128        ctx: &AuthContext<DB>,
129    ) -> AuthResult<AuthResponse> {
130        let (user, _session) = ctx.require_session(req).await?;
131        let filtered = list_accounts_core(&user, ctx).await?;
132        Ok(AuthResponse::json(200, &filtered)?)
133    }
134
135    async fn handle_unlink_account<DB: DatabaseAdapter>(
136        &self,
137        req: &AuthRequest,
138        ctx: &AuthContext<DB>,
139    ) -> AuthResult<AuthResponse> {
140        let (user, _session) = ctx.require_session(req).await?;
141
142        let unlink_req: UnlinkAccountRequest = match better_auth_core::validate_request_body(req) {
143            Ok(v) => v,
144            Err(resp) => return Ok(resp),
145        };
146
147        let response = unlink_account_core(&user, &unlink_req.provider_id, ctx).await?;
148        Ok(AuthResponse::json(200, &response)?)
149    }
150}
151
152#[cfg(feature = "axum")]
153mod axum_impl {
154    use super::*;
155
156    use axum::Json;
157    use axum::extract::State;
158    use better_auth_core::{AuthState, CurrentSession, ValidatedJson};
159
160    async fn handle_list_accounts<DB: DatabaseAdapter>(
161        State(state): State<AuthState<DB>>,
162        CurrentSession { user, .. }: CurrentSession<DB>,
163    ) -> Result<Json<Vec<AccountResponse>>, AuthError> {
164        let ctx = state.to_context();
165        let accounts = list_accounts_core(&user, &ctx).await?;
166        Ok(Json(accounts))
167    }
168
169    async fn handle_unlink_account<DB: DatabaseAdapter>(
170        State(state): State<AuthState<DB>>,
171        CurrentSession { user, .. }: CurrentSession<DB>,
172        ValidatedJson(body): ValidatedJson<UnlinkAccountRequest>,
173    ) -> Result<Json<StatusResponse>, AuthError> {
174        let ctx = state.to_context();
175        let response = unlink_account_core(&user, &body.provider_id, &ctx).await?;
176        Ok(Json(response))
177    }
178
179    impl<DB: DatabaseAdapter> better_auth_core::AxumPlugin<DB> for AccountManagementPlugin {
180        fn name(&self) -> &'static str {
181            "account-management"
182        }
183
184        fn router(&self) -> axum::Router<AuthState<DB>> {
185            use axum::routing::{get, post};
186
187            axum::Router::new()
188                .route("/list-accounts", get(handle_list_accounts::<DB>))
189                .route("/unlink-account", post(handle_unlink_account::<DB>))
190        }
191    }
192}