better_auth_api/plugins/oauth/
mod.rs1use async_trait::async_trait;
2
3use better_auth_core::AuthResult;
4use better_auth_core::adapters::DatabaseAdapter;
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
6use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
7
8#[cfg(feature = "axum")]
9use better_auth_core::plugin::{AuthState, AxumPlugin};
10
11pub mod encryption;
12mod handlers;
13mod providers;
14#[cfg(test)]
15mod tests;
16mod types;
17
18pub use providers::{OAuthConfig, OAuthProvider, OAuthStateStrategy, OAuthUserInfo};
19
20pub struct OAuthPlugin {
21 config: OAuthConfig,
22}
23
24impl OAuthPlugin {
25 pub fn new() -> Self {
26 Self {
27 config: OAuthConfig::default(),
28 }
29 }
30
31 pub fn with_config(config: OAuthConfig) -> Self {
32 Self { config }
33 }
34
35 pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
36 self.config.providers.insert(name.to_string(), provider);
37 self
38 }
39}
40
41impl Default for OAuthPlugin {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47#[async_trait]
48impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
49 fn name(&self) -> &'static str {
50 "oauth"
51 }
52
53 fn routes(&self) -> Vec<AuthRoute> {
54 vec![
55 AuthRoute::post("/sign-in/social", "social_sign_in"),
56 AuthRoute::get("/callback/{provider}", "oauth_callback"),
57 AuthRoute::post("/link-social", "link_social"),
58 AuthRoute::post("/get-access-token", "get_access_token"),
59 AuthRoute::post("/refresh-token", "refresh_token"),
60 ]
61 }
62
63 async fn on_request(
64 &self,
65 req: &AuthRequest,
66 ctx: &AuthContext<DB>,
67 ) -> AuthResult<Option<AuthResponse>> {
68 match (req.method(), req.path()) {
69 (HttpMethod::Post, "/sign-in/social") => Ok(Some(
70 handlers::handle_social_sign_in(&self.config, req, ctx).await?,
71 )),
72 (HttpMethod::Get, path) if path_matches_callback(path) => {
73 let provider = extract_provider_from_callback(path);
74 Ok(Some(
75 handlers::handle_callback(&self.config, &provider, req, ctx).await?,
76 ))
77 }
78 (HttpMethod::Post, "/link-social") => Ok(Some(
79 handlers::handle_link_social(&self.config, req, ctx).await?,
80 )),
81 (HttpMethod::Post, "/get-access-token") => Ok(Some(
82 handlers::handle_get_access_token(&self.config, req, ctx).await?,
83 )),
84 (HttpMethod::Post, "/refresh-token") => Ok(Some(
85 handlers::handle_refresh_token(&self.config, req, ctx).await?,
86 )),
87 _ => Ok(None),
88 }
89 }
90}
91
92fn path_matches_callback(path: &str) -> bool {
94 let path_without_query = path.split('?').next().unwrap_or(path);
95 path_without_query.starts_with("/callback/") && path_without_query.len() > "/callback/".len()
96}
97
98fn extract_provider_from_callback(path: &str) -> String {
100 let path_without_query = path.split('?').next().unwrap_or(path);
101 path_without_query["/callback/".len()..].to_string()
102}
103
104#[cfg(feature = "axum")]
109mod axum_impl {
110 use super::*;
111 use std::sync::Arc;
112
113 use axum::Json;
114 use axum::extract::{Extension, Path, Query, State};
115 use axum::http::header;
116 use axum::response::IntoResponse;
117 use better_auth_core::error::AuthError;
118 use better_auth_core::extractors::{CurrentSession, ValidatedJson};
119
120 use super::handlers::{
121 callback_core, get_access_token_core, link_social_core, refresh_token_core,
122 social_sign_in_core,
123 };
124 use super::types::{
125 AccessTokenResponse, GetAccessTokenRequest, LinkSocialRequest, RefreshTokenRequest,
126 RefreshTokenResponse, SocialSignInRequest, SocialSignInResponse,
127 };
128
129 #[derive(serde::Deserialize)]
130 struct CallbackQuery {
131 code: String,
132 state: String,
133 }
134
135 #[derive(Clone)]
136 struct PluginState {
137 config: OAuthConfig,
138 }
139
140 async fn handle_social_sign_in<DB: DatabaseAdapter>(
141 State(state): State<AuthState<DB>>,
142 Extension(ps): Extension<Arc<PluginState>>,
143 ValidatedJson(body): ValidatedJson<SocialSignInRequest>,
144 ) -> Result<Json<SocialSignInResponse>, AuthError> {
145 let ctx = state.to_context();
146 let result = social_sign_in_core(&body, &ps.config, &ctx).await?;
147 Ok(Json(result))
148 }
149
150 async fn handle_callback<DB: DatabaseAdapter>(
151 State(state): State<AuthState<DB>>,
152 Extension(ps): Extension<Arc<PluginState>>,
153 Path(provider): Path<String>,
154 Query(params): Query<CallbackQuery>,
155 ) -> Result<impl IntoResponse, AuthError> {
156 let ctx = state.to_context();
157 let (response, token) =
158 callback_core(¶ms.code, ¶ms.state, &provider, &ps.config, &ctx).await?;
159 let cookie = state.session_cookie(&token);
160 Ok(([(header::SET_COOKIE, cookie)], Json(response)))
161 }
162
163 async fn handle_link_social<DB: DatabaseAdapter>(
164 State(state): State<AuthState<DB>>,
165 Extension(ps): Extension<Arc<PluginState>>,
166 CurrentSession { session, .. }: CurrentSession<DB>,
167 ValidatedJson(body): ValidatedJson<LinkSocialRequest>,
168 ) -> Result<Json<SocialSignInResponse>, AuthError> {
169 let ctx = state.to_context();
170 let result = link_social_core(&body, &session, &ps.config, &ctx).await?;
171 Ok(Json(result))
172 }
173
174 async fn handle_get_access_token<DB: DatabaseAdapter>(
175 State(state): State<AuthState<DB>>,
176 CurrentSession { session, .. }: CurrentSession<DB>,
177 ValidatedJson(body): ValidatedJson<GetAccessTokenRequest>,
178 ) -> Result<Json<AccessTokenResponse>, AuthError> {
179 let ctx = state.to_context();
180 let result = get_access_token_core(&body, &session, &ctx).await?;
181 Ok(Json(result))
182 }
183
184 async fn handle_refresh_token<DB: DatabaseAdapter>(
185 State(state): State<AuthState<DB>>,
186 Extension(ps): Extension<Arc<PluginState>>,
187 CurrentSession { session, .. }: CurrentSession<DB>,
188 ValidatedJson(body): ValidatedJson<RefreshTokenRequest>,
189 ) -> Result<Json<RefreshTokenResponse>, AuthError> {
190 let ctx = state.to_context();
191 let result = refresh_token_core(&body, &session, &ps.config, &ctx).await?;
192 Ok(Json(result))
193 }
194
195 #[async_trait]
196 impl<DB: DatabaseAdapter> AxumPlugin<DB> for OAuthPlugin {
197 fn name(&self) -> &'static str {
198 "oauth"
199 }
200
201 fn router(&self) -> axum::Router<AuthState<DB>> {
202 use axum::routing::{get, post};
203
204 let plugin_state = Arc::new(PluginState {
205 config: self.config.clone(),
206 });
207
208 axum::Router::new()
209 .route("/sign-in/social", post(handle_social_sign_in::<DB>))
210 .route("/callback/:provider", get(handle_callback::<DB>))
211 .route("/link-social", post(handle_link_social::<DB>))
212 .route("/get-access-token", post(handle_get_access_token::<DB>))
213 .route("/refresh-token", post(handle_refresh_token::<DB>))
214 .layer(Extension(plugin_state))
215 }
216 }
217}