avl_auth/
oauth2.rs

1//! OAuth2 and OpenID Connect implementation
2
3use crate::error::{AuthError, Result};
4use crate::models::OAuth2Provider;
5use oauth2::{
6    AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
7    PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, TokenUrl,
8};
9use oauth2::basic::BasicClient;
10use oauth2::reqwest::async_http_client;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16pub struct OAuth2Manager {
17    providers: Arc<RwLock<HashMap<String, ProviderClient>>>,
18}
19
20struct ProviderClient {
21    client: BasicClient,
22    scopes: Vec<String>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct AuthorizationRequest {
27    pub url: String,
28    pub state: String,
29    pub pkce_verifier: Option<String>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct TokenExchange {
34    pub access_token: String,
35    pub refresh_token: Option<String>,
36    pub expires_in: Option<u64>,
37    pub id_token: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct UserInfo {
42    pub id: String,
43    pub email: String,
44    pub email_verified: bool,
45    pub name: Option<String>,
46    pub picture: Option<String>,
47    pub provider: String,
48}
49
50impl OAuth2Manager {
51    pub fn new() -> Self {
52        Self {
53            providers: Arc::new(RwLock::new(HashMap::new())),
54        }
55    }
56
57    pub async fn register_provider(&self, provider: OAuth2Provider) -> Result<()> {
58        let client = BasicClient::new(
59            ClientId::new(provider.client_id.clone()),
60            Some(ClientSecret::new(provider.client_secret.clone())),
61            AuthUrl::new(provider.auth_url.clone())
62                .map_err(|e| AuthError::ConfigError(e.to_string()))?,
63            Some(
64                TokenUrl::new(provider.token_url.clone())
65                    .map_err(|e| AuthError::ConfigError(e.to_string()))?,
66            ),
67        )
68        .set_redirect_uri(
69            RedirectUrl::new(provider.redirect_url.clone())
70                .map_err(|e| AuthError::ConfigError(e.to_string()))?,
71        );
72
73        let provider_client = ProviderClient {
74            client,
75            scopes: provider.scopes.clone(),
76        };
77
78        let mut providers = self.providers.write().await;
79        providers.insert(provider.name.clone(), provider_client);
80
81        tracing::info!("Registered OAuth2 provider: {}", provider.name);
82        Ok(())
83    }
84
85    pub async fn authorize_url(
86        &self,
87        provider_name: &str,
88        use_pkce: bool,
89    ) -> Result<AuthorizationRequest> {
90        let providers = self.providers.read().await;
91        let provider = providers
92            .get(provider_name)
93            .ok_or_else(|| AuthError::OAuth2Error(format!("Provider not found: {}", provider_name)))?;
94
95        let mut auth_request = provider.client.authorize_url(CsrfToken::new_random);
96
97        for scope in &provider.scopes {
98            auth_request = auth_request.add_scope(Scope::new(scope.clone()));
99        }
100
101        let (url, state, pkce_verifier) = if use_pkce {
102            let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
103            let (url, state) = auth_request.set_pkce_challenge(challenge).url();
104            (url, state, Some(verifier.secret().to_string()))
105        } else {
106            let (url, state) = auth_request.url();
107            (url, state, None)
108        };
109
110        Ok(AuthorizationRequest {
111            url: url.to_string(),
112            state: state.secret().to_string(),
113            pkce_verifier,
114        })
115    }
116
117    pub async fn exchange_code(
118        &self,
119        provider_name: &str,
120        code: &str,
121        pkce_verifier: Option<&str>,
122    ) -> Result<TokenExchange> {
123        let providers = self.providers.read().await;
124        let provider = providers
125            .get(provider_name)
126            .ok_or_else(|| AuthError::OAuth2Error(format!("Provider not found: {}", provider_name)))?;
127
128        let mut token_request = provider
129            .client
130            .exchange_code(AuthorizationCode::new(code.to_string()));
131
132        if let Some(verifier) = pkce_verifier {
133            token_request = token_request.set_pkce_verifier(PkceCodeVerifier::new(verifier.to_string()));
134        }
135
136        let token_response = token_request
137            .request_async(async_http_client)
138            .await
139            .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
140
141        Ok(TokenExchange {
142            access_token: token_response.access_token().secret().to_string(),
143            refresh_token: token_response.refresh_token().map(|t| t.secret().to_string()),
144            expires_in: token_response.expires_in().map(|d| d.as_secs()),
145            id_token: None, // OAuth2 crate doesn't provide id_token directly
146        })
147    }
148
149    pub async fn get_user_info(
150        &self,
151        provider_name: &str,
152        access_token: &str,
153    ) -> Result<UserInfo> {
154        match provider_name {
155            "google" => self.get_google_user_info(access_token).await,
156            "github" => self.get_github_user_info(access_token).await,
157            "microsoft" => self.get_microsoft_user_info(access_token).await,
158            _ => Err(AuthError::OAuth2Error(format!(
159                "User info not supported for provider: {}",
160                provider_name
161            ))),
162        }
163    }
164
165    async fn get_google_user_info(&self, access_token: &str) -> Result<UserInfo> {
166        let client = reqwest::Client::new();
167        let response = client
168            .get("https://www.googleapis.com/oauth2/v2/userinfo")
169            .bearer_auth(access_token)
170            .send()
171            .await
172            .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
173
174        #[derive(Deserialize)]
175        struct GoogleUserInfo {
176            id: String,
177            email: String,
178            verified_email: bool,
179            name: Option<String>,
180            picture: Option<String>,
181        }
182
183        let user: GoogleUserInfo = response
184            .json()
185            .await
186            .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
187
188        Ok(UserInfo {
189            id: user.id,
190            email: user.email,
191            email_verified: user.verified_email,
192            name: user.name,
193            picture: user.picture,
194            provider: "google".to_string(),
195        })
196    }
197
198    async fn get_github_user_info(&self, access_token: &str) -> Result<UserInfo> {
199        let client = reqwest::Client::new();
200
201        // Get user info
202        let user_response = client
203            .get("https://api.github.com/user")
204            .bearer_auth(access_token)
205            .header("User-Agent", "AVL-Auth")
206            .send()
207            .await
208            .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
209
210        #[derive(Deserialize)]
211        struct GitHubUser {
212            id: u64,
213            _login: String,
214            name: Option<String>,
215            avatar_url: Option<String>,
216        }
217
218        let user: GitHubUser = user_response
219            .json()
220            .await
221            .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
222
223        // Get primary email
224        let email_response = client
225            .get("https://api.github.com/user/emails")
226            .bearer_auth(access_token)
227            .header("User-Agent", "AVL-Auth")
228            .send()
229            .await
230            .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
231
232        #[derive(Deserialize)]
233        struct GitHubEmail {
234            email: String,
235            primary: bool,
236            verified: bool,
237        }
238
239        let emails: Vec<GitHubEmail> = email_response
240            .json()
241            .await
242            .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
243
244        let primary_email = emails
245            .into_iter()
246            .find(|e| e.primary)
247            .ok_or_else(|| AuthError::OAuth2Error("No primary email found".to_string()))?;
248
249        Ok(UserInfo {
250            id: user.id.to_string(),
251            email: primary_email.email,
252            email_verified: primary_email.verified,
253            name: user.name,
254            picture: user.avatar_url,
255            provider: "github".to_string(),
256        })
257    }
258
259    async fn get_microsoft_user_info(&self, access_token: &str) -> Result<UserInfo> {
260        let client = reqwest::Client::new();
261        let response = client
262            .get("https://graph.microsoft.com/v1.0/me")
263            .bearer_auth(access_token)
264            .send()
265            .await
266            .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
267
268        #[derive(Deserialize)]
269        struct MicrosoftUserInfo {
270            id: String,
271            #[serde(rename = "userPrincipalName")]
272            user_principal_name: String,
273            #[serde(rename = "displayName")]
274            display_name: Option<String>,
275        }
276
277        let user: MicrosoftUserInfo = response
278            .json()
279            .await
280            .map_err(|e| AuthError::OAuth2Error(e.to_string()))?;
281
282        Ok(UserInfo {
283            id: user.id,
284            email: user.user_principal_name,
285            email_verified: true, // Microsoft accounts are verified
286            name: user.display_name,
287            picture: None,
288            provider: "microsoft".to_string(),
289        })
290    }
291}
292
293impl Default for OAuth2Manager {
294    fn default() -> Self {
295        Self::new()
296    }
297}