better_auth_api/plugins/oauth/
mod.rs1use async_trait::async_trait;
2
3use better_auth_core::AuthResult;
4use better_auth_core::adapters::DatabaseAdapter;
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
6use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
7
8pub mod encryption;
9mod handlers;
10mod providers;
11mod types;
12
13pub use providers::{OAuthConfig, OAuthProvider, OAuthStateStrategy, OAuthUserInfo};
14
15pub struct OAuthPlugin {
16 config: OAuthConfig,
17}
18
19impl OAuthPlugin {
20 pub fn new() -> Self {
21 Self {
22 config: OAuthConfig::default(),
23 }
24 }
25
26 pub fn with_config(config: OAuthConfig) -> Self {
27 Self { config }
28 }
29
30 pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
31 self.config.providers.insert(name.to_string(), provider);
32 self
33 }
34}
35
36impl Default for OAuthPlugin {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42#[async_trait]
43impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
44 fn name(&self) -> &'static str {
45 "oauth"
46 }
47
48 fn routes(&self) -> Vec<AuthRoute> {
49 vec![
50 AuthRoute::post("/sign-in/social", "social_sign_in"),
51 AuthRoute::get("/callback/{provider}", "oauth_callback"),
52 AuthRoute::post("/link-social", "link_social"),
53 AuthRoute::post("/get-access-token", "get_access_token"),
54 AuthRoute::post("/refresh-token", "refresh_token"),
55 ]
56 }
57
58 async fn on_request(
59 &self,
60 req: &AuthRequest,
61 ctx: &AuthContext<DB>,
62 ) -> AuthResult<Option<AuthResponse>> {
63 match (req.method(), req.path()) {
64 (HttpMethod::Post, "/sign-in/social") => Ok(Some(
65 handlers::handle_social_sign_in(&self.config, req, ctx).await?,
66 )),
67 (HttpMethod::Get, path) if path_matches_callback(path) => {
68 let provider = extract_provider_from_callback(path);
69 Ok(Some(
70 handlers::handle_callback(&self.config, &provider, req, ctx).await?,
71 ))
72 }
73 (HttpMethod::Post, "/link-social") => Ok(Some(
74 handlers::handle_link_social(&self.config, req, ctx).await?,
75 )),
76 (HttpMethod::Post, "/get-access-token") => Ok(Some(
77 handlers::handle_get_access_token(&self.config, req, ctx).await?,
78 )),
79 (HttpMethod::Post, "/refresh-token") => Ok(Some(
80 handlers::handle_refresh_token(&self.config, req, ctx).await?,
81 )),
82 _ => Ok(None),
83 }
84 }
85}
86
87fn path_matches_callback(path: &str) -> bool {
89 let path_without_query = path.split('?').next().unwrap_or(path);
90 path_without_query.starts_with("/callback/") && path_without_query.len() > "/callback/".len()
91}
92
93fn extract_provider_from_callback(path: &str) -> String {
95 let path_without_query = path.split('?').next().unwrap_or(path);
96 path_without_query["/callback/".len()..].to_string()
97}