fraiseql_auth/
provider.rs1use std::fmt;
3
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AuthError, Result};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct UserInfo {
12 pub id: String,
14 pub email: String,
16 pub name: Option<String>,
18 pub picture: Option<String>,
20 pub raw_claims: serde_json::Value,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TokenResponse {
27 pub access_token: String,
29 pub refresh_token: Option<String>,
31 pub expires_in: u64,
33 pub token_type: String,
35}
36
37#[async_trait]
44pub trait OAuthProvider: Send + Sync + fmt::Debug {
45 fn name(&self) -> &str;
47
48 fn authorization_url(&self, state: &str) -> String;
53
54 async fn exchange_code(&self, code: &str) -> Result<TokenResponse>;
62
63 async fn user_info(&self, access_token: &str) -> Result<UserInfo>;
71
72 async fn refresh_token(&self, _refresh_token: &str) -> Result<TokenResponse> {
80 Err(AuthError::OAuthError {
81 message: format!("{} does not support token refresh", self.name()),
82 })
83 }
84
85 async fn revoke_token(&self, _token: &str) -> Result<()> {
90 Ok(())
91 }
92}
93
94#[derive(Debug, Clone)]
98pub struct PkceChallenge {
99 pub verifier: String,
101 pub challenge: String,
103}
104
105impl PkceChallenge {
106 pub fn generate() -> Result<Self> {
113 use sha2::{Digest, Sha256};
114
115 let verifier = generate_pkce_verifier()?;
116
117 let mut hasher = Sha256::new();
118 hasher.update(verifier.as_bytes());
119 let challenge_bytes = hasher.finalize();
120 let challenge = base64_url_encode(&challenge_bytes);
121
122 Ok(Self {
123 verifier,
124 challenge,
125 })
126 }
127
128 pub fn validate(&self, verifier: &str) -> bool {
130 use sha2::{Digest, Sha256};
131
132 let mut hasher = Sha256::new();
133 hasher.update(verifier.as_bytes());
134 let hash = hasher.finalize();
135 let encoded = base64_url_encode(&hash);
136
137 encoded == self.challenge
138 }
139}
140
141fn generate_pkce_verifier() -> Result<String> {
167 use rand::{Rng, rngs::OsRng};
168
169 const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
170 const VERIFIER_LENGTH: usize = 128; const MIN_VERIFIER_LENGTH: usize = 43; let mut rng = OsRng;
176 let verifier: String = (0..VERIFIER_LENGTH)
177 .map(|_| {
178 let idx = rng.gen_range(0..CHARSET.len());
179 CHARSET[idx] as char
180 })
181 .collect();
182
183 if verifier.len() < MIN_VERIFIER_LENGTH {
185 return Err(AuthError::PkceError {
186 message: format!(
187 "Generated PKCE verifier too short: {} < {} chars",
188 verifier.len(),
189 MIN_VERIFIER_LENGTH
190 ),
191 });
192 }
193
194 if verifier.len() > 128 {
195 return Err(AuthError::PkceError {
196 message: format!("Generated PKCE verifier too long: {} > 128 chars", verifier.len()),
197 });
198 }
199
200 let allowed_chars: std::collections::HashSet<char> =
202 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
203 .chars()
204 .collect();
205
206 for (i, c) in verifier.chars().enumerate() {
207 if !allowed_chars.contains(&c) {
208 return Err(AuthError::PkceError {
209 message: format!(
210 "Generated PKCE verifier contains invalid character '{}' at position {}",
211 c, i
212 ),
213 });
214 }
215 }
216
217 Ok(verifier)
218}
219
220fn base64_url_encode(bytes: &[u8]) -> String {
222 use base64::Engine;
223 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
224}
225
226#[allow(clippy::unwrap_used)] #[cfg(test)]
228mod tests {
229 #[allow(clippy::wildcard_imports)]
230 use super::*;
232
233 #[test]
234 fn test_pkce_challenge_generation() {
235 let challenge_result = PkceChallenge::generate();
237 assert!(challenge_result.is_ok(), "PKCE challenge generation should succeed");
238
239 let challenge = challenge_result.unwrap();
240 assert!(!challenge.verifier.is_empty(), "Verifier should not be empty");
241 assert!(!challenge.challenge.is_empty(), "Challenge should not be empty");
242 assert!(
243 challenge.verifier.len() >= 43 && challenge.verifier.len() <= 128,
244 "Verifier length must be 43-128 characters per RFC 7636"
245 );
246 }
247
248 #[test]
249 fn test_pkce_verifier_contains_valid_characters() {
250 let challenge = PkceChallenge::generate().unwrap();
252
253 let allowed_chars: std::collections::HashSet<char> =
254 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
255 .chars()
256 .collect();
257
258 for c in challenge.verifier.chars() {
259 assert!(allowed_chars.contains(&c), "PKCE verifier contains invalid character: {}", c);
260 }
261 }
262
263 #[test]
264 fn test_pkce_validation() {
265 let challenge = PkceChallenge::generate().unwrap();
267 assert!(
268 challenge.validate(&challenge.verifier),
269 "Challenge should validate against its own verifier"
270 );
271
272 let wrong_verifier = "wrong_verifier";
273 assert!(!challenge.validate(wrong_verifier), "Challenge should reject invalid verifier");
274 }
275
276 #[test]
277 fn test_pkce_generation_is_unique() {
278 let challenge1 = PkceChallenge::generate().unwrap();
280 let challenge2 = PkceChallenge::generate().unwrap();
281
282 assert_ne!(
283 challenge1.verifier, challenge2.verifier,
284 "Generated verifiers should be unique"
285 );
286 assert_ne!(
287 challenge1.challenge, challenge2.challenge,
288 "Generated challenges should be unique"
289 );
290 }
291
292 #[test]
293 fn test_pkce_challenge_is_base64_url_safe() {
294 let challenge = PkceChallenge::generate().unwrap();
296
297 assert!(
299 !challenge.challenge.contains('+'),
300 "Challenge should not contain + (not URL-safe)"
301 );
302 assert!(
303 !challenge.challenge.contains('/'),
304 "Challenge should not contain / (not URL-safe)"
305 );
306
307 for c in challenge.challenge.chars() {
309 assert!(
310 c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '=',
311 "Challenge contains unexpected character: {}",
312 c
313 );
314 }
315 }
316
317 #[test]
318 fn test_base64_url_encode() {
319 let bytes = b"hello world";
320 let encoded = base64_url_encode(bytes);
321 assert!(!encoded.is_empty());
322 assert!(!encoded.contains('+'));
324 assert!(!encoded.contains('/'));
325 }
326}