better_auth_api/plugins/oauth/
mod.rs1use async_trait::async_trait;
2
3use better_auth_core::AuthResult;
4use better_auth_core::adapters::DatabaseAdapter;
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
6use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
7
8mod handlers;
9mod providers;
10mod types;
11
12pub use providers::{OAuthConfig, OAuthProvider, OAuthUserInfo};
13
14pub struct OAuthPlugin {
15 config: OAuthConfig,
16}
17
18impl OAuthPlugin {
19 pub fn new() -> Self {
20 Self {
21 config: OAuthConfig::default(),
22 }
23 }
24
25 pub fn with_config(config: OAuthConfig) -> Self {
26 Self { config }
27 }
28
29 pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
30 self.config.providers.insert(name.to_string(), provider);
31 self
32 }
33}
34
35impl Default for OAuthPlugin {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41#[async_trait]
42impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
43 fn name(&self) -> &'static str {
44 "oauth"
45 }
46
47 fn routes(&self) -> Vec<AuthRoute> {
48 vec![
49 AuthRoute::post("/sign-in/social", "social_sign_in"),
50 AuthRoute::get("/callback/{provider}", "oauth_callback"),
51 AuthRoute::post("/link-social", "link_social"),
52 AuthRoute::post("/get-access-token", "get_access_token"),
53 AuthRoute::post("/refresh-token", "refresh_token"),
54 ]
55 }
56
57 async fn on_request(
58 &self,
59 req: &AuthRequest,
60 ctx: &AuthContext<DB>,
61 ) -> AuthResult<Option<AuthResponse>> {
62 match (req.method(), req.path()) {
63 (HttpMethod::Post, "/sign-in/social") => Ok(Some(
64 handlers::handle_social_sign_in(&self.config, req, ctx).await?,
65 )),
66 (HttpMethod::Get, path) if path_matches_callback(path) => {
67 let provider = extract_provider_from_callback(path);
68 Ok(Some(
69 handlers::handle_callback(&self.config, &provider, req, ctx).await?,
70 ))
71 }
72 (HttpMethod::Post, "/link-social") => Ok(Some(
73 handlers::handle_link_social(&self.config, req, ctx).await?,
74 )),
75 (HttpMethod::Post, "/get-access-token") => Ok(Some(
76 handlers::handle_get_access_token(&self.config, req, ctx).await?,
77 )),
78 (HttpMethod::Post, "/refresh-token") => Ok(Some(
79 handlers::handle_refresh_token(&self.config, req, ctx).await?,
80 )),
81 _ => Ok(None),
82 }
83 }
84}
85
86fn path_matches_callback(path: &str) -> bool {
88 let path_without_query = path.split('?').next().unwrap_or(path);
89 path_without_query.starts_with("/callback/") && path_without_query.len() > "/callback/".len()
90}
91
92fn extract_provider_from_callback(path: &str) -> String {
94 let path_without_query = path.split('?').next().unwrap_or(path);
95 path_without_query["/callback/".len()..].to_string()
96}