oxidite_auth/oauth2/
client.rs1use serde::{Deserialize, Serialize};
2use url::Url;
3use reqwest::Client;
4use base64::{Engine as _, engine::general_purpose};
5use crate::{AuthError, Result};
6
7#[derive(Clone, Debug)]
9pub struct OAuth2Config {
10 pub client_id: String,
11 pub client_secret: String,
12 pub redirect_uri: String,
13 pub authorization_endpoint: String,
14 pub token_endpoint: String,
15 pub userinfo_endpoint: Option<String>,
16 pub scopes: Vec<String>,
17}
18
19pub struct OAuth2Client {
21 config: OAuth2Config,
22 http_client: Client,
23}
24
25impl OAuth2Client {
26 pub fn new(config: OAuth2Config) -> Self {
27 Self {
28 config,
29 http_client: Client::new(),
30 }
31 }
32
33 pub fn authorization_url(&self, state: &str, code_challenge: Option<&str>) -> Result<String> {
35 let mut url = Url::parse(&self.config.authorization_endpoint)
36 .map_err(|e| AuthError::HashError(e.to_string()))?;
37
38 url.query_pairs_mut()
39 .append_pair("client_id", &self.config.client_id)
40 .append_pair("redirect_uri", &self.config.redirect_uri)
41 .append_pair("response_type", "code")
42 .append_pair("state", state)
43 .append_pair("scope", &self.config.scopes.join(" "));
44
45 if let Some(challenge) = code_challenge {
46 url.query_pairs_mut()
47 .append_pair("code_challenge", challenge)
48 .append_pair("code_challenge_method", "S256");
49 }
50
51 Ok(url.to_string())
52 }
53
54 pub async fn exchange_code_with_state(&self, code: &str, state: &str, expected_state: &str, code_verifier: Option<&str>) -> Result<TokenResponse> {
56 if state != expected_state {
57 return Err(AuthError::TokenError("Invalid OAuth2 state".to_string()));
58 }
59 self.exchange_code(code, code_verifier).await
60 }
61
62 pub async fn exchange_code(&self, code: &str, code_verifier: Option<&str>) -> Result<TokenResponse> {
64 let mut params = vec![
65 ("grant_type", "authorization_code"),
66 ("code", code),
67 ("redirect_uri", &self.config.redirect_uri),
68 ("client_id", &self.config.client_id),
69 ("client_secret", &self.config.client_secret),
70 ];
71
72 if let Some(verifier) = code_verifier {
73 params.push(("code_verifier", verifier));
74 }
75
76 let response = self.http_client
77 .post(&self.config.token_endpoint)
78 .form(¶ms)
79 .send()
80 .await
81 .map_err(|e| AuthError::TokenError(e.to_string()))?;
82
83 let token_response: TokenResponse = response
84 .json()
85 .await
86 .map_err(|e| AuthError::TokenError(e.to_string()))?;
87
88 Ok(token_response)
89 }
90
91 pub async fn get_userinfo(&self, access_token: &str) -> Result<serde_json::Value> {
93 let endpoint = self.config.userinfo_endpoint.as_ref()
94 .ok_or_else(|| AuthError::TokenError("Userinfo endpoint not configured".to_string()))?;
95
96 let response = self.http_client
97 .get(endpoint)
98 .bearer_auth(access_token)
99 .send()
100 .await
101 .map_err(|e| AuthError::TokenError(e.to_string()))?;
102
103 let userinfo = response
104 .json::<serde_json::Value>()
105 .await
106 .map_err(|e| AuthError::TokenError(e.to_string()))?;
107
108 Ok(userinfo)
109 }
110
111 pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
113 let params = vec![
114 ("grant_type", "refresh_token"),
115 ("refresh_token", refresh_token),
116 ("client_id", &self.config.client_id),
117 ("client_secret", &self.config.client_secret),
118 ];
119
120 let response = self.http_client
121 .post(&self.config.token_endpoint)
122 .form(¶ms)
123 .send()
124 .await
125 .map_err(|e| AuthError::TokenError(e.to_string()))?;
126
127 let token_response: TokenResponse = response
128 .json()
129 .await
130 .map_err(|e| AuthError::TokenError(e.to_string()))?;
131
132 Ok(token_response)
133 }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct TokenResponse {
138 pub access_token: String,
139 pub token_type: String,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 pub expires_in: Option<u64>,
142 #[serde(skip_serializing_if = "Option::is_none")]
143 pub refresh_token: Option<String>,
144 #[serde(skip_serializing_if = "Option::is_none")]
145 pub scope: Option<String>,
146}
147
148pub fn generate_pkce() -> (String, String) {
150 use rand::{Rng, distr::{Alphanumeric}};
151
152 let verifier: String = rand::rng()
153 .sample_iter(Alphanumeric)
154 .take(128)
155 .map(char::from)
156 .collect();
157
158 let challenge = general_purpose::URL_SAFE_NO_PAD.encode(
159 ring::digest::digest(&ring::digest::SHA256, verifier.as_bytes()).as_ref()
160 );
161
162 (verifier, challenge)
163}