Skip to main content

better_auth_api/plugins/
oauth.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use validator::Validate;
5
6use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
7use better_auth_core::{AuthError, AuthResult};
8use better_auth_core::{AuthRequest, AuthResponse, CreateAccount, CreateUser, HttpMethod, User};
9
10/// OAuth authentication plugin for social sign-in
11pub struct OAuthPlugin {
12    config: OAuthConfig,
13}
14
15#[derive(Debug, Clone, Default)]
16pub struct OAuthConfig {
17    pub providers: HashMap<String, OAuthProvider>,
18}
19
20#[derive(Debug, Clone)]
21pub struct OAuthProvider {
22    pub client_id: String,
23    pub client_secret: String,
24    pub auth_url: String,
25    pub token_url: String,
26    pub user_info_url: String,
27    pub scopes: Vec<String>,
28}
29
30impl OAuthPlugin {
31    pub fn new() -> Self {
32        Self {
33            config: OAuthConfig::default(),
34        }
35    }
36
37    pub fn with_config(config: OAuthConfig) -> Self {
38        Self { config }
39    }
40
41    pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
42        self.config.providers.insert(name.to_string(), provider);
43        self
44    }
45}
46
47impl Default for OAuthPlugin {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53// Request/Response structures for social authentication
54#[derive(Debug, Deserialize, Validate)]
55#[allow(dead_code)]
56struct SocialSignInRequest {
57    #[serde(rename = "callbackURL")]
58    callback_url: Option<String>,
59    #[serde(rename = "newUserCallbackURL")]
60    new_user_callback_url: Option<String>,
61    #[serde(rename = "errorCallbackURL")]
62    error_callback_url: Option<String>,
63    #[validate(length(min = 1, message = "Provider is required"))]
64    provider: String,
65    #[serde(rename = "disableRedirect")]
66    disable_redirect: Option<String>,
67    #[serde(rename = "idToken")]
68    id_token: Option<String>,
69    scopes: Option<String>,
70    #[serde(rename = "requestSignUp")]
71    request_sign_up: Option<String>,
72    #[serde(rename = "loginHint")]
73    login_hint: Option<String>,
74}
75
76#[derive(Debug, Deserialize, Validate)]
77struct LinkSocialRequest {
78    #[serde(rename = "callbackURL")]
79    callback_url: Option<String>,
80    #[validate(length(min = 1, message = "Provider is required"))]
81    provider: String,
82    scopes: Option<String>,
83}
84
85#[derive(Debug, Serialize)]
86struct SocialSignInResponse {
87    redirect: bool,
88    token: String,
89    url: Option<String>,
90    user: User,
91}
92
93#[derive(Debug, Serialize)]
94struct LinkSocialResponse {
95    url: String,
96    redirect: bool,
97}
98
99#[async_trait]
100impl AuthPlugin for OAuthPlugin {
101    fn name(&self) -> &'static str {
102        "oauth"
103    }
104
105    fn routes(&self) -> Vec<AuthRoute> {
106        vec![
107            AuthRoute::post("/sign-in/social", "social_sign_in"),
108            AuthRoute::post("/link-social", "link_social"),
109        ]
110    }
111
112    async fn on_request(
113        &self,
114        req: &AuthRequest,
115        ctx: &AuthContext,
116    ) -> AuthResult<Option<AuthResponse>> {
117        match (req.method(), req.path()) {
118            (HttpMethod::Post, "/sign-in/social") => {
119                Ok(Some(self.handle_social_sign_in(req, ctx).await?))
120            }
121            (HttpMethod::Post, "/link-social") => {
122                Ok(Some(self.handle_link_social(req, ctx).await?))
123            }
124            _ => Ok(None),
125        }
126    }
127}
128
129// Implementation methods outside the trait
130impl OAuthPlugin {
131    async fn handle_social_sign_in(
132        &self,
133        req: &AuthRequest,
134        ctx: &AuthContext,
135    ) -> AuthResult<AuthResponse> {
136        let signin_req: SocialSignInRequest = match better_auth_core::validate_request_body(req) {
137            Ok(v) => v,
138            Err(resp) => return Ok(resp),
139        };
140
141        // Validate provider
142        if !self.config.providers.contains_key(&signin_req.provider) {
143            return Err(AuthError::bad_request(format!(
144                "Provider '{}' is not configured",
145                signin_req.provider
146            )));
147        }
148
149        // If id_token is provided, verify and create session directly
150        if let Some(id_token) = &signin_req.id_token {
151            return self
152                .handle_id_token_sign_in(id_token, &signin_req, ctx)
153                .await;
154        }
155
156        // Otherwise, generate authorization URL for OAuth flow
157        self.generate_auth_url(&signin_req, ctx).await
158    }
159
160    async fn handle_link_social(
161        &self,
162        req: &AuthRequest,
163        ctx: &AuthContext,
164    ) -> AuthResult<AuthResponse> {
165        let link_req: LinkSocialRequest = match better_auth_core::validate_request_body(req) {
166            Ok(v) => v,
167            Err(resp) => return Ok(resp),
168        };
169
170        // Validate provider
171        if !self.config.providers.contains_key(&link_req.provider) {
172            return Err(AuthError::bad_request(format!(
173                "Provider '{}' is not configured",
174                link_req.provider
175            )));
176        }
177
178        // Generate authorization URL for linking
179        let provider = &self.config.providers[&link_req.provider];
180        let callback_url = link_req.callback_url.unwrap_or_else(|| {
181            format!(
182                "{}/oauth/{}/callback",
183                ctx.config.base_url, link_req.provider
184            )
185        });
186
187        let scopes = if let Some(scopes) = &link_req.scopes {
188            scopes.split(',').map(|s| s.trim().to_string()).collect()
189        } else {
190            provider.scopes.clone()
191        };
192
193        let auth_url = format!(
194            "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state=link_{}",
195            provider.auth_url,
196            provider.client_id,
197            urlencoding::encode(&callback_url),
198            urlencoding::encode(&scopes.join(" ")),
199            uuid::Uuid::new_v4()
200        );
201
202        let response = LinkSocialResponse {
203            url: auth_url,
204            redirect: true,
205        };
206
207        Ok(AuthResponse::json(200, &response)?)
208    }
209
210    async fn handle_id_token_sign_in(
211        &self,
212        id_token: &str,
213        signin_req: &SocialSignInRequest,
214        ctx: &AuthContext,
215    ) -> AuthResult<AuthResponse> {
216        // TODO: Implement proper JWT verification
217        // For now, return a mock implementation that creates a user
218
219        // Mock user creation from ID token
220        let email = format!("user+{}@{}.com", uuid::Uuid::new_v4(), signin_req.provider);
221        let name = format!("User from {}", signin_req.provider);
222
223        // Check if user already exists
224        let existing_user = ctx.database.get_user_by_email(&email).await?;
225        let user = if let Some(user) = existing_user {
226            user
227        } else {
228            // Create new user
229            let create_user = CreateUser::new()
230                .with_email(&email)
231                .with_name(&name)
232                .with_email_verified(true); // Social providers typically verify email
233
234            ctx.database.create_user(create_user).await?
235        };
236
237        // Create account record for this social provider
238        let create_account = CreateAccount {
239            account_id: format!("{}_{}", signin_req.provider, uuid::Uuid::new_v4()),
240            provider_id: signin_req.provider.clone(),
241            user_id: user.id.clone(),
242            access_token: Some("mock_access_token".to_string()),
243            refresh_token: None,
244            id_token: Some(id_token.to_string()),
245            access_token_expires_at: None,
246            refresh_token_expires_at: None,
247            scope: None,
248            password: None,
249        };
250
251        // Check if account already exists
252        if ctx
253            .database
254            .get_account(&signin_req.provider, &create_account.account_id)
255            .await?
256            .is_none()
257        {
258            ctx.database.create_account(create_account).await?;
259        }
260
261        // Create session
262        let session_manager =
263            better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
264        let session = session_manager.create_session(&user, None, None).await?;
265
266        let response = SocialSignInResponse {
267            redirect: false,
268            token: session.token,
269            url: None,
270            user,
271        };
272
273        Ok(AuthResponse::json(200, &response)?)
274    }
275
276    async fn generate_auth_url(
277        &self,
278        signin_req: &SocialSignInRequest,
279        ctx: &AuthContext,
280    ) -> AuthResult<AuthResponse> {
281        let provider = &self.config.providers[&signin_req.provider];
282        let callback_url = signin_req.callback_url.clone().unwrap_or_else(|| {
283            format!(
284                "{}/oauth/{}/callback",
285                ctx.config.base_url, signin_req.provider
286            )
287        });
288
289        let scopes = if let Some(scopes) = &signin_req.scopes {
290            scopes.split(',').map(|s| s.trim().to_string()).collect()
291        } else {
292            provider.scopes.clone()
293        };
294
295        let mut auth_url = format!(
296            "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
297            provider.auth_url,
298            provider.client_id,
299            urlencoding::encode(&callback_url),
300            urlencoding::encode(&scopes.join(" ")),
301            uuid::Uuid::new_v4()
302        );
303
304        if let Some(login_hint) = &signin_req.login_hint {
305            auth_url.push_str(&format!("&login_hint={}", urlencoding::encode(login_hint)));
306        }
307
308        // Return redirect response
309        Ok(AuthResponse::json(
310            200,
311            &serde_json::json!({
312                "url": auth_url,
313                "redirect": true
314            }),
315        )?)
316    }
317}