better_auth/plugins/
oauth.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5use crate::core::{AuthPlugin, AuthRoute, AuthContext};
6use crate::types::{AuthRequest, AuthResponse, HttpMethod, User, CreateUser, CreateAccount};
7use crate::error::{AuthError, AuthResult};
8
9/// OAuth authentication plugin for social sign-in
10pub struct OAuthPlugin {
11    config: OAuthConfig,
12}
13
14#[derive(Debug, Clone)]
15pub struct OAuthConfig {
16    pub providers: HashMap<String, OAuthProvider>,
17}
18
19#[derive(Debug, Clone)]
20pub struct OAuthProvider {
21    pub client_id: String,
22    pub client_secret: String,
23    pub auth_url: String,
24    pub token_url: String,
25    pub user_info_url: String,
26    pub scopes: Vec<String>,
27}
28
29impl OAuthPlugin {
30    pub fn new() -> Self {
31        Self {
32            config: OAuthConfig::default(),
33        }
34    }
35    
36    pub fn with_config(config: OAuthConfig) -> Self {
37        Self { config }
38    }
39    
40    pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
41        self.config.providers.insert(name.to_string(), provider);
42        self
43    }
44}
45
46impl Default for OAuthPlugin {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl Default for OAuthConfig {
53    fn default() -> Self {
54        Self {
55            providers: HashMap::new(),
56        }
57    }
58}
59
60// Request/Response structures for social authentication
61#[derive(Debug, Deserialize)]
62struct SocialSignInRequest {
63    #[serde(rename = "callbackURL")]
64    callback_url: Option<String>,
65    #[serde(rename = "newUserCallbackURL")]
66    new_user_callback_url: Option<String>,
67    #[serde(rename = "errorCallbackURL")]
68    error_callback_url: Option<String>,
69    provider: String,
70    #[serde(rename = "disableRedirect")]
71    disable_redirect: Option<String>,
72    #[serde(rename = "idToken")]
73    id_token: Option<String>,
74    scopes: Option<String>,
75    #[serde(rename = "requestSignUp")]
76    request_sign_up: Option<String>,
77    #[serde(rename = "loginHint")]
78    login_hint: Option<String>,
79}
80
81#[derive(Debug, Deserialize)]
82struct LinkSocialRequest {
83    #[serde(rename = "callbackURL")]
84    callback_url: Option<String>,
85    provider: String,
86    scopes: Option<String>,
87}
88
89#[derive(Debug, Serialize)]
90struct SocialSignInResponse {
91    redirect: bool,
92    token: String,
93    url: Option<String>,
94    user: User,
95}
96
97#[derive(Debug, Serialize)]
98struct LinkSocialResponse {
99    url: String,
100    redirect: bool,
101}
102
103#[async_trait]
104impl AuthPlugin for OAuthPlugin {
105    fn name(&self) -> &'static str {
106        "oauth"
107    }
108    
109    fn routes(&self) -> Vec<AuthRoute> {
110        vec![
111            AuthRoute::post("/sign-in/social", "social_sign_in"),
112            AuthRoute::post("/link-social", "link_social"),
113        ]
114    }
115    
116    async fn on_request(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<Option<AuthResponse>> {
117        match (req.method(), req.path()) {
118            (HttpMethod::Post, "/sign-in/social") => {
119                Ok(Some(self.handle_social_sign_in(req, ctx).await?))
120            },
121            (HttpMethod::Post, "/link-social") => {
122                Ok(Some(self.handle_link_social(req, ctx).await?))
123            },
124            _ => Ok(None),
125        }
126    }
127}
128
129// Implementation methods outside the trait
130impl OAuthPlugin {
131    async fn handle_social_sign_in(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
132        let signin_req: SocialSignInRequest = match req.body_as_json() {
133            Ok(req) => req,
134            Err(e) => {
135                return Ok(AuthResponse::json(400, &serde_json::json!({
136                    "error": "Invalid request",
137                    "message": format!("Invalid JSON: {}", e)
138                }))?);
139            }
140        };
141        
142        // Validate provider
143        if !self.config.providers.contains_key(&signin_req.provider) {
144            return Ok(AuthResponse::json(400, &serde_json::json!({
145                "error": "Invalid provider",
146                "message": format!("Provider '{}' is not configured", signin_req.provider)
147            }))?);
148        }
149        
150        // If id_token is provided, verify and create session directly
151        if let Some(id_token) = &signin_req.id_token {
152            return self.handle_id_token_sign_in(id_token, &signin_req, ctx).await;
153        }
154        
155        // Otherwise, generate authorization URL for OAuth flow
156        self.generate_auth_url(&signin_req, ctx).await
157    }
158    
159    async fn handle_link_social(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
160        let link_req: LinkSocialRequest = match req.body_as_json() {
161            Ok(req) => req,
162            Err(e) => {
163                return Ok(AuthResponse::json(400, &serde_json::json!({
164                    "error": "Invalid request",
165                    "message": format!("Invalid JSON: {}", e)
166                }))?);
167            }
168        };
169        
170        // Validate provider
171        if !self.config.providers.contains_key(&link_req.provider) {
172            return Ok(AuthResponse::json(400, &serde_json::json!({
173                "error": "Invalid provider",
174                "message": format!("Provider '{}' is not configured", link_req.provider)
175            }))?);
176        }
177        
178        // Generate authorization URL for linking
179        let provider = &self.config.providers[&link_req.provider];
180        let callback_url = link_req.callback_url.unwrap_or_else(|| {
181            format!("{}/oauth/{}/callback", ctx.config.base_url, link_req.provider)
182        });
183        
184        let scopes = if let Some(scopes) = &link_req.scopes {
185            scopes.split(',').map(|s| s.trim().to_string()).collect()
186        } else {
187            provider.scopes.clone()
188        };
189        
190        let auth_url = format!(
191            "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state=link_{}",
192            provider.auth_url,
193            provider.client_id,
194            urlencoding::encode(&callback_url),
195            urlencoding::encode(&scopes.join(" ")),
196            uuid::Uuid::new_v4()
197        );
198        
199        let response = LinkSocialResponse {
200            url: auth_url,
201            redirect: true,
202        };
203        
204        Ok(AuthResponse::json(200, &response)?)
205    }
206    
207    async fn handle_id_token_sign_in(
208        &self,
209        id_token: &str,
210        signin_req: &SocialSignInRequest,
211        ctx: &AuthContext,
212    ) -> AuthResult<AuthResponse> {
213        // TODO: Implement proper JWT verification
214        // For now, return a mock implementation that creates a user
215        
216        // Mock user creation from ID token
217        let email = format!("user+{}@{}.com", uuid::Uuid::new_v4(), signin_req.provider);
218        let name = format!("User from {}", signin_req.provider);
219        
220        // Check if user already exists
221        let existing_user = ctx.database.get_user_by_email(&email).await?;
222        let user = if let Some(user) = existing_user {
223            user
224        } else {
225            // Create new user
226            let create_user = CreateUser::new()
227                .with_email(&email)
228                .with_name(&name)
229                .with_email_verified(true); // Social providers typically verify email
230            
231            ctx.database.create_user(create_user).await?
232        };
233        
234        // Create account record for this social provider
235        let create_account = CreateAccount {
236            account_id: format!("{}_{}", signin_req.provider, uuid::Uuid::new_v4()),
237            provider_id: signin_req.provider.clone(),
238            user_id: user.id.clone(),
239            access_token: Some("mock_access_token".to_string()),
240            refresh_token: None,
241            id_token: Some(id_token.to_string()),
242            access_token_expires_at: None,
243            refresh_token_expires_at: None,
244            scope: None,
245            password: None,
246        };
247        
248        // Check if account already exists
249        if ctx.database.get_account(&signin_req.provider, &create_account.account_id).await?.is_none() {
250            ctx.database.create_account(create_account).await?;
251        }
252        
253        // Create session
254        let session_manager = crate::core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
255        let session = session_manager.create_session(&user, None, None).await?;
256        
257        let response = SocialSignInResponse {
258            redirect: false,
259            token: session.token,
260            url: None,
261            user,
262        };
263        
264        Ok(AuthResponse::json(200, &response)?)
265    }
266    
267    async fn generate_auth_url(
268        &self,
269        signin_req: &SocialSignInRequest,
270        ctx: &AuthContext,
271    ) -> AuthResult<AuthResponse> {
272        let provider = &self.config.providers[&signin_req.provider];
273        let callback_url = signin_req.callback_url.clone().unwrap_or_else(|| {
274            format!("{}/oauth/{}/callback", ctx.config.base_url, signin_req.provider)
275        });
276        
277        let scopes = if let Some(scopes) = &signin_req.scopes {
278            scopes.split(',').map(|s| s.trim().to_string()).collect()
279        } else {
280            provider.scopes.clone()
281        };
282        
283        let mut auth_url = format!(
284            "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
285            provider.auth_url,
286            provider.client_id,
287            urlencoding::encode(&callback_url),
288            urlencoding::encode(&scopes.join(" ")),
289            uuid::Uuid::new_v4()
290        );
291        
292        if let Some(login_hint) = &signin_req.login_hint {
293            auth_url.push_str(&format!("&login_hint={}", urlencoding::encode(login_hint)));
294        }
295        
296        // Return redirect response
297        Ok(AuthResponse::json(200, &serde_json::json!({
298            "url": auth_url,
299            "redirect": true
300        }))?)
301    }
302}