1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5use crate::core::{AuthPlugin, AuthRoute, AuthContext};
6use crate::types::{AuthRequest, AuthResponse, HttpMethod, User, CreateUser, CreateAccount};
7use crate::error::{AuthError, AuthResult};
8
9pub struct OAuthPlugin {
11 config: OAuthConfig,
12}
13
14#[derive(Debug, Clone)]
15pub struct OAuthConfig {
16 pub providers: HashMap<String, OAuthProvider>,
17}
18
19#[derive(Debug, Clone)]
20pub struct OAuthProvider {
21 pub client_id: String,
22 pub client_secret: String,
23 pub auth_url: String,
24 pub token_url: String,
25 pub user_info_url: String,
26 pub scopes: Vec<String>,
27}
28
29impl OAuthPlugin {
30 pub fn new() -> Self {
31 Self {
32 config: OAuthConfig::default(),
33 }
34 }
35
36 pub fn with_config(config: OAuthConfig) -> Self {
37 Self { config }
38 }
39
40 pub fn add_provider(mut self, name: &str, provider: OAuthProvider) -> Self {
41 self.config.providers.insert(name.to_string(), provider);
42 self
43 }
44}
45
46impl Default for OAuthPlugin {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl Default for OAuthConfig {
53 fn default() -> Self {
54 Self {
55 providers: HashMap::new(),
56 }
57 }
58}
59
60#[derive(Debug, Deserialize)]
62struct SocialSignInRequest {
63 #[serde(rename = "callbackURL")]
64 callback_url: Option<String>,
65 #[serde(rename = "newUserCallbackURL")]
66 new_user_callback_url: Option<String>,
67 #[serde(rename = "errorCallbackURL")]
68 error_callback_url: Option<String>,
69 provider: String,
70 #[serde(rename = "disableRedirect")]
71 disable_redirect: Option<String>,
72 #[serde(rename = "idToken")]
73 id_token: Option<String>,
74 scopes: Option<String>,
75 #[serde(rename = "requestSignUp")]
76 request_sign_up: Option<String>,
77 #[serde(rename = "loginHint")]
78 login_hint: Option<String>,
79}
80
81#[derive(Debug, Deserialize)]
82struct LinkSocialRequest {
83 #[serde(rename = "callbackURL")]
84 callback_url: Option<String>,
85 provider: String,
86 scopes: Option<String>,
87}
88
89#[derive(Debug, Serialize)]
90struct SocialSignInResponse {
91 redirect: bool,
92 token: String,
93 url: Option<String>,
94 user: User,
95}
96
97#[derive(Debug, Serialize)]
98struct LinkSocialResponse {
99 url: String,
100 redirect: bool,
101}
102
103#[async_trait]
104impl AuthPlugin for OAuthPlugin {
105 fn name(&self) -> &'static str {
106 "oauth"
107 }
108
109 fn routes(&self) -> Vec<AuthRoute> {
110 vec![
111 AuthRoute::post("/sign-in/social", "social_sign_in"),
112 AuthRoute::post("/link-social", "link_social"),
113 ]
114 }
115
116 async fn on_request(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<Option<AuthResponse>> {
117 match (req.method(), req.path()) {
118 (HttpMethod::Post, "/sign-in/social") => {
119 Ok(Some(self.handle_social_sign_in(req, ctx).await?))
120 },
121 (HttpMethod::Post, "/link-social") => {
122 Ok(Some(self.handle_link_social(req, ctx).await?))
123 },
124 _ => Ok(None),
125 }
126 }
127}
128
129impl OAuthPlugin {
131 async fn handle_social_sign_in(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
132 let signin_req: SocialSignInRequest = match req.body_as_json() {
133 Ok(req) => req,
134 Err(e) => {
135 return Ok(AuthResponse::json(400, &serde_json::json!({
136 "error": "Invalid request",
137 "message": format!("Invalid JSON: {}", e)
138 }))?);
139 }
140 };
141
142 if !self.config.providers.contains_key(&signin_req.provider) {
144 return Ok(AuthResponse::json(400, &serde_json::json!({
145 "error": "Invalid provider",
146 "message": format!("Provider '{}' is not configured", signin_req.provider)
147 }))?);
148 }
149
150 if let Some(id_token) = &signin_req.id_token {
152 return self.handle_id_token_sign_in(id_token, &signin_req, ctx).await;
153 }
154
155 self.generate_auth_url(&signin_req, ctx).await
157 }
158
159 async fn handle_link_social(&self, req: &AuthRequest, ctx: &AuthContext) -> AuthResult<AuthResponse> {
160 let link_req: LinkSocialRequest = match req.body_as_json() {
161 Ok(req) => req,
162 Err(e) => {
163 return Ok(AuthResponse::json(400, &serde_json::json!({
164 "error": "Invalid request",
165 "message": format!("Invalid JSON: {}", e)
166 }))?);
167 }
168 };
169
170 if !self.config.providers.contains_key(&link_req.provider) {
172 return Ok(AuthResponse::json(400, &serde_json::json!({
173 "error": "Invalid provider",
174 "message": format!("Provider '{}' is not configured", link_req.provider)
175 }))?);
176 }
177
178 let provider = &self.config.providers[&link_req.provider];
180 let callback_url = link_req.callback_url.unwrap_or_else(|| {
181 format!("{}/oauth/{}/callback", ctx.config.base_url, link_req.provider)
182 });
183
184 let scopes = if let Some(scopes) = &link_req.scopes {
185 scopes.split(',').map(|s| s.trim().to_string()).collect()
186 } else {
187 provider.scopes.clone()
188 };
189
190 let auth_url = format!(
191 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state=link_{}",
192 provider.auth_url,
193 provider.client_id,
194 urlencoding::encode(&callback_url),
195 urlencoding::encode(&scopes.join(" ")),
196 uuid::Uuid::new_v4()
197 );
198
199 let response = LinkSocialResponse {
200 url: auth_url,
201 redirect: true,
202 };
203
204 Ok(AuthResponse::json(200, &response)?)
205 }
206
207 async fn handle_id_token_sign_in(
208 &self,
209 id_token: &str,
210 signin_req: &SocialSignInRequest,
211 ctx: &AuthContext,
212 ) -> AuthResult<AuthResponse> {
213 let email = format!("user+{}@{}.com", uuid::Uuid::new_v4(), signin_req.provider);
218 let name = format!("User from {}", signin_req.provider);
219
220 let existing_user = ctx.database.get_user_by_email(&email).await?;
222 let user = if let Some(user) = existing_user {
223 user
224 } else {
225 let create_user = CreateUser::new()
227 .with_email(&email)
228 .with_name(&name)
229 .with_email_verified(true); ctx.database.create_user(create_user).await?
232 };
233
234 let create_account = CreateAccount {
236 account_id: format!("{}_{}", signin_req.provider, uuid::Uuid::new_v4()),
237 provider_id: signin_req.provider.clone(),
238 user_id: user.id.clone(),
239 access_token: Some("mock_access_token".to_string()),
240 refresh_token: None,
241 id_token: Some(id_token.to_string()),
242 access_token_expires_at: None,
243 refresh_token_expires_at: None,
244 scope: None,
245 password: None,
246 };
247
248 if ctx.database.get_account(&signin_req.provider, &create_account.account_id).await?.is_none() {
250 ctx.database.create_account(create_account).await?;
251 }
252
253 let session_manager = crate::core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
255 let session = session_manager.create_session(&user, None, None).await?;
256
257 let response = SocialSignInResponse {
258 redirect: false,
259 token: session.token,
260 url: None,
261 user,
262 };
263
264 Ok(AuthResponse::json(200, &response)?)
265 }
266
267 async fn generate_auth_url(
268 &self,
269 signin_req: &SocialSignInRequest,
270 ctx: &AuthContext,
271 ) -> AuthResult<AuthResponse> {
272 let provider = &self.config.providers[&signin_req.provider];
273 let callback_url = signin_req.callback_url.clone().unwrap_or_else(|| {
274 format!("{}/oauth/{}/callback", ctx.config.base_url, signin_req.provider)
275 });
276
277 let scopes = if let Some(scopes) = &signin_req.scopes {
278 scopes.split(',').map(|s| s.trim().to_string()).collect()
279 } else {
280 provider.scopes.clone()
281 };
282
283 let mut auth_url = format!(
284 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
285 provider.auth_url,
286 provider.client_id,
287 urlencoding::encode(&callback_url),
288 urlencoding::encode(&scopes.join(" ")),
289 uuid::Uuid::new_v4()
290 );
291
292 if let Some(login_hint) = &signin_req.login_hint {
293 auth_url.push_str(&format!("&login_hint={}", urlencoding::encode(login_hint)));
294 }
295
296 Ok(AuthResponse::json(200, &serde_json::json!({
298 "url": auth_url,
299 "redirect": true
300 }))?)
301 }
302}