better_auth_api/plugins/
account_management.rs1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthAccount, AuthSession, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
10
11pub struct AccountManagementPlugin {
13 config: AccountManagementConfig,
14}
15
16#[derive(Debug, Clone)]
17pub struct AccountManagementConfig {
18 pub require_authentication: bool,
19}
20
21#[derive(Debug, Deserialize, Validate)]
22struct UnlinkAccountRequest {
23 #[serde(rename = "providerId")]
24 #[validate(length(min = 1, message = "Provider ID is required"))]
25 provider_id: String,
26}
27
28#[derive(Debug, Serialize)]
29struct AccountResponse {
30 id: String,
31 #[serde(rename = "accountId")]
32 account_id: String,
33 provider: String,
34 #[serde(rename = "createdAt")]
35 created_at: String,
36 #[serde(rename = "updatedAt")]
37 updated_at: String,
38 scopes: Vec<String>,
39}
40
41#[derive(Debug, Serialize)]
42struct StatusResponse {
43 status: bool,
44}
45
46impl AccountManagementPlugin {
47 pub fn new() -> Self {
48 Self {
49 config: AccountManagementConfig::default(),
50 }
51 }
52
53 pub fn with_config(config: AccountManagementConfig) -> Self {
54 Self { config }
55 }
56
57 pub fn require_authentication(mut self, require: bool) -> Self {
58 self.config.require_authentication = require;
59 self
60 }
61}
62
63impl Default for AccountManagementPlugin {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl Default for AccountManagementConfig {
70 fn default() -> Self {
71 Self {
72 require_authentication: true,
73 }
74 }
75}
76
77#[async_trait]
78impl<DB: DatabaseAdapter> AuthPlugin<DB> for AccountManagementPlugin {
79 fn name(&self) -> &'static str {
80 "account-management"
81 }
82
83 fn routes(&self) -> Vec<AuthRoute> {
84 vec![
85 AuthRoute::get("/list-accounts", "list_accounts"),
86 AuthRoute::post("/unlink-account", "unlink_account"),
87 ]
88 }
89
90 async fn on_request(
91 &self,
92 req: &AuthRequest,
93 ctx: &AuthContext<DB>,
94 ) -> AuthResult<Option<AuthResponse>> {
95 match (req.method(), req.path()) {
96 (HttpMethod::Get, "/list-accounts") => {
97 Ok(Some(self.handle_list_accounts(req, ctx).await?))
98 }
99 (HttpMethod::Post, "/unlink-account") => {
100 Ok(Some(self.handle_unlink_account(req, ctx).await?))
101 }
102 _ => Ok(None),
103 }
104 }
105}
106
107impl AccountManagementPlugin {
108 async fn require_session<DB: DatabaseAdapter>(
109 &self,
110 req: &AuthRequest,
111 ctx: &AuthContext<DB>,
112 ) -> AuthResult<(DB::User, DB::Session)> {
113 let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
114
115 if let Some(token) = session_manager.extract_session_token(req)
116 && let Some(session) = session_manager.get_session(&token).await?
117 && let Some(user) = ctx.database.get_user_by_id(session.user_id()).await?
118 {
119 return Ok((user, session));
120 }
121
122 Err(AuthError::Unauthenticated)
123 }
124
125 async fn handle_list_accounts<DB: DatabaseAdapter>(
126 &self,
127 req: &AuthRequest,
128 ctx: &AuthContext<DB>,
129 ) -> AuthResult<AuthResponse> {
130 let (user, _session) = self.require_session(req, ctx).await?;
131
132 let accounts = ctx.database.get_user_accounts(user.id()).await?;
133
134 let filtered: Vec<AccountResponse> = accounts
136 .iter()
137 .map(|acc| AccountResponse {
138 id: acc.id().to_string(),
139 account_id: acc.account_id().to_string(),
140 provider: acc.provider_id().to_string(),
141 created_at: acc.created_at().to_rfc3339(),
142 updated_at: acc.updated_at().to_rfc3339(),
143 scopes: acc
144 .scope()
145 .map(|s| {
146 s.split([' ', ','])
147 .filter(|s| !s.is_empty())
148 .map(|s| s.to_string())
149 .collect()
150 })
151 .unwrap_or_default(),
152 })
153 .collect();
154
155 Ok(AuthResponse::json(200, &filtered)?)
156 }
157
158 async fn handle_unlink_account<DB: DatabaseAdapter>(
159 &self,
160 req: &AuthRequest,
161 ctx: &AuthContext<DB>,
162 ) -> AuthResult<AuthResponse> {
163 let (user, _session) = self.require_session(req, ctx).await?;
164
165 let unlink_req: UnlinkAccountRequest = match better_auth_core::validate_request_body(req) {
166 Ok(v) => v,
167 Err(resp) => return Ok(resp),
168 };
169
170 let accounts = ctx.database.get_user_accounts(user.id()).await?;
171
172 let has_password = user
174 .metadata()
175 .get("password_hash")
176 .and_then(|v| v.as_str())
177 .is_some();
178
179 let remaining_accounts = accounts
181 .iter()
182 .filter(|acc| acc.provider_id() != unlink_req.provider_id)
183 .count();
184
185 if !has_password && remaining_accounts == 0 {
187 return Err(AuthError::bad_request(
188 "Cannot unlink the last account. You must have at least one authentication method.",
189 ));
190 }
191
192 let account_to_remove = accounts
194 .iter()
195 .find(|acc| acc.provider_id() == unlink_req.provider_id)
196 .ok_or_else(|| AuthError::not_found("No account found with this provider"))?;
197
198 ctx.database.delete_account(account_to_remove.id()).await?;
199
200 let response = StatusResponse { status: true };
201 Ok(AuthResponse::json(200, &response)?)
202 }
203}