Skip to main content

oxidite_auth/oauth2/
client.rs

1use serde::{Deserialize, Serialize};
2use url::Url;
3use reqwest::Client;
4use base64::{Engine as _, engine::general_purpose};
5use crate::{AuthError, Result};
6
7/// OAuth2 client configuration
8#[derive(Clone, Debug)]
9pub struct OAuth2Config {
10    pub client_id: String,
11    pub client_secret: String,
12    pub redirect_uri: String,
13    pub authorization_endpoint: String,
14    pub token_endpoint: String,
15    pub scopes: Vec<String>,
16}
17
18/// OAuth2 client
19pub struct OAuth2Client {
20    config: OAuth2Config,
21    http_client: Client,
22}
23
24impl OAuth2Client {
25    pub fn new(config: OAuth2Config) -> Self {
26        Self {
27            config,
28            http_client: Client::new(),
29        }
30    }
31
32    /// Generate authorization URL with PKCE
33    pub fn authorization_url(&self, state: &str, code_challenge: Option<&str>) -> Result<String> {
34        let mut url = Url::parse(&self.config.authorization_endpoint)
35            .map_err(|e| AuthError::HashError(e.to_string()))?;
36
37        url.query_pairs_mut()
38            .append_pair("client_id", &self.config.client_id)
39            .append_pair("redirect_uri", &self.config.redirect_uri)
40            .append_pair("response_type", "code")
41            .append_pair("state", state)
42            .append_pair("scope", &self.config.scopes.join(" "));
43
44        if let Some(challenge) = code_challenge {
45            url.query_pairs_mut()
46                .append_pair("code_challenge", challenge)
47                .append_pair("code_challenge_method", "S256");
48        }
49
50        Ok(url.to_string())
51    }
52
53    /// Exchange authorization code for access token
54    pub async fn exchange_code(&self, code: &str, code_verifier: Option<&str>) -> Result<TokenResponse> {
55        let mut params = vec![
56            ("grant_type", "authorization_code"),
57            ("code", code),
58            ("redirect_uri", &self.config.redirect_uri),
59            ("client_id", &self.config.client_id),
60            ("client_secret", &self.config.client_secret),
61        ];
62
63        if let Some(verifier) = code_verifier {
64            params.push(("code_verifier", verifier));
65        }
66
67        let response = self.http_client
68            .post(&self.config.token_endpoint)
69            .form(&params)
70            .send()
71            .await
72            .map_err(|e| AuthError::HashError(e.to_string()))?;
73
74        let token_response: TokenResponse = response
75            .json()
76            .await
77            .map_err(|e| AuthError::HashError(e.to_string()))?;
78
79        Ok(token_response)
80    }
81
82    /// Refresh access token
83    pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
84        let params = vec![
85            ("grant_type", "refresh_token"),
86            ("refresh_token", refresh_token),
87            ("client_id", &self.config.client_id),
88            ("client_secret", &self.config.client_secret),
89        ];
90
91        let response = self.http_client
92            .post(&self.config.token_endpoint)
93            .form(&params)
94            .send()
95            .await
96            .map_err(|e| AuthError::HashError(e.to_string()))?;
97
98        let token_response: TokenResponse = response
99            .json()
100            .await
101            .map_err(|e| AuthError::HashError(e.to_string()))?;
102
103        Ok(token_response)
104    }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct TokenResponse {
109    pub access_token: String,
110    pub token_type: String,
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub expires_in: Option<u64>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub refresh_token: Option<String>,
115    #[serde(skip_serializing_if = "Option::is_none")]
116    pub scope: Option<String>,
117}
118
119/// Generate PKCE code verifier and challenge
120pub fn generate_pkce() -> (String, String) {
121    use rand::{Rng, distr::{Alphanumeric}};
122    
123    let verifier: String = rand::rng()
124        .sample_iter(Alphanumeric)
125        .take(128)
126        .map(char::from)
127        .collect();
128
129    let challenge = general_purpose::URL_SAFE_NO_PAD.encode(
130        ring::digest::digest(&ring::digest::SHA256, verifier.as_bytes()).as_ref()
131    );
132
133    (verifier, challenge)
134}