Skip to main content

fraiseql_auth/
provider.rs

1//! OAuth 2.0 / OIDC provider trait and core data types.
2use std::fmt;
3
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6
7use crate::error::{AuthError, Result};
8
9/// User information retrieved from OAuth provider
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct UserInfo {
12    /// Unique user identifier from provider
13    pub id:         String,
14    /// User's email address
15    pub email:      String,
16    /// User's display name (optional)
17    pub name:       Option<String>,
18    /// User's profile picture URL (optional)
19    pub picture:    Option<String>,
20    /// Raw claims from provider (for custom fields)
21    pub raw_claims: serde_json::Value,
22}
23
24/// Token response from OAuth provider
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TokenResponse {
27    /// Access token (short-lived)
28    pub access_token:  String,
29    /// Refresh token if provider supports it
30    pub refresh_token: Option<String>,
31    /// Token expiration in seconds
32    pub expires_in:    u64,
33    /// Token type (typically "Bearer")
34    pub token_type:    String,
35}
36
37/// OAuth 2.0 / OIDC provider trait
38///
39/// Implement this trait to add support for custom OAuth providers.
40// Reason: used as dyn Trait (Arc<dyn OAuthProvider>, Box<dyn OAuthProvider>); async_trait ensures
41// Send bounds and dyn-compatibility async_trait: dyn-dispatch required; remove when RTN + Send is
42// stable (RFC 3425)
43#[async_trait]
44pub trait OAuthProvider: Send + Sync + fmt::Debug {
45    /// Provider name for logging/debugging
46    fn name(&self) -> &str;
47
48    /// Generate authorization URL for user to visit
49    ///
50    /// # Arguments
51    /// * `state` - CSRF protection state (should be cryptographically random)
52    fn authorization_url(&self, state: &str) -> String;
53
54    /// Exchange authorization code for tokens
55    ///
56    /// # Arguments
57    /// * `code` - Authorization code from provider
58    ///
59    /// # Returns
60    /// Token response with access_token and optional refresh_token
61    async fn exchange_code(&self, code: &str) -> Result<TokenResponse>;
62
63    /// Get user information using access token
64    ///
65    /// # Arguments
66    /// * `access_token` - The access token to use for API call
67    ///
68    /// # Returns
69    /// UserInfo with user details from provider
70    async fn user_info(&self, access_token: &str) -> Result<UserInfo>;
71
72    /// Refresh the access token (optional, default returns error)
73    ///
74    /// # Arguments
75    /// * `refresh_token` - The refresh token
76    ///
77    /// # Returns
78    /// New TokenResponse if provider supports refresh
79    async fn refresh_token(&self, _refresh_token: &str) -> Result<TokenResponse> {
80        Err(AuthError::OAuthError {
81            message: format!("{} does not support token refresh", self.name()),
82        })
83    }
84
85    /// Revoke a token (optional, default is no-op)
86    ///
87    /// # Arguments
88    /// * `token` - Token to revoke
89    async fn revoke_token(&self, _token: &str) -> Result<()> {
90        Ok(())
91    }
92}
93
94/// PKCE (Proof Key for Public Clients) helper
95///
96/// Used to prevent authorization code interception attacks
97#[derive(Debug, Clone)]
98pub struct PkceChallenge {
99    /// Generated code verifier (cryptographically random)
100    pub verifier:  String,
101    /// Code challenge (SHA256 hash of verifier)
102    pub challenge: String,
103}
104
105impl PkceChallenge {
106    /// Generate a new PKCE challenge.
107    ///
108    /// # Errors
109    ///
110    /// Returns [`AuthError::PkceError`] if the generated verifier fails RFC 7636
111    /// length or character-set constraints (essentially never in practice).
112    pub fn generate() -> Result<Self> {
113        use sha2::{Digest, Sha256};
114
115        let verifier = generate_pkce_verifier()?;
116
117        let mut hasher = Sha256::new();
118        hasher.update(verifier.as_bytes());
119        let challenge_bytes = hasher.finalize();
120        let challenge = base64_url_encode(&challenge_bytes);
121
122        Ok(Self {
123            verifier,
124            challenge,
125        })
126    }
127
128    /// Validate a verifier against a challenge
129    pub fn validate(&self, verifier: &str) -> bool {
130        use sha2::{Digest, Sha256};
131
132        let mut hasher = Sha256::new();
133        hasher.update(verifier.as_bytes());
134        let hash = hasher.finalize();
135        let encoded = base64_url_encode(&hash);
136
137        encoded == self.challenge
138    }
139}
140
141/// Generate a PKCE verifier (43-128 characters of unreserved characters)
142///
143/// # SECURITY
144///
145/// This uses `rand::thread_rng()` which is cryptographically secure on all major platforms.
146/// It generates a 128-character random string using only unreserved characters as per RFC 7636.
147///
148/// The generated verifier meets these requirements:
149/// - Length: exactly 128 characters (within 43-128 range)
150/// - Characters: only unreserved ASCII characters: [A-Z a-z 0-9 - . _ ~]
151/// - Randomness: cryptographically secure pseudorandom generation
152/// - No padding: can be used directly in PKCE challenge
153///
154/// # Errors
155///
156/// Returns error if:
157/// - Random number generation fails (extremely rare)
158/// - Generated verifier is invalid (should never happen given the constraints)
159///
160/// # Implementation Notes
161///
162/// We use a fixed 128-character length (maximum allowed by RFC 7636) for:
163/// 1. Maximum security: more entropy means harder to guess
164/// 2. Consistency: predictable length for tests and monitoring
165/// 3. Compatibility: all OAuth providers support 128-char verifiers
166fn generate_pkce_verifier() -> Result<String> {
167    use rand::{Rng, rngs::OsRng};
168
169    const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
170    const VERIFIER_LENGTH: usize = 128; // Maximum allowed by RFC 7636
171    const MIN_VERIFIER_LENGTH: usize = 43; // Minimum allowed by RFC 7636
172
173    // SECURITY: OsRng is used instead of thread_rng() to guarantee OS-level
174    // entropy for PKCE verifiers, regardless of process startup state.
175    let mut rng = OsRng;
176    let verifier: String = (0..VERIFIER_LENGTH)
177        .map(|_| {
178            let idx = rng.gen_range(0..CHARSET.len());
179            CHARSET[idx] as char
180        })
181        .collect();
182
183    // Validate the generated verifier meets RFC 7636 requirements
184    if verifier.len() < MIN_VERIFIER_LENGTH {
185        return Err(AuthError::PkceError {
186            message: format!(
187                "Generated PKCE verifier too short: {} < {} chars",
188                verifier.len(),
189                MIN_VERIFIER_LENGTH
190            ),
191        });
192    }
193
194    if verifier.len() > 128 {
195        return Err(AuthError::PkceError {
196            message: format!("Generated PKCE verifier too long: {} > 128 chars", verifier.len()),
197        });
198    }
199
200    // Verify all characters are from the allowed charset
201    let allowed_chars: std::collections::HashSet<char> =
202        "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
203            .chars()
204            .collect();
205
206    for (i, c) in verifier.chars().enumerate() {
207        if !allowed_chars.contains(&c) {
208            return Err(AuthError::PkceError {
209                message: format!(
210                    "Generated PKCE verifier contains invalid character '{}' at position {}",
211                    c, i
212                ),
213            });
214        }
215    }
216
217    Ok(verifier)
218}
219
220/// URL-safe base64 encoding for PKCE
221fn base64_url_encode(bytes: &[u8]) -> String {
222    use base64::Engine;
223    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
224}
225
226#[allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
227#[cfg(test)]
228mod tests {
229    #[allow(clippy::wildcard_imports)]
230    // Reason: test module — wildcard keeps test boilerplate minimal
231    use super::*;
232
233    #[test]
234    fn test_pkce_challenge_generation() {
235        // Test proper error handling - generation should always succeed
236        let challenge_result = PkceChallenge::generate();
237        assert!(challenge_result.is_ok(), "PKCE challenge generation should succeed");
238
239        let challenge = challenge_result.unwrap();
240        assert!(!challenge.verifier.is_empty(), "Verifier should not be empty");
241        assert!(!challenge.challenge.is_empty(), "Challenge should not be empty");
242        assert!(
243            challenge.verifier.len() >= 43 && challenge.verifier.len() <= 128,
244            "Verifier length must be 43-128 characters per RFC 7636"
245        );
246    }
247
248    #[test]
249    fn test_pkce_verifier_contains_valid_characters() {
250        // Verify that generated verifier only contains unreserved characters
251        let challenge = PkceChallenge::generate().unwrap();
252
253        let allowed_chars: std::collections::HashSet<char> =
254            "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
255                .chars()
256                .collect();
257
258        for c in challenge.verifier.chars() {
259            assert!(allowed_chars.contains(&c), "PKCE verifier contains invalid character: {}", c);
260        }
261    }
262
263    #[test]
264    fn test_pkce_validation() {
265        // Test that validation works correctly
266        let challenge = PkceChallenge::generate().unwrap();
267        assert!(
268            challenge.validate(&challenge.verifier),
269            "Challenge should validate against its own verifier"
270        );
271
272        let wrong_verifier = "wrong_verifier";
273        assert!(!challenge.validate(wrong_verifier), "Challenge should reject invalid verifier");
274    }
275
276    #[test]
277    fn test_pkce_generation_is_unique() {
278        // Test that multiple PKCE generations produce different verifiers
279        let challenge1 = PkceChallenge::generate().unwrap();
280        let challenge2 = PkceChallenge::generate().unwrap();
281
282        assert_ne!(
283            challenge1.verifier, challenge2.verifier,
284            "Generated verifiers should be unique"
285        );
286        assert_ne!(
287            challenge1.challenge, challenge2.challenge,
288            "Generated challenges should be unique"
289        );
290    }
291
292    #[test]
293    fn test_pkce_challenge_is_base64_url_safe() {
294        // Verify that challenge is URL-safe base64 encoded
295        let challenge = PkceChallenge::generate().unwrap();
296
297        // URL-safe base64 should not contain + or / (only -, _, and =)
298        assert!(
299            !challenge.challenge.contains('+'),
300            "Challenge should not contain + (not URL-safe)"
301        );
302        assert!(
303            !challenge.challenge.contains('/'),
304            "Challenge should not contain / (not URL-safe)"
305        );
306
307        // But should only contain valid base64 characters
308        for c in challenge.challenge.chars() {
309            assert!(
310                c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '=',
311                "Challenge contains unexpected character: {}",
312                c
313            );
314        }
315    }
316
317    #[test]
318    fn test_base64_url_encode() {
319        let bytes = b"hello world";
320        let encoded = base64_url_encode(bytes);
321        assert!(!encoded.is_empty());
322        // URL-safe base64 should not contain + or /
323        assert!(!encoded.contains('+'));
324        assert!(!encoded.contains('/'));
325    }
326}