1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use validator::Validate;
5
6use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
7use better_auth_core::{AuthError, AuthResult};
8use better_auth_core::{AuthRequest, AuthResponse, CreateAccount, CreateUser, HttpMethod, User};
9
10pub struct OAuthPlugin {
12 config: OAuthConfig,
13}
14
15#[derive(Debug, Clone, Default)]
16pub struct OAuthConfig {
17 pub providers: HashMap<String, OAuthProvider>,
18}
19
20#[derive(Debug, Clone)]
21pub struct OAuthProvider {
22 pub client_id: String,
23 pub client_secret: String,
24 pub auth_url: String,
25 pub token_url: String,
26 pub user_info_url: String,
27 pub scopes: Vec<String>,
28}
29
30impl OAuthPlugin {
31 pub fn new() -> Self {
32 Self {
33 config: OAuthConfig::default(),
34 }
35 }
36
37 pub fn with_config(config: OAuthConfig) -> Self {
38 Self { config }
39 }
40
41 pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
42 self.config.providers.insert(name.to_string(), provider);
43 self
44 }
45}
46
47impl Default for OAuthPlugin {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53#[derive(Debug, Deserialize, Validate)]
55#[allow(dead_code)]
56struct SocialSignInRequest {
57 #[serde(rename = "callbackURL")]
58 callback_url: Option<String>,
59 #[serde(rename = "newUserCallbackURL")]
60 new_user_callback_url: Option<String>,
61 #[serde(rename = "errorCallbackURL")]
62 error_callback_url: Option<String>,
63 #[validate(length(min = 1, message = "Provider is required"))]
64 provider: String,
65 #[serde(rename = "disableRedirect")]
66 disable_redirect: Option<String>,
67 #[serde(rename = "idToken")]
68 id_token: Option<String>,
69 scopes: Option<String>,
70 #[serde(rename = "requestSignUp")]
71 request_sign_up: Option<String>,
72 #[serde(rename = "loginHint")]
73 login_hint: Option<String>,
74}
75
76#[derive(Debug, Deserialize, Validate)]
77struct LinkSocialRequest {
78 #[serde(rename = "callbackURL")]
79 callback_url: Option<String>,
80 #[validate(length(min = 1, message = "Provider is required"))]
81 provider: String,
82 scopes: Option<String>,
83}
84
85#[derive(Debug, Serialize)]
86struct SocialSignInResponse {
87 redirect: bool,
88 token: String,
89 url: Option<String>,
90 user: User,
91}
92
93#[derive(Debug, Serialize)]
94struct LinkSocialResponse {
95 url: String,
96 redirect: bool,
97}
98
99#[async_trait]
100impl AuthPlugin for OAuthPlugin {
101 fn name(&self) -> &'static str {
102 "oauth"
103 }
104
105 fn routes(&self) -> Vec<AuthRoute> {
106 vec![
107 AuthRoute::post("/sign-in/social", "social_sign_in"),
108 AuthRoute::post("/link-social", "link_social"),
109 ]
110 }
111
112 async fn on_request(
113 &self,
114 req: &AuthRequest,
115 ctx: &AuthContext,
116 ) -> AuthResult<Option<AuthResponse>> {
117 match (req.method(), req.path()) {
118 (HttpMethod::Post, "/sign-in/social") => {
119 Ok(Some(self.handle_social_sign_in(req, ctx).await?))
120 }
121 (HttpMethod::Post, "/link-social") => {
122 Ok(Some(self.handle_link_social(req, ctx).await?))
123 }
124 _ => Ok(None),
125 }
126 }
127}
128
129impl OAuthPlugin {
131 async fn handle_social_sign_in(
132 &self,
133 req: &AuthRequest,
134 ctx: &AuthContext,
135 ) -> AuthResult<AuthResponse> {
136 let signin_req: SocialSignInRequest = match better_auth_core::validate_request_body(req) {
137 Ok(v) => v,
138 Err(resp) => return Ok(resp),
139 };
140
141 if !self.config.providers.contains_key(&signin_req.provider) {
143 return Err(AuthError::bad_request(format!(
144 "Provider '{}' is not configured",
145 signin_req.provider
146 )));
147 }
148
149 if let Some(id_token) = &signin_req.id_token {
151 return self
152 .handle_id_token_sign_in(id_token, &signin_req, ctx)
153 .await;
154 }
155
156 self.generate_auth_url(&signin_req, ctx).await
158 }
159
160 async fn handle_link_social(
161 &self,
162 req: &AuthRequest,
163 ctx: &AuthContext,
164 ) -> AuthResult<AuthResponse> {
165 let link_req: LinkSocialRequest = match better_auth_core::validate_request_body(req) {
166 Ok(v) => v,
167 Err(resp) => return Ok(resp),
168 };
169
170 if !self.config.providers.contains_key(&link_req.provider) {
172 return Err(AuthError::bad_request(format!(
173 "Provider '{}' is not configured",
174 link_req.provider
175 )));
176 }
177
178 let provider = &self.config.providers[&link_req.provider];
180 let callback_url = link_req.callback_url.unwrap_or_else(|| {
181 format!(
182 "{}/oauth/{}/callback",
183 ctx.config.base_url, link_req.provider
184 )
185 });
186
187 let scopes = if let Some(scopes) = &link_req.scopes {
188 scopes.split(',').map(|s| s.trim().to_string()).collect()
189 } else {
190 provider.scopes.clone()
191 };
192
193 let auth_url = format!(
194 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state=link_{}",
195 provider.auth_url,
196 provider.client_id,
197 urlencoding::encode(&callback_url),
198 urlencoding::encode(&scopes.join(" ")),
199 uuid::Uuid::new_v4()
200 );
201
202 let response = LinkSocialResponse {
203 url: auth_url,
204 redirect: true,
205 };
206
207 Ok(AuthResponse::json(200, &response)?)
208 }
209
210 async fn handle_id_token_sign_in(
211 &self,
212 id_token: &str,
213 signin_req: &SocialSignInRequest,
214 ctx: &AuthContext,
215 ) -> AuthResult<AuthResponse> {
216 let email = format!("user+{}@{}.com", uuid::Uuid::new_v4(), signin_req.provider);
221 let name = format!("User from {}", signin_req.provider);
222
223 let existing_user = ctx.database.get_user_by_email(&email).await?;
225 let user = if let Some(user) = existing_user {
226 user
227 } else {
228 let create_user = CreateUser::new()
230 .with_email(&email)
231 .with_name(&name)
232 .with_email_verified(true); ctx.database.create_user(create_user).await?
235 };
236
237 let create_account = CreateAccount {
239 account_id: format!("{}_{}", signin_req.provider, uuid::Uuid::new_v4()),
240 provider_id: signin_req.provider.clone(),
241 user_id: user.id.clone(),
242 access_token: Some("mock_access_token".to_string()),
243 refresh_token: None,
244 id_token: Some(id_token.to_string()),
245 access_token_expires_at: None,
246 refresh_token_expires_at: None,
247 scope: None,
248 password: None,
249 };
250
251 if ctx
253 .database
254 .get_account(&signin_req.provider, &create_account.account_id)
255 .await?
256 .is_none()
257 {
258 ctx.database.create_account(create_account).await?;
259 }
260
261 let session_manager =
263 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
264 let session = session_manager.create_session(&user, None, None).await?;
265
266 let response = SocialSignInResponse {
267 redirect: false,
268 token: session.token,
269 url: None,
270 user,
271 };
272
273 Ok(AuthResponse::json(200, &response)?)
274 }
275
276 async fn generate_auth_url(
277 &self,
278 signin_req: &SocialSignInRequest,
279 ctx: &AuthContext,
280 ) -> AuthResult<AuthResponse> {
281 let provider = &self.config.providers[&signin_req.provider];
282 let callback_url = signin_req.callback_url.clone().unwrap_or_else(|| {
283 format!(
284 "{}/oauth/{}/callback",
285 ctx.config.base_url, signin_req.provider
286 )
287 });
288
289 let scopes = if let Some(scopes) = &signin_req.scopes {
290 scopes.split(',').map(|s| s.trim().to_string()).collect()
291 } else {
292 provider.scopes.clone()
293 };
294
295 let mut auth_url = format!(
296 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
297 provider.auth_url,
298 provider.client_id,
299 urlencoding::encode(&callback_url),
300 urlencoding::encode(&scopes.join(" ")),
301 uuid::Uuid::new_v4()
302 );
303
304 if let Some(login_hint) = &signin_req.login_hint {
305 auth_url.push_str(&format!("&login_hint={}", urlencoding::encode(login_hint)));
306 }
307
308 Ok(AuthResponse::json(
310 200,
311 &serde_json::json!({
312 "url": auth_url,
313 "redirect": true
314 }),
315 )?)
316 }
317}