better_auth_api/plugins/
account_management.rs1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use validator::Validate;
4
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute, SessionManager};
6use better_auth_core::{AuthError, AuthResult};
7use better_auth_core::{AuthRequest, AuthResponse, HttpMethod, Session, User};
8
9pub struct AccountManagementPlugin {
11 config: AccountManagementConfig,
12}
13
14#[derive(Debug, Clone)]
15pub struct AccountManagementConfig {
16 pub require_authentication: bool,
17}
18
19#[derive(Debug, Deserialize, Validate)]
20struct UnlinkAccountRequest {
21 #[serde(rename = "providerId")]
22 #[validate(length(min = 1, message = "Provider ID is required"))]
23 provider_id: String,
24}
25
26#[derive(Debug, Serialize)]
27struct AccountResponse {
28 id: String,
29 #[serde(rename = "accountId")]
30 account_id: String,
31 #[serde(rename = "providerId")]
32 provider_id: String,
33 #[serde(rename = "userId")]
34 user_id: String,
35 #[serde(rename = "createdAt")]
36 created_at: String,
37 #[serde(rename = "updatedAt")]
38 updated_at: String,
39 scope: Option<String>,
40}
41
42#[derive(Debug, Serialize)]
43struct StatusResponse {
44 status: bool,
45}
46
47impl AccountManagementPlugin {
48 pub fn new() -> Self {
49 Self {
50 config: AccountManagementConfig::default(),
51 }
52 }
53
54 pub fn with_config(config: AccountManagementConfig) -> Self {
55 Self { config }
56 }
57
58 pub fn require_authentication(mut self, require: bool) -> Self {
59 self.config.require_authentication = require;
60 self
61 }
62}
63
64impl Default for AccountManagementPlugin {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl Default for AccountManagementConfig {
71 fn default() -> Self {
72 Self {
73 require_authentication: true,
74 }
75 }
76}
77
78#[async_trait]
79impl AuthPlugin for AccountManagementPlugin {
80 fn name(&self) -> &'static str {
81 "account-management"
82 }
83
84 fn routes(&self) -> Vec<AuthRoute> {
85 vec![
86 AuthRoute::get("/list-accounts", "list_accounts"),
87 AuthRoute::post("/unlink-account", "unlink_account"),
88 ]
89 }
90
91 async fn on_request(
92 &self,
93 req: &AuthRequest,
94 ctx: &AuthContext,
95 ) -> AuthResult<Option<AuthResponse>> {
96 match (req.method(), req.path()) {
97 (HttpMethod::Get, "/list-accounts") => {
98 Ok(Some(self.handle_list_accounts(req, ctx).await?))
99 }
100 (HttpMethod::Post, "/unlink-account") => {
101 Ok(Some(self.handle_unlink_account(req, ctx).await?))
102 }
103 _ => Ok(None),
104 }
105 }
106}
107
108impl AccountManagementPlugin {
109 async fn require_session(
110 &self,
111 req: &AuthRequest,
112 ctx: &AuthContext,
113 ) -> AuthResult<(User, Session)> {
114 let session_manager = SessionManager::new(ctx.config.clone(), ctx.database.clone());
115
116 if let Some(token) = session_manager.extract_session_token(req)
117 && let Some(session) = session_manager.get_session(&token).await?
118 && let Some(user) = ctx.database.get_user_by_id(&session.user_id).await?
119 {
120 return Ok((user, session));
121 }
122
123 Err(AuthError::Unauthenticated)
124 }
125
126 async fn handle_list_accounts(
127 &self,
128 req: &AuthRequest,
129 ctx: &AuthContext,
130 ) -> AuthResult<AuthResponse> {
131 let (user, _session) = self.require_session(req, ctx).await?;
132
133 let accounts = ctx.database.get_user_accounts(&user.id).await?;
134
135 let filtered: Vec<AccountResponse> = accounts
137 .iter()
138 .map(|acc| AccountResponse {
139 id: acc.id.clone(),
140 account_id: acc.account_id.clone(),
141 provider_id: acc.provider_id.clone(),
142 user_id: acc.user_id.clone(),
143 created_at: acc.created_at.to_rfc3339(),
144 updated_at: acc.updated_at.to_rfc3339(),
145 scope: acc.scope.clone(),
146 })
147 .collect();
148
149 Ok(AuthResponse::json(200, &filtered)?)
150 }
151
152 async fn handle_unlink_account(
153 &self,
154 req: &AuthRequest,
155 ctx: &AuthContext,
156 ) -> AuthResult<AuthResponse> {
157 let (user, _session) = self.require_session(req, ctx).await?;
158
159 let unlink_req: UnlinkAccountRequest = match better_auth_core::validate_request_body(req) {
160 Ok(v) => v,
161 Err(resp) => return Ok(resp),
162 };
163
164 let accounts = ctx.database.get_user_accounts(&user.id).await?;
165
166 let has_password = user
168 .metadata
169 .get("password_hash")
170 .and_then(|v| v.as_str())
171 .is_some();
172
173 let remaining_accounts = accounts
175 .iter()
176 .filter(|acc| acc.provider_id != unlink_req.provider_id)
177 .count();
178
179 if !has_password && remaining_accounts == 0 {
181 return Err(AuthError::bad_request(
182 "Cannot unlink the last account. You must have at least one authentication method.",
183 ));
184 }
185
186 let account_to_remove = accounts
188 .iter()
189 .find(|acc| acc.provider_id == unlink_req.provider_id)
190 .ok_or_else(|| AuthError::not_found("No account found with this provider"))?;
191
192 ctx.database.delete_account(&account_to_remove.id).await?;
193
194 let response = StatusResponse { status: true };
195 Ok(AuthResponse::json(200, &response)?)
196 }
197}