1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use validator::Validate;
5
6use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
7use better_auth_core::{AuthError, AuthResult};
8use better_auth_core::{AuthRequest, AuthResponse, CreateAccount, CreateUser, HttpMethod};
9use better_auth_core::{AuthSession, AuthUser, DatabaseAdapter};
10
11pub struct OAuthPlugin {
13 config: OAuthConfig,
14}
15
16#[derive(Debug, Clone, Default)]
17pub struct OAuthConfig {
18 pub providers: HashMap<String, OAuthProvider>,
19}
20
21#[derive(Debug, Clone)]
22pub struct OAuthProvider {
23 pub client_id: String,
24 pub client_secret: String,
25 pub auth_url: String,
26 pub token_url: String,
27 pub user_info_url: String,
28 pub scopes: Vec<String>,
29}
30
31impl OAuthPlugin {
32 pub fn new() -> Self {
33 Self {
34 config: OAuthConfig::default(),
35 }
36 }
37
38 pub fn with_config(config: OAuthConfig) -> Self {
39 Self { config }
40 }
41
42 pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
43 self.config.providers.insert(name.to_string(), provider);
44 self
45 }
46}
47
48impl Default for OAuthPlugin {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54#[derive(Debug, Deserialize, Validate)]
56#[allow(dead_code)]
57struct SocialSignInRequest {
58 #[serde(rename = "callbackURL")]
59 callback_url: Option<String>,
60 #[serde(rename = "newUserCallbackURL")]
61 new_user_callback_url: Option<String>,
62 #[serde(rename = "errorCallbackURL")]
63 error_callback_url: Option<String>,
64 #[validate(length(min = 1, message = "Provider is required"))]
65 provider: String,
66 #[serde(rename = "disableRedirect")]
67 disable_redirect: Option<String>,
68 #[serde(rename = "idToken")]
69 id_token: Option<String>,
70 scopes: Option<String>,
71 #[serde(rename = "requestSignUp")]
72 request_sign_up: Option<String>,
73 #[serde(rename = "loginHint")]
74 login_hint: Option<String>,
75}
76
77#[derive(Debug, Deserialize, Validate)]
78struct LinkSocialRequest {
79 #[serde(rename = "callbackURL")]
80 callback_url: Option<String>,
81 #[validate(length(min = 1, message = "Provider is required"))]
82 provider: String,
83 scopes: Option<String>,
84}
85
86#[derive(Debug, Serialize)]
87struct SocialSignInResponse<U: Serialize> {
88 redirect: bool,
89 token: String,
90 url: Option<String>,
91 user: U,
92}
93
94#[derive(Debug, Serialize)]
95struct LinkSocialResponse {
96 url: String,
97 redirect: bool,
98}
99
100#[async_trait]
101impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
102 fn name(&self) -> &'static str {
103 "oauth"
104 }
105
106 fn routes(&self) -> Vec<AuthRoute> {
107 vec![
108 AuthRoute::post("/sign-in/social", "social_sign_in"),
109 AuthRoute::post("/link-social", "link_social"),
110 ]
111 }
112
113 async fn on_request(
114 &self,
115 req: &AuthRequest,
116 ctx: &AuthContext<DB>,
117 ) -> AuthResult<Option<AuthResponse>> {
118 match (req.method(), req.path()) {
119 (HttpMethod::Post, "/sign-in/social") => {
120 Ok(Some(self.handle_social_sign_in(req, ctx).await?))
121 }
122 (HttpMethod::Post, "/link-social") => {
123 Ok(Some(self.handle_link_social(req, ctx).await?))
124 }
125 _ => Ok(None),
126 }
127 }
128}
129
130impl OAuthPlugin {
132 async fn handle_social_sign_in<DB: DatabaseAdapter>(
133 &self,
134 req: &AuthRequest,
135 ctx: &AuthContext<DB>,
136 ) -> AuthResult<AuthResponse> {
137 let signin_req: SocialSignInRequest = match better_auth_core::validate_request_body(req) {
138 Ok(v) => v,
139 Err(resp) => return Ok(resp),
140 };
141
142 if !self.config.providers.contains_key(&signin_req.provider) {
144 return Err(AuthError::bad_request(format!(
145 "Provider '{}' is not configured",
146 signin_req.provider
147 )));
148 }
149
150 if let Some(id_token) = &signin_req.id_token {
152 return self
153 .handle_id_token_sign_in(id_token, &signin_req, ctx)
154 .await;
155 }
156
157 self.generate_auth_url(&signin_req, ctx).await
159 }
160
161 async fn handle_link_social<DB: DatabaseAdapter>(
162 &self,
163 req: &AuthRequest,
164 ctx: &AuthContext<DB>,
165 ) -> AuthResult<AuthResponse> {
166 let link_req: LinkSocialRequest = match better_auth_core::validate_request_body(req) {
167 Ok(v) => v,
168 Err(resp) => return Ok(resp),
169 };
170
171 if !self.config.providers.contains_key(&link_req.provider) {
173 return Err(AuthError::bad_request(format!(
174 "Provider '{}' is not configured",
175 link_req.provider
176 )));
177 }
178
179 let provider = &self.config.providers[&link_req.provider];
181 let callback_url = link_req.callback_url.unwrap_or_else(|| {
182 format!(
183 "{}/oauth/{}/callback",
184 ctx.config.base_url, link_req.provider
185 )
186 });
187
188 let scopes = if let Some(scopes) = &link_req.scopes {
189 scopes.split(',').map(|s| s.trim().to_string()).collect()
190 } else {
191 provider.scopes.clone()
192 };
193
194 let auth_url = format!(
195 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state=link_{}",
196 provider.auth_url,
197 provider.client_id,
198 urlencoding::encode(&callback_url),
199 urlencoding::encode(&scopes.join(" ")),
200 uuid::Uuid::new_v4()
201 );
202
203 let response = LinkSocialResponse {
204 url: auth_url,
205 redirect: true,
206 };
207
208 Ok(AuthResponse::json(200, &response)?)
209 }
210
211 async fn handle_id_token_sign_in<DB: DatabaseAdapter>(
212 &self,
213 id_token: &str,
214 signin_req: &SocialSignInRequest,
215 ctx: &AuthContext<DB>,
216 ) -> AuthResult<AuthResponse> {
217 let email = format!("user+{}@{}.com", uuid::Uuid::new_v4(), signin_req.provider);
222 let name = format!("User from {}", signin_req.provider);
223
224 let existing_user = ctx.database.get_user_by_email(&email).await?;
226 let user = if let Some(user) = existing_user {
227 user
228 } else {
229 let create_user = CreateUser::new()
231 .with_email(&email)
232 .with_name(&name)
233 .with_email_verified(true); ctx.database.create_user(create_user).await?
236 };
237
238 let create_account = CreateAccount {
240 account_id: format!("{}_{}", signin_req.provider, uuid::Uuid::new_v4()),
241 provider_id: signin_req.provider.clone(),
242 user_id: user.id().to_string(),
243 access_token: Some("mock_access_token".to_string()),
244 refresh_token: None,
245 id_token: Some(id_token.to_string()),
246 access_token_expires_at: None,
247 refresh_token_expires_at: None,
248 scope: None,
249 password: None,
250 };
251
252 if ctx
254 .database
255 .get_account(&signin_req.provider, &create_account.account_id)
256 .await?
257 .is_none()
258 {
259 ctx.database.create_account(create_account).await?;
260 }
261
262 let session_manager =
264 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
265 let session = session_manager.create_session(&user, None, None).await?;
266
267 let response = SocialSignInResponse {
268 redirect: false,
269 token: session.token().to_string(),
270 url: None,
271 user,
272 };
273
274 Ok(AuthResponse::json(200, &response)?)
275 }
276
277 async fn generate_auth_url<DB: DatabaseAdapter>(
278 &self,
279 signin_req: &SocialSignInRequest,
280 ctx: &AuthContext<DB>,
281 ) -> AuthResult<AuthResponse> {
282 let provider = &self.config.providers[&signin_req.provider];
283 let callback_url = signin_req.callback_url.clone().unwrap_or_else(|| {
284 format!(
285 "{}/oauth/{}/callback",
286 ctx.config.base_url, signin_req.provider
287 )
288 });
289
290 let scopes = if let Some(scopes) = &signin_req.scopes {
291 scopes.split(',').map(|s| s.trim().to_string()).collect()
292 } else {
293 provider.scopes.clone()
294 };
295
296 let mut auth_url = format!(
297 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
298 provider.auth_url,
299 provider.client_id,
300 urlencoding::encode(&callback_url),
301 urlencoding::encode(&scopes.join(" ")),
302 uuid::Uuid::new_v4()
303 );
304
305 if let Some(login_hint) = &signin_req.login_hint {
306 auth_url.push_str(&format!("&login_hint={}", urlencoding::encode(login_hint)));
307 }
308
309 Ok(AuthResponse::json(
311 200,
312 &serde_json::json!({
313 "url": auth_url,
314 "redirect": true
315 }),
316 )?)
317 }
318}