better_auth_api/plugins/
account_management.rs1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::adapters::DatabaseAdapter;
6use better_auth_core::entity::{AuthAccount, AuthSession, AuthUser};
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
10
11pub struct AccountManagementPlugin {
13 config: AccountManagementConfig,
14}
15
16#[derive(Debug, Clone)]
17pub struct AccountManagementConfig {
18 pub require_authentication: bool,
19}
20
21#[derive(Debug, Deserialize, Validate)]
22struct UnlinkAccountRequest {
23 #[serde(rename = "providerId")]
24 #[validate(length(min = 1, message = "Provider ID is required"))]
25 provider_id: String,
26}
27
28#[derive(Debug, Serialize)]
29struct AccountResponse {
30 id: String,
31 #[serde(rename = "accountId")]
32 account_id: String,
33 #[serde(rename = "providerId")]
34 provider_id: String,
35 #[serde(rename = "userId")]
36 user_id: String,
37 #[serde(rename = "createdAt")]
38 created_at: String,
39 #[serde(rename = "updatedAt")]
40 updated_at: String,
41 scope: Option<String>,
42}
43
44#[derive(Debug, Serialize)]
45struct StatusResponse {
46 status: bool,
47}
48
49impl AccountManagementPlugin {
50 pub fn new() -> Self {
51 Self {
52 config: AccountManagementConfig::default(),
53 }
54 }
55
56 pub fn with_config(config: AccountManagementConfig) -> Self {
57 Self { config }
58 }
59
60 pub fn require_authentication(mut self, require: bool) -> Self {
61 self.config.require_authentication = require;
62 self
63 }
64}
65
66impl Default for AccountManagementPlugin {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl Default for AccountManagementConfig {
73 fn default() -> Self {
74 Self {
75 require_authentication: true,
76 }
77 }
78}
79
80#[async_trait]
81impl<DB: DatabaseAdapter> AuthPlugin<DB> for AccountManagementPlugin {
82 fn name(&self) -> &'static str {
83 "account-management"
84 }
85
86 fn routes(&self) -> Vec<AuthRoute> {
87 vec![
88 AuthRoute::get("/list-accounts", "list_accounts"),
89 AuthRoute::post("/unlink-account", "unlink_account"),
90 ]
91 }
92
93 async fn on_request(
94 &self,
95 req: &AuthRequest,
96 ctx: &AuthContext<DB>,
97 ) -> AuthResult<Option<AuthResponse>> {
98 match (req.method(), req.path()) {
99 (HttpMethod::Get, "/list-accounts") => {
100 Ok(Some(self.handle_list_accounts(req, ctx).await?))
101 }
102 (HttpMethod::Post, "/unlink-account") => {
103 Ok(Some(self.handle_unlink_account(req, ctx).await?))
104 }
105 _ => Ok(None),
106 }
107 }
108}
109
110impl AccountManagementPlugin {
111 async fn require_session<DB: DatabaseAdapter>(
112 &self,
113 req: &AuthRequest,
114 ctx: &AuthContext<DB>,
115 ) -> AuthResult<(DB::User, DB::Session)> {
116 let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
117
118 if let Some(token) = session_manager.extract_session_token(req)
119 && let Some(session) = session_manager.get_session(&token).await?
120 && let Some(user) = ctx.database.get_user_by_id(session.user_id()).await?
121 {
122 return Ok((user, session));
123 }
124
125 Err(AuthError::Unauthenticated)
126 }
127
128 async fn handle_list_accounts<DB: DatabaseAdapter>(
129 &self,
130 req: &AuthRequest,
131 ctx: &AuthContext<DB>,
132 ) -> AuthResult<AuthResponse> {
133 let (user, _session) = self.require_session(req, ctx).await?;
134
135 let accounts = ctx.database.get_user_accounts(user.id()).await?;
136
137 let filtered: Vec<AccountResponse> = accounts
139 .iter()
140 .map(|acc| AccountResponse {
141 id: acc.id().to_string(),
142 account_id: acc.account_id().to_string(),
143 provider_id: acc.provider_id().to_string(),
144 user_id: acc.user_id().to_string(),
145 created_at: acc.created_at().to_rfc3339(),
146 updated_at: acc.updated_at().to_rfc3339(),
147 scope: acc.scope().map(|s| s.to_string()),
148 })
149 .collect();
150
151 Ok(AuthResponse::json(200, &filtered)?)
152 }
153
154 async fn handle_unlink_account<DB: DatabaseAdapter>(
155 &self,
156 req: &AuthRequest,
157 ctx: &AuthContext<DB>,
158 ) -> AuthResult<AuthResponse> {
159 let (user, _session) = self.require_session(req, ctx).await?;
160
161 let unlink_req: UnlinkAccountRequest = match better_auth_core::validate_request_body(req) {
162 Ok(v) => v,
163 Err(resp) => return Ok(resp),
164 };
165
166 let accounts = ctx.database.get_user_accounts(user.id()).await?;
167
168 let has_password = user
170 .metadata()
171 .get("password_hash")
172 .and_then(|v| v.as_str())
173 .is_some();
174
175 let remaining_accounts = accounts
177 .iter()
178 .filter(|acc| acc.provider_id() != unlink_req.provider_id)
179 .count();
180
181 if !has_password && remaining_accounts == 0 {
183 return Err(AuthError::bad_request(
184 "Cannot unlink the last account. You must have at least one authentication method.",
185 ));
186 }
187
188 let account_to_remove = accounts
190 .iter()
191 .find(|acc| acc.provider_id() == unlink_req.provider_id)
192 .ok_or_else(|| AuthError::not_found("No account found with this provider"))?;
193
194 ctx.database.delete_account(account_to_remove.id()).await?;
195
196 let response = StatusResponse { status: true };
197 Ok(AuthResponse::json(200, &response)?)
198 }
199}