better_auth_api/plugins/
account_management.rs1use serde::{Deserialize, Serialize};
2use validator::Validate;
3
4use better_auth_core::adapters::DatabaseAdapter;
5use better_auth_core::entity::{AuthAccount, AuthUser};
6use better_auth_core::{AuthContext, AuthError, AuthResult};
7use better_auth_core::{AuthRequest, AuthResponse};
8
9use super::StatusResponse;
10
11pub struct AccountManagementPlugin {
13 config: AccountManagementConfig,
14}
15
16#[derive(Debug, Clone, better_auth_core::PluginConfig)]
17#[plugin(name = "AccountManagementPlugin")]
18pub struct AccountManagementConfig {
19 #[config(default = true)]
20 pub require_authentication: bool,
21}
22
23#[derive(Debug, Deserialize, Validate)]
24struct UnlinkAccountRequest {
25 #[serde(rename = "providerId")]
26 #[validate(length(min = 1, message = "Provider ID is required"))]
27 provider_id: String,
28}
29
30#[derive(Debug, Serialize)]
31pub(crate) struct AccountResponse {
32 id: String,
33 #[serde(rename = "accountId")]
34 account_id: String,
35 provider: String,
36 #[serde(rename = "createdAt")]
37 created_at: String,
38 #[serde(rename = "updatedAt")]
39 updated_at: String,
40 scopes: Vec<String>,
41}
42
43better_auth_core::impl_auth_plugin! {
44 AccountManagementPlugin, "account-management";
45 routes {
46 get "/list-accounts" => handle_list_accounts, "list_accounts";
47 post "/unlink-account" => handle_unlink_account, "unlink_account";
48 }
49}
50
51pub(crate) async fn list_accounts_core<DB: DatabaseAdapter>(
56 user: &DB::User,
57 ctx: &AuthContext<DB>,
58) -> AuthResult<Vec<AccountResponse>> {
59 let accounts = ctx.database.get_user_accounts(user.id()).await?;
60
61 let filtered: Vec<AccountResponse> = accounts
62 .iter()
63 .map(|acc| AccountResponse {
64 id: acc.id().to_string(),
65 account_id: acc.account_id().to_string(),
66 provider: acc.provider_id().to_string(),
67 created_at: acc.created_at().to_rfc3339(),
68 updated_at: acc.updated_at().to_rfc3339(),
69 scopes: acc
70 .scope()
71 .map(|s| {
72 s.split([' ', ','])
73 .filter(|s| !s.is_empty())
74 .map(|s| s.to_string())
75 .collect()
76 })
77 .unwrap_or_default(),
78 })
79 .collect();
80
81 Ok(filtered)
82}
83
84pub(crate) async fn unlink_account_core<DB: DatabaseAdapter>(
85 user: &DB::User,
86 provider_id: &str,
87 ctx: &AuthContext<DB>,
88) -> AuthResult<StatusResponse> {
89 let accounts = ctx.database.get_user_accounts(user.id()).await?;
90
91 let allow_unlinking_all = ctx.config.account.account_linking.allow_unlinking_all;
92
93 let has_password = user.password_hash().is_some();
95
96 let remaining_accounts = accounts
98 .iter()
99 .filter(|acc| acc.provider_id() != provider_id)
100 .count();
101
102 if !allow_unlinking_all && !has_password && remaining_accounts == 0 {
104 return Err(AuthError::bad_request(
105 "Cannot unlink the last account. You must have at least one authentication method.",
106 ));
107 }
108
109 let account_to_remove = accounts
111 .iter()
112 .find(|acc| acc.provider_id() == provider_id)
113 .ok_or_else(|| AuthError::not_found("No account found with this provider"))?;
114
115 ctx.database.delete_account(account_to_remove.id()).await?;
116
117 Ok(StatusResponse { status: true })
118}
119
120impl AccountManagementPlugin {
125 async fn handle_list_accounts<DB: DatabaseAdapter>(
126 &self,
127 req: &AuthRequest,
128 ctx: &AuthContext<DB>,
129 ) -> AuthResult<AuthResponse> {
130 let (user, _session) = ctx.require_session(req).await?;
131 let filtered = list_accounts_core(&user, ctx).await?;
132 Ok(AuthResponse::json(200, &filtered)?)
133 }
134
135 async fn handle_unlink_account<DB: DatabaseAdapter>(
136 &self,
137 req: &AuthRequest,
138 ctx: &AuthContext<DB>,
139 ) -> AuthResult<AuthResponse> {
140 let (user, _session) = ctx.require_session(req).await?;
141
142 let unlink_req: UnlinkAccountRequest = match better_auth_core::validate_request_body(req) {
143 Ok(v) => v,
144 Err(resp) => return Ok(resp),
145 };
146
147 let response = unlink_account_core(&user, &unlink_req.provider_id, ctx).await?;
148 Ok(AuthResponse::json(200, &response)?)
149 }
150}
151
152#[cfg(feature = "axum")]
153mod axum_impl {
154 use super::*;
155
156 use axum::Json;
157 use axum::extract::State;
158 use better_auth_core::{AuthState, CurrentSession, ValidatedJson};
159
160 async fn handle_list_accounts<DB: DatabaseAdapter>(
161 State(state): State<AuthState<DB>>,
162 CurrentSession { user, .. }: CurrentSession<DB>,
163 ) -> Result<Json<Vec<AccountResponse>>, AuthError> {
164 let ctx = state.to_context();
165 let accounts = list_accounts_core(&user, &ctx).await?;
166 Ok(Json(accounts))
167 }
168
169 async fn handle_unlink_account<DB: DatabaseAdapter>(
170 State(state): State<AuthState<DB>>,
171 CurrentSession { user, .. }: CurrentSession<DB>,
172 ValidatedJson(body): ValidatedJson<UnlinkAccountRequest>,
173 ) -> Result<Json<StatusResponse>, AuthError> {
174 let ctx = state.to_context();
175 let response = unlink_account_core(&user, &body.provider_id, &ctx).await?;
176 Ok(Json(response))
177 }
178
179 impl<DB: DatabaseAdapter> better_auth_core::AxumPlugin<DB> for AccountManagementPlugin {
180 fn name(&self) -> &'static str {
181 "account-management"
182 }
183
184 fn router(&self) -> axum::Router<AuthState<DB>> {
185 use axum::routing::{get, post};
186
187 axum::Router::new()
188 .route("/list-accounts", get(handle_list_accounts::<DB>))
189 .route("/unlink-account", post(handle_unlink_account::<DB>))
190 }
191 }
192}