Skip to main content

better_auth_api/plugins/oauth/
mod.rs

1use async_trait::async_trait;
2
3use better_auth_core::AuthResult;
4use better_auth_core::adapters::DatabaseAdapter;
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
6use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
7
8mod handlers;
9mod providers;
10mod types;
11
12pub use providers::{OAuthConfig, OAuthProvider, OAuthUserInfo};
13
14pub struct OAuthPlugin {
15    config: OAuthConfig,
16}
17
18impl OAuthPlugin {
19    pub fn new() -> Self {
20        Self {
21            config: OAuthConfig::default(),
22        }
23    }
24
25    pub fn with_config(config: OAuthConfig) -> Self {
26        Self { config }
27    }
28
29    pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
30        self.config.providers.insert(name.to_string(), provider);
31        self
32    }
33}
34
35impl Default for OAuthPlugin {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41#[async_trait]
42impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
43    fn name(&self) -> &'static str {
44        "oauth"
45    }
46
47    fn routes(&self) -> Vec<AuthRoute> {
48        vec![
49            AuthRoute::post("/sign-in/social", "social_sign_in"),
50            AuthRoute::get("/callback/{provider}", "oauth_callback"),
51            AuthRoute::post("/link-social", "link_social"),
52            AuthRoute::post("/get-access-token", "get_access_token"),
53            AuthRoute::post("/refresh-token", "refresh_token"),
54        ]
55    }
56
57    async fn on_request(
58        &self,
59        req: &AuthRequest,
60        ctx: &AuthContext<DB>,
61    ) -> AuthResult<Option<AuthResponse>> {
62        match (req.method(), req.path()) {
63            (HttpMethod::Post, "/sign-in/social") => Ok(Some(
64                handlers::handle_social_sign_in(&self.config, req, ctx).await?,
65            )),
66            (HttpMethod::Get, path) if path_matches_callback(path) => {
67                let provider = extract_provider_from_callback(path);
68                Ok(Some(
69                    handlers::handle_callback(&self.config, &provider, req, ctx).await?,
70                ))
71            }
72            (HttpMethod::Post, "/link-social") => Ok(Some(
73                handlers::handle_link_social(&self.config, req, ctx).await?,
74            )),
75            (HttpMethod::Post, "/get-access-token") => Ok(Some(
76                handlers::handle_get_access_token(&self.config, req, ctx).await?,
77            )),
78            (HttpMethod::Post, "/refresh-token") => Ok(Some(
79                handlers::handle_refresh_token(&self.config, req, ctx).await?,
80            )),
81            _ => Ok(None),
82        }
83    }
84}
85
86/// Check if the path matches `/callback/{provider}` (with optional query string).
87fn path_matches_callback(path: &str) -> bool {
88    let path_without_query = path.split('?').next().unwrap_or(path);
89    path_without_query.starts_with("/callback/") && path_without_query.len() > "/callback/".len()
90}
91
92/// Extract the provider name from `/callback/{provider}?...`.
93fn extract_provider_from_callback(path: &str) -> String {
94    let path_without_query = path.split('?').next().unwrap_or(path);
95    path_without_query["/callback/".len()..].to_string()
96}