better_auth_api/plugins/oauth/
mod.rs1use async_trait::async_trait;
2
3use better_auth_core::AuthResult;
4use better_auth_core::adapters::DatabaseAdapter;
5use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
6use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
7
8#[cfg(feature = "axum")]
9use better_auth_core::plugin::{AuthState, AxumPlugin};
10
11pub mod encryption;
12mod handlers;
13mod providers;
14mod types;
15
16pub use providers::{OAuthConfig, OAuthProvider, OAuthStateStrategy, OAuthUserInfo};
17
18pub struct OAuthPlugin {
19 config: OAuthConfig,
20}
21
22impl OAuthPlugin {
23 pub fn new() -> Self {
24 Self {
25 config: OAuthConfig::default(),
26 }
27 }
28
29 pub fn with_config(config: OAuthConfig) -> Self {
30 Self { config }
31 }
32
33 pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
34 self.config.providers.insert(name.to_string(), provider);
35 self
36 }
37}
38
39impl Default for OAuthPlugin {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45#[async_trait]
46impl<DB: DatabaseAdapter> AuthPlugin<DB> for OAuthPlugin {
47 fn name(&self) -> &'static str {
48 "oauth"
49 }
50
51 fn routes(&self) -> Vec<AuthRoute> {
52 vec![
53 AuthRoute::post("/sign-in/social", "social_sign_in"),
54 AuthRoute::get("/callback/{provider}", "oauth_callback"),
55 AuthRoute::post("/link-social", "link_social"),
56 AuthRoute::post("/get-access-token", "get_access_token"),
57 AuthRoute::post("/refresh-token", "refresh_token"),
58 ]
59 }
60
61 async fn on_request(
62 &self,
63 req: &AuthRequest,
64 ctx: &AuthContext<DB>,
65 ) -> AuthResult<Option<AuthResponse>> {
66 match (req.method(), req.path()) {
67 (HttpMethod::Post, "/sign-in/social") => Ok(Some(
68 handlers::handle_social_sign_in(&self.config, req, ctx).await?,
69 )),
70 (HttpMethod::Get, path) if path_matches_callback(path) => {
71 let provider = extract_provider_from_callback(path);
72 Ok(Some(
73 handlers::handle_callback(&self.config, &provider, req, ctx).await?,
74 ))
75 }
76 (HttpMethod::Post, "/link-social") => Ok(Some(
77 handlers::handle_link_social(&self.config, req, ctx).await?,
78 )),
79 (HttpMethod::Post, "/get-access-token") => Ok(Some(
80 handlers::handle_get_access_token(&self.config, req, ctx).await?,
81 )),
82 (HttpMethod::Post, "/refresh-token") => Ok(Some(
83 handlers::handle_refresh_token(&self.config, req, ctx).await?,
84 )),
85 _ => Ok(None),
86 }
87 }
88}
89
90fn path_matches_callback(path: &str) -> bool {
92 let path_without_query = path.split('?').next().unwrap_or(path);
93 path_without_query.starts_with("/callback/") && path_without_query.len() > "/callback/".len()
94}
95
96fn extract_provider_from_callback(path: &str) -> String {
98 let path_without_query = path.split('?').next().unwrap_or(path);
99 path_without_query["/callback/".len()..].to_string()
100}
101
102#[cfg(feature = "axum")]
107mod axum_impl {
108 use super::*;
109 use std::sync::Arc;
110
111 use axum::Json;
112 use axum::extract::{Extension, Path, Query, State};
113 use axum::http::header;
114 use axum::response::IntoResponse;
115 use better_auth_core::error::AuthError;
116 use better_auth_core::extractors::{CurrentSession, ValidatedJson};
117
118 use super::handlers::{
119 callback_core, get_access_token_core, link_social_core, refresh_token_core,
120 social_sign_in_core,
121 };
122 use super::types::{
123 AccessTokenResponse, GetAccessTokenRequest, LinkSocialRequest, RefreshTokenRequest,
124 RefreshTokenResponse, SocialSignInRequest, SocialSignInResponse,
125 };
126
127 #[derive(serde::Deserialize)]
128 struct CallbackQuery {
129 code: String,
130 state: String,
131 }
132
133 #[derive(Clone)]
134 struct PluginState {
135 config: OAuthConfig,
136 }
137
138 async fn handle_social_sign_in<DB: DatabaseAdapter>(
139 State(state): State<AuthState<DB>>,
140 Extension(ps): Extension<Arc<PluginState>>,
141 ValidatedJson(body): ValidatedJson<SocialSignInRequest>,
142 ) -> Result<Json<SocialSignInResponse>, AuthError> {
143 let ctx = state.to_context();
144 let result = social_sign_in_core(&body, &ps.config, &ctx).await?;
145 Ok(Json(result))
146 }
147
148 async fn handle_callback<DB: DatabaseAdapter>(
149 State(state): State<AuthState<DB>>,
150 Extension(ps): Extension<Arc<PluginState>>,
151 Path(provider): Path<String>,
152 Query(params): Query<CallbackQuery>,
153 ) -> Result<impl IntoResponse, AuthError> {
154 let ctx = state.to_context();
155 let (response, token) =
156 callback_core(¶ms.code, ¶ms.state, &provider, &ps.config, &ctx).await?;
157 let cookie = state.session_cookie(&token);
158 Ok(([(header::SET_COOKIE, cookie)], Json(response)))
159 }
160
161 async fn handle_link_social<DB: DatabaseAdapter>(
162 State(state): State<AuthState<DB>>,
163 Extension(ps): Extension<Arc<PluginState>>,
164 CurrentSession { session, .. }: CurrentSession<DB>,
165 ValidatedJson(body): ValidatedJson<LinkSocialRequest>,
166 ) -> Result<Json<SocialSignInResponse>, AuthError> {
167 let ctx = state.to_context();
168 let result = link_social_core(&body, &session, &ps.config, &ctx).await?;
169 Ok(Json(result))
170 }
171
172 async fn handle_get_access_token<DB: DatabaseAdapter>(
173 State(state): State<AuthState<DB>>,
174 CurrentSession { session, .. }: CurrentSession<DB>,
175 ValidatedJson(body): ValidatedJson<GetAccessTokenRequest>,
176 ) -> Result<Json<AccessTokenResponse>, AuthError> {
177 let ctx = state.to_context();
178 let result = get_access_token_core(&body, &session, &ctx).await?;
179 Ok(Json(result))
180 }
181
182 async fn handle_refresh_token<DB: DatabaseAdapter>(
183 State(state): State<AuthState<DB>>,
184 Extension(ps): Extension<Arc<PluginState>>,
185 CurrentSession { session, .. }: CurrentSession<DB>,
186 ValidatedJson(body): ValidatedJson<RefreshTokenRequest>,
187 ) -> Result<Json<RefreshTokenResponse>, AuthError> {
188 let ctx = state.to_context();
189 let result = refresh_token_core(&body, &session, &ps.config, &ctx).await?;
190 Ok(Json(result))
191 }
192
193 #[async_trait]
194 impl<DB: DatabaseAdapter> AxumPlugin<DB> for OAuthPlugin {
195 fn name(&self) -> &'static str {
196 "oauth"
197 }
198
199 fn router(&self) -> axum::Router<AuthState<DB>> {
200 use axum::routing::{get, post};
201
202 let plugin_state = Arc::new(PluginState {
203 config: self.config.clone(),
204 });
205
206 axum::Router::new()
207 .route("/sign-in/social", post(handle_social_sign_in::<DB>))
208 .route("/callback/:provider", get(handle_callback::<DB>))
209 .route("/link-social", post(handle_link_social::<DB>))
210 .route("/get-access-token", post(handle_get_access_token::<DB>))
211 .route("/refresh-token", post(handle_refresh_token::<DB>))
212 .layer(Extension(plugin_state))
213 }
214 }
215}