better_auth_api/plugins/
account_management.rs1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthAccount, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
10
11use super::StatusResponse;
12
13pub struct AccountManagementPlugin {
15 config: AccountManagementConfig,
16}
17
18#[derive(Debug, Clone)]
19pub struct AccountManagementConfig {
20 pub require_authentication: bool,
21}
22
23#[derive(Debug, Deserialize, Validate)]
24struct UnlinkAccountRequest {
25 #[serde(rename = "providerId")]
26 #[validate(length(min = 1, message = "Provider ID is required"))]
27 provider_id: String,
28}
29
30#[derive(Debug, Serialize)]
31struct AccountResponse {
32 id: String,
33 #[serde(rename = "accountId")]
34 account_id: String,
35 provider: String,
36 #[serde(rename = "createdAt")]
37 created_at: String,
38 #[serde(rename = "updatedAt")]
39 updated_at: String,
40 scopes: Vec<String>,
41}
42
43impl AccountManagementPlugin {
44 pub fn new() -> Self {
45 Self {
46 config: AccountManagementConfig::default(),
47 }
48 }
49
50 pub fn with_config(config: AccountManagementConfig) -> Self {
51 Self { config }
52 }
53
54 pub fn require_authentication(mut self, require: bool) -> Self {
55 self.config.require_authentication = require;
56 self
57 }
58}
59
60impl Default for AccountManagementPlugin {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl Default for AccountManagementConfig {
67 fn default() -> Self {
68 Self {
69 require_authentication: true,
70 }
71 }
72}
73
74#[async_trait]
75impl<DB: DatabaseAdapter> AuthPlugin<DB> for AccountManagementPlugin {
76 fn name(&self) -> &'static str {
77 "account-management"
78 }
79
80 fn routes(&self) -> Vec<AuthRoute> {
81 vec![
82 AuthRoute::get("/list-accounts", "list_accounts"),
83 AuthRoute::post("/unlink-account", "unlink_account"),
84 ]
85 }
86
87 async fn on_request(
88 &self,
89 req: &AuthRequest,
90 ctx: &AuthContext<DB>,
91 ) -> AuthResult<Option<AuthResponse>> {
92 match (req.method(), req.path()) {
93 (HttpMethod::Get, "/list-accounts") => {
94 Ok(Some(self.handle_list_accounts(req, ctx).await?))
95 }
96 (HttpMethod::Post, "/unlink-account") => {
97 Ok(Some(self.handle_unlink_account(req, ctx).await?))
98 }
99 _ => Ok(None),
100 }
101 }
102}
103
104impl AccountManagementPlugin {
105 async fn handle_list_accounts<DB: DatabaseAdapter>(
106 &self,
107 req: &AuthRequest,
108 ctx: &AuthContext<DB>,
109 ) -> AuthResult<AuthResponse> {
110 let (user, _session) = ctx.require_session(req).await?;
111
112 let accounts = ctx.database.get_user_accounts(user.id()).await?;
113
114 let filtered: Vec<AccountResponse> = accounts
116 .iter()
117 .map(|acc| AccountResponse {
118 id: acc.id().to_string(),
119 account_id: acc.account_id().to_string(),
120 provider: acc.provider_id().to_string(),
121 created_at: acc.created_at().to_rfc3339(),
122 updated_at: acc.updated_at().to_rfc3339(),
123 scopes: acc
124 .scope()
125 .map(|s| {
126 s.split([' ', ','])
127 .filter(|s| !s.is_empty())
128 .map(|s| s.to_string())
129 .collect()
130 })
131 .unwrap_or_default(),
132 })
133 .collect();
134
135 Ok(AuthResponse::json(200, &filtered)?)
136 }
137
138 async fn handle_unlink_account<DB: DatabaseAdapter>(
139 &self,
140 req: &AuthRequest,
141 ctx: &AuthContext<DB>,
142 ) -> AuthResult<AuthResponse> {
143 let (user, _session) = ctx.require_session(req).await?;
144
145 let unlink_req: UnlinkAccountRequest = match better_auth_core::validate_request_body(req) {
146 Ok(v) => v,
147 Err(resp) => return Ok(resp),
148 };
149
150 let accounts = ctx.database.get_user_accounts(user.id()).await?;
151
152 let allow_unlinking_all = ctx.config.account.account_linking.allow_unlinking_all;
153
154 let has_password = user
156 .metadata()
157 .get("password_hash")
158 .and_then(|v| v.as_str())
159 .is_some();
160
161 let remaining_accounts = accounts
163 .iter()
164 .filter(|acc| acc.provider_id() != unlink_req.provider_id)
165 .count();
166
167 if !allow_unlinking_all && !has_password && remaining_accounts == 0 {
169 return Err(AuthError::bad_request(
170 "Cannot unlink the last account. You must have at least one authentication method.",
171 ));
172 }
173
174 let account_to_remove = accounts
176 .iter()
177 .find(|acc| acc.provider_id() == unlink_req.provider_id)
178 .ok_or_else(|| AuthError::not_found("No account found with this provider"))?;
179
180 ctx.database.delete_account(account_to_remove.id()).await?;
181
182 let response = StatusResponse { status: true };
183 Ok(AuthResponse::json(200, &response)?)
184 }
185}