Skip to main content

better_auth_api/plugins/oauth/
mod.rs

1use async_trait::async_trait;
2
3use better_auth_core::AuthResult;
4use better_auth_core::adapters::DatabaseAdapter;
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
6use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
7
8pub mod encryption;
9mod handlers;
10mod providers;
11mod types;
12
13pub use providers::{OAuthConfig, OAuthProvider, OAuthStateStrategy, OAuthUserInfo};
14
15pub struct OAuthPlugin {
16    config: OAuthConfig,
17}
18
19impl OAuthPlugin {
20    pub fn new() -> Self {
21        Self {
22            config: OAuthConfig::default(),
23        }
24    }
25
26    pub fn with_config(config: OAuthConfig) -> Self {
27        Self { config }
28    }
29
30    pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
31        self.config.providers.insert(name.to_string(), provider);
32        self
33    }
34}
35
36impl Default for OAuthPlugin {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42#[async_trait]
43impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
44    fn name(&self) -> &'static str {
45        "oauth"
46    }
47
48    fn routes(&self) -> Vec<AuthRoute> {
49        vec![
50            AuthRoute::post("/sign-in/social", "social_sign_in"),
51            AuthRoute::get("/callback/{provider}", "oauth_callback"),
52            AuthRoute::post("/link-social", "link_social"),
53            AuthRoute::post("/get-access-token", "get_access_token"),
54            AuthRoute::post("/refresh-token", "refresh_token"),
55        ]
56    }
57
58    async fn on_request(
59        &self,
60        req: &AuthRequest,
61        ctx: &AuthContext<DB>,
62    ) -> AuthResult<Option<AuthResponse>> {
63        match (req.method(), req.path()) {
64            (HttpMethod::Post, "/sign-in/social") => Ok(Some(
65                handlers::handle_social_sign_in(&self.config, req, ctx).await?,
66            )),
67            (HttpMethod::Get, path) if path_matches_callback(path) => {
68                let provider = extract_provider_from_callback(path);
69                Ok(Some(
70                    handlers::handle_callback(&self.config, &provider, req, ctx).await?,
71                ))
72            }
73            (HttpMethod::Post, "/link-social") => Ok(Some(
74                handlers::handle_link_social(&self.config, req, ctx).await?,
75            )),
76            (HttpMethod::Post, "/get-access-token") => Ok(Some(
77                handlers::handle_get_access_token(&self.config, req, ctx).await?,
78            )),
79            (HttpMethod::Post, "/refresh-token") => Ok(Some(
80                handlers::handle_refresh_token(&self.config, req, ctx).await?,
81            )),
82            _ => Ok(None),
83        }
84    }
85}
86
87/// Check if the path matches `/callback/{provider}` (with optional query string).
88fn path_matches_callback(path: &str) -> bool {
89    let path_without_query = path.split('?').next().unwrap_or(path);
90    path_without_query.starts_with("/callback/") && path_without_query.len() > "/callback/".len()
91}
92
93/// Extract the provider name from `/callback/{provider}?...`.
94fn extract_provider_from_callback(path: &str) -> String {
95    let path_without_query = path.split('?').next().unwrap_or(path);
96    path_without_query["/callback/".len()..].to_string()
97}