Skip to main content

wae_authentication/oauth2/
client.rs

1//! OAuth2 客户端实现
2
3use crate::oauth2::{AuthorizationUrl, OAuth2ClientConfig, OAuth2Error, OAuth2Result, TokenResponse, UserInfo};
4use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use wae_request::{HttpClient, HttpClientConfig, HttpError};
8
9/// OAuth2 客户端
10#[derive(Debug, Clone)]
11pub struct OAuth2Client {
12    config: OAuth2ClientConfig,
13    http_client: HttpClient,
14}
15
16impl OAuth2Client {
17    /// 创建新的 OAuth2 客户端
18    ///
19    /// # Arguments
20    /// * `config` - 客户端配置
21    pub fn new(config: OAuth2ClientConfig) -> OAuth2Result<Self> {
22        let http_config = HttpClientConfig {
23            timeout: std::time::Duration::from_millis(config.timeout_ms),
24            connect_timeout: std::time::Duration::from_secs(10),
25            user_agent: "wae-oauth2/0.1.0".to_string(),
26            max_retries: 3,
27            retry_delay: std::time::Duration::from_millis(1000),
28            default_headers: HashMap::new(),
29        };
30
31        let http_client = HttpClient::new(http_config);
32
33        Ok(Self { config, http_client })
34    }
35
36    /// 生成授权 URL
37    ///
38    /// # Returns
39    /// 返回授权 URL、状态参数和 PKCE code verifier
40    pub fn authorization_url(&self) -> OAuth2Result<AuthorizationUrl> {
41        let state = self.generate_state();
42        let mut params: Vec<(String, String)> = vec![
43            ("client_id".to_string(), self.config.provider.client_id.clone()),
44            ("redirect_uri".to_string(), self.config.provider.redirect_uri.clone()),
45            ("response_type".to_string(), "code".to_string()),
46            ("state".to_string(), state.clone()),
47        ];
48
49        if !self.config.provider.scopes.is_empty() {
50            params.push(("scope".to_string(), self.config.provider.scopes.join(" ")));
51        }
52
53        let code_verifier = if self.config.use_pkce {
54            let verifier = self.generate_code_verifier();
55            let challenge = self.generate_code_challenge(&verifier);
56            params.push(("code_challenge".to_string(), challenge));
57            params.push(("code_challenge_method".to_string(), "S256".to_string()));
58            Some(verifier)
59        }
60        else {
61            None
62        };
63
64        for (key, value) in &self.config.provider.extra_params {
65            params.push((key.clone(), value.clone()));
66        }
67
68        let query = params
69            .iter()
70            .map(|(k, v)| format!("{}={}", wae_types::url_encode(k), wae_types::url_encode(v)))
71            .collect::<Vec<_>>()
72            .join("&");
73
74        let url = format!("{}?{}", self.config.provider.authorization_url, query);
75
76        Ok(AuthorizationUrl { url, state, code_verifier })
77    }
78
79    /// 使用授权码交换令牌
80    ///
81    /// # Arguments
82    /// * `code` - 授权码
83    /// * `code_verifier` - PKCE code verifier (如果启用了 PKCE)
84    pub async fn exchange_code(&self, code: &str, code_verifier: Option<&str>) -> OAuth2Result<TokenResponse> {
85        let mut params = HashMap::new();
86        params.insert("grant_type", "authorization_code".to_string());
87        params.insert("code", code.to_string());
88        params.insert("redirect_uri", self.config.provider.redirect_uri.clone());
89        params.insert("client_id", self.config.provider.client_id.clone());
90        params.insert("client_secret", self.config.provider.client_secret.clone());
91
92        if let Some(verifier) = code_verifier {
93            params.insert("code_verifier", verifier.to_string());
94        }
95
96        let form_body = self.encode_form_data(&params);
97
98        let response = self
99            .http_client
100            .post_with_headers(&self.config.provider.token_url, form_body.into_bytes(), self.form_headers())
101            .await
102            .map_err(OAuth2Error::from)?;
103
104        if !response.is_success() {
105            let error_text = response.text().unwrap_or_default();
106            return Err(OAuth2Error::ProviderError(error_text));
107        }
108
109        let token_response: TokenResponse = response.json().map_err(OAuth2Error::from)?;
110        Ok(token_response)
111    }
112
113    /// 刷新访问令牌
114    ///
115    /// # Arguments
116    /// * `refresh_token` - 刷新令牌
117    pub async fn refresh_token(&self, refresh_token: &str) -> OAuth2Result<TokenResponse> {
118        let mut params = HashMap::new();
119        params.insert("grant_type", "refresh_token".to_string());
120        params.insert("refresh_token", refresh_token.to_string());
121        params.insert("client_id", self.config.provider.client_id.clone());
122        params.insert("client_secret", self.config.provider.client_secret.clone());
123
124        let form_body = self.encode_form_data(&params);
125
126        let response = self
127            .http_client
128            .post_with_headers(&self.config.provider.token_url, form_body.into_bytes(), self.form_headers())
129            .await
130            .map_err(OAuth2Error::from)?;
131
132        if !response.is_success() {
133            let error_text = response.text().unwrap_or_default();
134            return Err(OAuth2Error::ProviderError(error_text));
135        }
136
137        let token_response: TokenResponse = response.json().map_err(OAuth2Error::from)?;
138        Ok(token_response)
139    }
140
141    /// 获取用户信息
142    ///
143    /// # Arguments
144    /// * `access_token` - 访问令牌
145    pub async fn get_user_info(&self, access_token: &str) -> OAuth2Result<UserInfo> {
146        let userinfo_url = self
147            .config
148            .provider
149            .userinfo_url
150            .as_ref()
151            .ok_or_else(|| OAuth2Error::ConfigurationError("userinfo_url not configured".into()))?;
152
153        let mut headers = HashMap::new();
154        headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
155
156        let response = self.http_client.get_with_headers(userinfo_url, headers).await.map_err(OAuth2Error::from)?;
157
158        if !response.is_success() {
159            let error_text = response.text().unwrap_or_default();
160            return Err(OAuth2Error::ProviderError(error_text));
161        }
162
163        let user_info: UserInfo = response.json().map_err(OAuth2Error::from)?;
164        Ok(user_info)
165    }
166
167    /// 撤销令牌
168    ///
169    /// # Arguments
170    /// * `token` - 要撤销的令牌
171    /// * `token_type_hint` - 令牌类型提示 (access_token 或 refresh_token)
172    pub async fn revoke_token(&self, token: &str, token_type_hint: Option<&str>) -> OAuth2Result<()> {
173        let revocation_url = self
174            .config
175            .provider
176            .revocation_url
177            .as_ref()
178            .ok_or_else(|| OAuth2Error::ConfigurationError("revocation_url not configured".into()))?;
179
180        let mut params = HashMap::new();
181        params.insert("token", token.to_string());
182        params.insert("client_id", self.config.provider.client_id.clone());
183        params.insert("client_secret", self.config.provider.client_secret.clone());
184
185        if let Some(hint) = token_type_hint {
186            params.insert("token_type_hint", hint.to_string());
187        }
188
189        let form_body = self.encode_form_data(&params);
190
191        let response = self
192            .http_client
193            .post_with_headers(revocation_url, form_body.into_bytes(), self.form_headers())
194            .await
195            .map_err(OAuth2Error::from)?;
196
197        if !response.is_success() {
198            let error_text = response.text().unwrap_or_default();
199            return Err(OAuth2Error::ProviderError(error_text));
200        }
201
202        Ok(())
203    }
204
205    /// 验证状态参数
206    ///
207    /// # Arguments
208    /// * `expected` - 期望的状态值
209    /// * `received` - 接收到的状态值
210    pub fn validate_state(&self, expected: &str, received: &str) -> OAuth2Result<()> {
211        if !self.config.use_state {
212            return Ok(());
213        }
214
215        if expected == received { Ok(()) } else { Err(OAuth2Error::StateMismatch) }
216    }
217
218    fn generate_state(&self) -> String {
219        uuid::Uuid::new_v4().to_string().replace('-', "")
220    }
221
222    fn generate_code_verifier(&self) -> String {
223        let random_bytes: Vec<u8> = (0..32).map(|_| rand::random::<u8>()).collect();
224        URL_SAFE_NO_PAD.encode(&random_bytes)
225    }
226
227    fn generate_code_challenge(&self, verifier: &str) -> String {
228        let mut hasher = Sha256::new();
229        hasher.update(verifier.as_bytes());
230        let hash = hasher.finalize();
231        URL_SAFE_NO_PAD.encode(&hash)
232    }
233
234    fn encode_form_data(&self, params: &HashMap<&str, String>) -> String {
235        params
236            .iter()
237            .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
238            .collect::<Vec<_>>()
239            .join("&")
240    }
241
242    fn form_headers(&self) -> HashMap<String, String> {
243        let mut headers = HashMap::new();
244        headers.insert("Content-Type".to_string(), "application/x-www-form-urlencoded".to_string());
245        headers
246    }
247
248    /// 获取提供者名称
249    pub fn provider_name(&self) -> &str {
250        &self.config.provider.name
251    }
252
253    /// 获取配置
254    pub fn config(&self) -> &OAuth2ClientConfig {
255        &self.config
256    }
257}
258
259impl From<HttpError> for OAuth2Error {
260    fn from(err: HttpError) -> Self {
261        match err {
262            HttpError::InvalidUrl(msg) => OAuth2Error::ConfigurationError(msg),
263            HttpError::Timeout => OAuth2Error::RequestError("Request timeout".into()),
264            HttpError::ConnectionFailed(msg) => OAuth2Error::RequestError(msg),
265            HttpError::DnsFailed(msg) => OAuth2Error::RequestError(msg),
266            HttpError::TlsError(msg) => OAuth2Error::RequestError(msg),
267            HttpError::StatusError { status, body } => match status {
268                401 => OAuth2Error::AccessDenied(body),
269                403 => OAuth2Error::AccessDenied(body),
270                _ => OAuth2Error::ProviderError(format!("HTTP {}: {}", status, body)),
271            },
272            _ => OAuth2Error::RequestError(err.to_string()),
273        }
274    }
275}
276
277mod urlencoding {
278    pub fn encode(s: &str) -> String {
279        url::form_urlencoded::byte_serialize(s.as_bytes()).collect()
280    }
281}