1use crate::error::{AuthError, Result};
4use crate::models::OAuth2Provider;
5use oauth2::{
6 AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
7 PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, TokenUrl,
8};
9use oauth2::basic::BasicClient;
10use oauth2::reqwest::async_http_client;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16pub struct OAuth2Manager {
17 providers: Arc<RwLock<HashMap<String, ProviderClient>>>,
18}
19
20struct ProviderClient {
21 client: BasicClient,
22 scopes: Vec<String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct AuthorizationRequest {
27 pub url: String,
28 pub state: String,
29 pub pkce_verifier: Option<String>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TokenExchange {
34 pub access_token: String,
35 pub refresh_token: Option<String>,
36 pub expires_in: Option<u64>,
37 pub id_token: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct UserInfo {
42 pub id: String,
43 pub email: String,
44 pub email_verified: bool,
45 pub name: Option<String>,
46 pub picture: Option<String>,
47 pub provider: String,
48}
49
50impl OAuth2Manager {
51 pub fn new() -> Self {
52 Self {
53 providers: Arc::new(RwLock::new(HashMap::new())),
54 }
55 }
56
57 pub async fn register_provider(&self, provider: OAuth2Provider) -> Result<()> {
58 let client = BasicClient::new(
59 ClientId::new(provider.client_id.clone()),
60 Some(ClientSecret::new(provider.client_secret.clone())),
61 AuthUrl::new(provider.auth_url.clone())
62 .map_err(|e| AuthError::ConfigError(e.to_string()))?,
63 Some(
64 TokenUrl::new(provider.token_url.clone())
65 .map_err(|e| AuthError::ConfigError(e.to_string()))?,
66 ),
67 )
68 .set_redirect_uri(
69 RedirectUrl::new(provider.redirect_url.clone())
70 .map_err(|e| AuthError::ConfigError(e.to_string()))?,
71 );
72
73 let provider_client = ProviderClient {
74 client,
75 scopes: provider.scopes.clone(),
76 };
77
78 let mut providers = self.providers.write().await;
79 providers.insert(provider.name.clone(), provider_client);
80
81 tracing::info!("Registered OAuth2 provider: {}", provider.name);
82 Ok(())
83 }
84
85 pub async fn authorize_url(
86 &self,
87 provider_name: &str,
88 use_pkce: bool,
89 ) -> Result<AuthorizationRequest> {
90 let providers = self.providers.read().await;
91 let provider = providers
92 .get(provider_name)
93 .ok_or_else(|| AuthError::OAuth2Error(format!("Provider not found: {}", provider_name)))?;
94
95 let mut auth_request = provider.client.authorize_url(CsrfToken::new_random);
96
97 for scope in &provider.scopes {
98 auth_request = auth_request.add_scope(Scope::new(scope.clone()));
99 }
100
101 let (url, state, pkce_verifier) = if use_pkce {
102 let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
103 let (url, state) = auth_request.set_pkce_challenge(challenge).url();
104 (url, state, Some(verifier.secret().to_string()))
105 } else {
106 let (url, state) = auth_request.url();
107 (url, state, None)
108 };
109
110 Ok(AuthorizationRequest {
111 url: url.to_string(),
112 state: state.secret().to_string(),
113 pkce_verifier,
114 })
115 }
116
117 pub async fn exchange_code(
118 &self,
119 provider_name: &str,
120 code: &str,
121 pkce_verifier: Option<&str>,
122 ) -> Result<TokenExchange> {
123 let providers = self.providers.read().await;
124 let provider = providers
125 .get(provider_name)
126 .ok_or_else(|| AuthError::OAuth2Error(format!("Provider not found: {}", provider_name)))?;
127
128 let mut token_request = provider
129 .client
130 .exchange_code(AuthorizationCode::new(code.to_string()));
131
132 if let Some(verifier) = pkce_verifier {
133 token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(verifier.to_string()));
134 }
135
136 let token_response = token_request
137 .request_async(async_http_client)
138 .await
139 .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
140
141 Ok(TokenExchange {
142 access_token: token_response.access_token().secret().to_string(),
143 refresh_token: token_response.refresh_token().map(|t| t.secret().to_string()),
144 expires_in: token_response.expires_in().map(|d| d.as_secs()),
145 id_token: None, })
147 }
148
149 pub async fn get_user_info(
150 &self,
151 provider_name: &str,
152 access_token: &str,
153 ) -> Result<UserInfo> {
154 match provider_name {
155 "google" => self.get_google_user_info(access_token).await,
156 "github" => self.get_github_user_info(access_token).await,
157 "microsoft" => self.get_microsoft_user_info(access_token).await,
158 _ => Err(AuthError::OAuth2Error(format!(
159 "User info not supported for provider: {}",
160 provider_name
161 ))),
162 }
163 }
164
165 async fn get_google_user_info(&self, access_token: &str) -> Result<UserInfo> {
166 let client = reqwest::Client::new();
167 let response = client
168 .get("https://www.googleapis.com/oauth2/v2/userinfo")
169 .bearer_auth(access_token)
170 .send()
171 .await
172 .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
173
174 #[derive(Deserialize)]
175 struct GoogleUserInfo {
176 id: String,
177 email: String,
178 verified_email: bool,
179 name: Option<String>,
180 picture: Option<String>,
181 }
182
183 let user: GoogleUserInfo = response
184 .json()
185 .await
186 .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
187
188 Ok(UserInfo {
189 id: user.id,
190 email: user.email,
191 email_verified: user.verified_email,
192 name: user.name,
193 picture: user.picture,
194 provider: "google".to_string(),
195 })
196 }
197
198 async fn get_github_user_info(&self, access_token: &str) -> Result<UserInfo> {
199 let client = reqwest::Client::new();
200
201 let user_response = client
203 .get("https://api.github.com/user")
204 .bearer_auth(access_token)
205 .header("User-Agent", "AVL-Auth")
206 .send()
207 .await
208 .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
209
210 #[derive(Deserialize)]
211 struct GitHubUser {
212 id: u64,
213 _login: String,
214 name: Option<String>,
215 avatar_url: Option<String>,
216 }
217
218 let user: GitHubUser = user_response
219 .json()
220 .await
221 .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
222
223 let email_response = client
225 .get("https://api.github.com/user/emails")
226 .bearer_auth(access_token)
227 .header("User-Agent", "AVL-Auth")
228 .send()
229 .await
230 .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
231
232 #[derive(Deserialize)]
233 struct GitHubEmail {
234 email: String,
235 primary: bool,
236 verified: bool,
237 }
238
239 let emails: Vec<GitHubEmail> = email_response
240 .json()
241 .await
242 .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
243
244 let primary_email = emails
245 .into_iter()
246 .find(|e| e.primary)
247 .ok_or_else(|| AuthError::OAuth2Error("No primary email found".to_string()))?;
248
249 Ok(UserInfo {
250 id: user.id.to_string(),
251 email: primary_email.email,
252 email_verified: primary_email.verified,
253 name: user.name,
254 picture: user.avatar_url,
255 provider: "github".to_string(),
256 })
257 }
258
259 async fn get_microsoft_user_info(&self, access_token: &str) -> Result<UserInfo> {
260 let client = reqwest::Client::new();
261 let response = client
262 .get("https://graph.microsoft.com/v1.0/me")
263 .bearer_auth(access_token)
264 .send()
265 .await
266 .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
267
268 #[derive(Deserialize)]
269 struct MicrosoftUserInfo {
270 id: String,
271 #[serde(rename = "userPrincipalName")]
272 user_principal_name: String,
273 #[serde(rename = "displayName")]
274 display_name: Option<String>,
275 }
276
277 let user: MicrosoftUserInfo = response
278 .json()
279 .await
280 .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
281
282 Ok(UserInfo {
283 id: user.id,
284 email: user.user_principal_name,
285 email_verified: true, name: user.display_name,
287 picture: None,
288 provider: "microsoft".to_string(),
289 })
290 }
291}
292
293impl Default for OAuth2Manager {
294 fn default() -> Self {
295 Self::new()
296 }
297}