Skip to main content

better_auth_api/plugins/
oauth.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use validator::Validate;
5
6use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
7use better_auth_core::{AuthError, AuthResult};
8use better_auth_core::{AuthRequest, AuthResponse, CreateAccount, CreateUser, HttpMethod};
9use better_auth_core::{AuthSession, AuthUser, DatabaseAdapter};
10
11/// OAuth authentication plugin for social sign-in
12pub struct OAuthPlugin {
13    config: OAuthConfig,
14}
15
16#[derive(Debug, Clone, Default)]
17pub struct OAuthConfig {
18    pub providers: HashMap<String, OAuthProvider>,
19}
20
21#[derive(Debug, Clone)]
22pub struct OAuthProvider {
23    pub client_id: String,
24    pub client_secret: String,
25    pub auth_url: String,
26    pub token_url: String,
27    pub user_info_url: String,
28    pub scopes: Vec<String>,
29}
30
31impl OAuthPlugin {
32    pub fn new() -> Self {
33        Self {
34            config: OAuthConfig::default(),
35        }
36    }
37
38    pub fn with_config(config: OAuthConfig) -> Self {
39        Self { config }
40    }
41
42    pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
43        self.config.providers.insert(name.to_string(), provider);
44        self
45    }
46}
47
48impl Default for OAuthPlugin {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54// Request/Response structures for social authentication
55#[derive(Debug, Deserialize, Validate)]
56#[allow(dead_code)]
57struct SocialSignInRequest {
58    #[serde(rename = "callbackURL")]
59    callback_url: Option<String>,
60    #[serde(rename = "newUserCallbackURL")]
61    new_user_callback_url: Option<String>,
62    #[serde(rename = "errorCallbackURL")]
63    error_callback_url: Option<String>,
64    #[validate(length(min = 1, message = "Provider is required"))]
65    provider: String,
66    #[serde(rename = "disableRedirect")]
67    disable_redirect: Option<String>,
68    #[serde(rename = "idToken")]
69    id_token: Option<String>,
70    scopes: Option<String>,
71    #[serde(rename = "requestSignUp")]
72    request_sign_up: Option<String>,
73    #[serde(rename = "loginHint")]
74    login_hint: Option<String>,
75}
76
77#[derive(Debug, Deserialize, Validate)]
78struct LinkSocialRequest {
79    #[serde(rename = "callbackURL")]
80    callback_url: Option<String>,
81    #[validate(length(min = 1, message = "Provider is required"))]
82    provider: String,
83    scopes: Option<String>,
84}
85
86#[derive(Debug, Serialize)]
87struct SocialSignInResponse<U: Serialize> {
88    redirect: bool,
89    token: String,
90    url: Option<String>,
91    user: U,
92}
93
94#[derive(Debug, Serialize)]
95struct LinkSocialResponse {
96    url: String,
97    redirect: bool,
98}
99
100#[async_trait]
101impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
102    fn name(&self) -> &'static str {
103        "oauth"
104    }
105
106    fn routes(&self) -> Vec<AuthRoute> {
107        vec![
108            AuthRoute::post("/sign-in/social", "social_sign_in"),
109            AuthRoute::post("/link-social", "link_social"),
110        ]
111    }
112
113    async fn on_request(
114        &self,
115        req: &AuthRequest,
116        ctx: &AuthContext<DB>,
117    ) -> AuthResult<Option<AuthResponse>> {
118        match (req.method(), req.path()) {
119            (HttpMethod::Post, "/sign-in/social") => {
120                Ok(Some(self.handle_social_sign_in(req, ctx).await?))
121            }
122            (HttpMethod::Post, "/link-social") => {
123                Ok(Some(self.handle_link_social(req, ctx).await?))
124            }
125            _ => Ok(None),
126        }
127    }
128}
129
130// Implementation methods outside the trait
131impl OAuthPlugin {
132    async fn handle_social_sign_in<DB: DatabaseAdapter>(
133        &self,
134        req: &AuthRequest,
135        ctx: &AuthContext<DB>,
136    ) -> AuthResult<AuthResponse> {
137        let signin_req: SocialSignInRequest = match better_auth_core::validate_request_body(req) {
138            Ok(v) => v,
139            Err(resp) => return Ok(resp),
140        };
141
142        // Validate provider
143        if !self.config.providers.contains_key(&signin_req.provider) {
144            return Err(AuthError::bad_request(format!(
145                "Provider '{}' is not configured",
146                signin_req.provider
147            )));
148        }
149
150        // If id_token is provided, verify and create session directly
151        if let Some(id_token) = &signin_req.id_token {
152            return self
153                .handle_id_token_sign_in(id_token, &signin_req, ctx)
154                .await;
155        }
156
157        // Otherwise, generate authorization URL for OAuth flow
158        self.generate_auth_url(&signin_req, ctx).await
159    }
160
161    async fn handle_link_social<DB: DatabaseAdapter>(
162        &self,
163        req: &AuthRequest,
164        ctx: &AuthContext<DB>,
165    ) -> AuthResult<AuthResponse> {
166        let link_req: LinkSocialRequest = match better_auth_core::validate_request_body(req) {
167            Ok(v) => v,
168            Err(resp) => return Ok(resp),
169        };
170
171        // Validate provider
172        if !self.config.providers.contains_key(&link_req.provider) {
173            return Err(AuthError::bad_request(format!(
174                "Provider '{}' is not configured",
175                link_req.provider
176            )));
177        }
178
179        // Generate authorization URL for linking
180        let provider = &self.config.providers[&link_req.provider];
181        let callback_url = link_req.callback_url.unwrap_or_else(|| {
182            format!(
183                "{}/oauth/{}/callback",
184                ctx.config.base_url, link_req.provider
185            )
186        });
187
188        let scopes = if let Some(scopes) = &link_req.scopes {
189            scopes.split(',').map(|s| s.trim().to_string()).collect()
190        } else {
191            provider.scopes.clone()
192        };
193
194        let auth_url = format!(
195            "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state=link_{}",
196            provider.auth_url,
197            provider.client_id,
198            urlencoding::encode(&callback_url),
199            urlencoding::encode(&scopes.join(" ")),
200            uuid::Uuid::new_v4()
201        );
202
203        let response = LinkSocialResponse {
204            url: auth_url,
205            redirect: true,
206        };
207
208        Ok(AuthResponse::json(200, &response)?)
209    }
210
211    async fn handle_id_token_sign_in<DB: DatabaseAdapter>(
212        &self,
213        id_token: &str,
214        signin_req: &SocialSignInRequest,
215        ctx: &AuthContext<DB>,
216    ) -> AuthResult<AuthResponse> {
217        // TODO: Implement proper JWT verification
218        // For now, return a mock implementation that creates a user
219
220        // Mock user creation from ID token
221        let email = format!("user+{}@{}.com", uuid::Uuid::new_v4(), signin_req.provider);
222        let name = format!("User from {}", signin_req.provider);
223
224        // Check if user already exists
225        let existing_user = ctx.database.get_user_by_email(&email).await?;
226        let user = if let Some(user) = existing_user {
227            user
228        } else {
229            // Create new user
230            let create_user = CreateUser::new()
231                .with_email(&email)
232                .with_name(&name)
233                .with_email_verified(true); // Social providers typically verify email
234
235            ctx.database.create_user(create_user).await?
236        };
237
238        // Create account record for this social provider
239        let create_account = CreateAccount {
240            account_id: format!("{}_{}", signin_req.provider, uuid::Uuid::new_v4()),
241            provider_id: signin_req.provider.clone(),
242            user_id: user.id().to_string(),
243            access_token: Some("mock_access_token".to_string()),
244            refresh_token: None,
245            id_token: Some(id_token.to_string()),
246            access_token_expires_at: None,
247            refresh_token_expires_at: None,
248            scope: None,
249            password: None,
250        };
251
252        // Check if account already exists
253        if ctx
254            .database
255            .get_account(&signin_req.provider, &create_account.account_id)
256            .await?
257            .is_none()
258        {
259            ctx.database.create_account(create_account).await?;
260        }
261
262        // Create session
263        let session_manager =
264            better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
265        let session = session_manager.create_session(&user, None, None).await?;
266
267        let response = SocialSignInResponse {
268            redirect: false,
269            token: session.token().to_string(),
270            url: None,
271            user,
272        };
273
274        Ok(AuthResponse::json(200, &response)?)
275    }
276
277    async fn generate_auth_url<DB: DatabaseAdapter>(
278        &self,
279        signin_req: &SocialSignInRequest,
280        ctx: &AuthContext<DB>,
281    ) -> AuthResult<AuthResponse> {
282        let provider = &self.config.providers[&signin_req.provider];
283        let callback_url = signin_req.callback_url.clone().unwrap_or_else(|| {
284            format!(
285                "{}/oauth/{}/callback",
286                ctx.config.base_url, signin_req.provider
287            )
288        });
289
290        let scopes = if let Some(scopes) = &signin_req.scopes {
291            scopes.split(',').map(|s| s.trim().to_string()).collect()
292        } else {
293            provider.scopes.clone()
294        };
295
296        let mut auth_url = format!(
297            "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
298            provider.auth_url,
299            provider.client_id,
300            urlencoding::encode(&callback_url),
301            urlencoding::encode(&scopes.join(" ")),
302            uuid::Uuid::new_v4()
303        );
304
305        if let Some(login_hint) = &signin_req.login_hint {
306            auth_url.push_str(&format!("&login_hint={}", urlencoding::encode(login_hint)));
307        }
308
309        // Return redirect response
310        Ok(AuthResponse::json(
311            200,
312            &serde_json::json!({
313                "url": auth_url,
314                "redirect": true
315            }),
316        )?)
317    }
318}