Skip to main content

oxidite_auth/oauth2/
client.rs

1use serde::{Deserialize, Serialize};
2use url::Url;
3use reqwest::Client;
4use base64::{Engine as _, engine::general_purpose};
5use crate::{AuthError, Result};
6
7/// OAuth2 client configuration
8#[derive(Clone, Debug)]
9pub struct OAuth2Config {
10    pub client_id: String,
11    pub client_secret: String,
12    pub redirect_uri: String,
13    pub authorization_endpoint: String,
14    pub token_endpoint: String,
15    pub userinfo_endpoint: Option<String>,
16    pub scopes: Vec<String>,
17}
18
19/// OAuth2 client
20pub struct OAuth2Client {
21    config: OAuth2Config,
22    http_client: Client,
23}
24
25impl OAuth2Client {
26    pub fn new(config: OAuth2Config) -> Self {
27        Self {
28            config,
29            http_client: Client::new(),
30        }
31    }
32
33    /// Generate authorization URL with PKCE
34    pub fn authorization_url(&self, state: &str, code_challenge: Option<&str>) -> Result<String> {
35        let mut url = Url::parse(&self.config.authorization_endpoint)
36            .map_err(|e| AuthError::HashError(e.to_string()))?;
37
38        url.query_pairs_mut()
39            .append_pair("client_id", &self.config.client_id)
40            .append_pair("redirect_uri", &self.config.redirect_uri)
41            .append_pair("response_type", "code")
42            .append_pair("state", state)
43            .append_pair("scope", &self.config.scopes.join(" "));
44
45        if let Some(challenge) = code_challenge {
46            url.query_pairs_mut()
47                .append_pair("code_challenge", challenge)
48                .append_pair("code_challenge_method", "S256");
49        }
50
51        Ok(url.to_string())
52    }
53
54    /// Exchange authorization code for access token with state validation
55    pub async fn exchange_code_with_state(&self, code: &str, state: &str, expected_state: &str, code_verifier: Option<&str>) -> Result<TokenResponse> {
56        if state != expected_state {
57            return Err(AuthError::TokenError("Invalid OAuth2 state".to_string()));
58        }
59        self.exchange_code(code, code_verifier).await
60    }
61
62    /// Exchange authorization code for access token
63    pub async fn exchange_code(&self, code: &str, code_verifier: Option<&str>) -> Result<TokenResponse> {
64        let mut params = vec![
65            ("grant_type", "authorization_code"),
66            ("code", code),
67            ("redirect_uri", &self.config.redirect_uri),
68            ("client_id", &self.config.client_id),
69            ("client_secret", &self.config.client_secret),
70        ];
71
72        if let Some(verifier) = code_verifier {
73            params.push(("code_verifier", verifier));
74        }
75
76        let response = self.http_client
77            .post(&self.config.token_endpoint)
78            .form(&params)
79            .send()
80            .await
81            .map_err(|e| AuthError::TokenError(e.to_string()))?;
82
83        let token_response: TokenResponse = response
84            .json()
85            .await
86            .map_err(|e| AuthError::TokenError(e.to_string()))?;
87
88        Ok(token_response)
89    }
90
91    /// Get user info from the provider's userinfo endpoint
92    pub async fn get_userinfo(&self, access_token: &str) -> Result<serde_json::Value> {
93        let endpoint = self.config.userinfo_endpoint.as_ref()
94            .ok_or_else(|| AuthError::TokenError("Userinfo endpoint not configured".to_string()))?;
95
96        let response = self.http_client
97            .get(endpoint)
98            .bearer_auth(access_token)
99            .send()
100            .await
101            .map_err(|e| AuthError::TokenError(e.to_string()))?;
102
103        let userinfo = response
104            .json::<serde_json::Value>()
105            .await
106            .map_err(|e| AuthError::TokenError(e.to_string()))?;
107
108        Ok(userinfo)
109    }
110
111    /// Refresh access token
112    pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
113        let params = vec![
114            ("grant_type", "refresh_token"),
115            ("refresh_token", refresh_token),
116            ("client_id", &self.config.client_id),
117            ("client_secret", &self.config.client_secret),
118        ];
119
120        let response = self.http_client
121            .post(&self.config.token_endpoint)
122            .form(&params)
123            .send()
124            .await
125            .map_err(|e| AuthError::TokenError(e.to_string()))?;
126
127        let token_response: TokenResponse = response
128            .json()
129            .await
130            .map_err(|e| AuthError::TokenError(e.to_string()))?;
131
132        Ok(token_response)
133    }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct TokenResponse {
138    pub access_token: String,
139    pub token_type: String,
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub expires_in: Option<u64>,
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub refresh_token: Option<String>,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub scope: Option<String>,
146}
147
148/// Generate PKCE code verifier and challenge
149pub fn generate_pkce() -> (String, String) {
150    use rand::{Rng, distr::{Alphanumeric}};
151    
152    let verifier: String = rand::rng()
153        .sample_iter(Alphanumeric)
154        .take(128)
155        .map(char::from)
156        .collect();
157
158    let challenge = general_purpose::URL_SAFE_NO_PAD.encode(
159        ring::digest::digest(&ring::digest::SHA256, verifier.as_bytes()).as_ref()
160    );
161
162    (verifier, challenge)
163}