1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
// OAuth 2.0 / OIDC provider trait and implementations
use std::fmt;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::auth::error::{AuthError, Result};
/// User information retrieved from OAuth provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInfo {
/// Unique user identifier from provider
pub id: String,
/// User's email address
pub email: String,
/// User's display name (optional)
pub name: Option<String>,
/// User's profile picture URL (optional)
pub picture: Option<String>,
/// Raw claims from provider (for custom fields)
pub raw_claims: serde_json::Value,
}
/// Token response from OAuth provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
/// Access token (short-lived)
pub access_token: String,
/// Refresh token if provider supports it
pub refresh_token: Option<String>,
/// Token expiration in seconds
pub expires_in: u64,
/// Token type (typically "Bearer")
pub token_type: String,
}
/// OAuth 2.0 / OIDC provider trait
///
/// Implement this trait to add support for custom OAuth providers.
#[async_trait]
pub trait OAuthProvider: Send + Sync + fmt::Debug {
/// Provider name for logging/debugging
fn name(&self) -> &str;
/// Generate authorization URL for user to visit
///
/// # Arguments
/// * `state` - CSRF protection state (should be cryptographically random)
fn authorization_url(&self, state: &str) -> String;
/// Exchange authorization code for tokens
///
/// # Arguments
/// * `code` - Authorization code from provider
///
/// # Returns
/// Token response with access_token and optional refresh_token
async fn exchange_code(&self, code: &str) -> Result<TokenResponse>;
/// Get user information using access token
///
/// # Arguments
/// * `access_token` - The access token to use for API call
///
/// # Returns
/// UserInfo with user details from provider
async fn user_info(&self, access_token: &str) -> Result<UserInfo>;
/// Refresh the access token (optional, default returns error)
///
/// # Arguments
/// * `refresh_token` - The refresh token
///
/// # Returns
/// New TokenResponse if provider supports refresh
async fn refresh_token(&self, _refresh_token: &str) -> Result<TokenResponse> {
Err(AuthError::OAuthError {
message: format!("{} does not support token refresh", self.name()),
})
}
/// Revoke a token (optional, default is no-op)
///
/// # Arguments
/// * `token` - Token to revoke
async fn revoke_token(&self, _token: &str) -> Result<()> {
Ok(())
}
}
/// PKCE (Proof Key for Public Clients) helper
///
/// Used to prevent authorization code interception attacks
#[derive(Debug, Clone)]
pub struct PkceChallenge {
/// Generated code verifier (cryptographically random)
pub verifier: String,
/// Code challenge (SHA256 hash of verifier)
pub challenge: String,
}
impl PkceChallenge {
/// Generate a new PKCE challenge
pub fn generate() -> Result<Self> {
use sha2::{Digest, Sha256};
let verifier = generate_pkce_verifier()?;
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let challenge_bytes = hasher.finalize();
let challenge = base64_url_encode(&challenge_bytes);
Ok(Self {
verifier,
challenge,
})
}
/// Validate a verifier against a challenge
pub fn validate(&self, verifier: &str) -> bool {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
let encoded = base64_url_encode(&hash);
encoded == self.challenge
}
}
/// Generate a PKCE verifier (43-128 characters of unreserved characters)
fn generate_pkce_verifier() -> Result<String> {
use rand::Rng;
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
const VERIFIER_LENGTH: usize = 128;
let mut rng = rand::thread_rng();
let verifier: String = (0..VERIFIER_LENGTH)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect();
Ok(verifier)
}
/// URL-safe base64 encoding for PKCE
fn base64_url_encode(bytes: &[u8]) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pkce_challenge_generation() {
let challenge = PkceChallenge::generate().expect("Failed to generate challenge");
assert!(!challenge.verifier.is_empty());
assert!(!challenge.challenge.is_empty());
assert!(challenge.verifier.len() >= 43 && challenge.verifier.len() <= 128);
}
#[test]
fn test_pkce_validation() {
let challenge = PkceChallenge::generate().expect("Failed to generate challenge");
assert!(challenge.validate(&challenge.verifier));
let wrong_verifier = "wrong_verifier";
assert!(!challenge.validate(wrong_verifier));
}
#[test]
fn test_base64_url_encode() {
let bytes = b"hello world";
let encoded = base64_url_encode(bytes);
assert!(!encoded.is_empty());
// URL-safe base64 should not contain + or /
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
}
}