use std::fmt;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::{AuthError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInfo {
pub id: String,
pub email: String,
pub name: Option<String>,
pub picture: Option<String>,
pub raw_claims: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_in: u64,
pub token_type: String,
}
#[async_trait]
pub trait OAuthProvider: Send + Sync + fmt::Debug {
fn name(&self) -> &str;
fn authorization_url(&self, state: &str) -> String;
async fn exchange_code(&self, code: &str) -> Result<TokenResponse>;
async fn user_info(&self, access_token: &str) -> Result<UserInfo>;
async fn refresh_token(&self, _refresh_token: &str) -> Result<TokenResponse> {
Err(AuthError::OAuthError {
message: format!("{} does not support token refresh", self.name()),
})
}
async fn revoke_token(&self, _token: &str) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PkceChallenge {
pub verifier: String,
pub challenge: String,
}
impl PkceChallenge {
pub fn generate() -> Result<Self> {
use sha2::{Digest, Sha256};
let verifier = generate_pkce_verifier()?;
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let challenge_bytes = hasher.finalize();
let challenge = base64_url_encode(&challenge_bytes);
Ok(Self {
verifier,
challenge,
})
}
pub fn validate(&self, verifier: &str) -> bool {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
let encoded = base64_url_encode(&hash);
encoded == self.challenge
}
}
fn generate_pkce_verifier() -> Result<String> {
use rand::Rng;
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
const VERIFIER_LENGTH: usize = 128; const MIN_VERIFIER_LENGTH: usize = 43;
let mut rng = rand::thread_rng();
let verifier: String = (0..VERIFIER_LENGTH)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect();
if verifier.len() < MIN_VERIFIER_LENGTH {
return Err(AuthError::PkceError {
message: format!(
"Generated PKCE verifier too short: {} < {} chars",
verifier.len(),
MIN_VERIFIER_LENGTH
),
});
}
if verifier.len() > 128 {
return Err(AuthError::PkceError {
message: format!("Generated PKCE verifier too long: {} > 128 chars", verifier.len()),
});
}
let allowed_chars: std::collections::HashSet<char> =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
.chars()
.collect();
for (i, c) in verifier.chars().enumerate() {
if !allowed_chars.contains(&c) {
return Err(AuthError::PkceError {
message: format!(
"Generated PKCE verifier contains invalid character '{}' at position {}",
c, i
),
});
}
}
Ok(verifier)
}
fn base64_url_encode(bytes: &[u8]) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pkce_challenge_generation() {
let challenge_result = PkceChallenge::generate();
assert!(challenge_result.is_ok(), "PKCE challenge generation should succeed");
let challenge = challenge_result.unwrap();
assert!(!challenge.verifier.is_empty(), "Verifier should not be empty");
assert!(!challenge.challenge.is_empty(), "Challenge should not be empty");
assert!(
challenge.verifier.len() >= 43 && challenge.verifier.len() <= 128,
"Verifier length must be 43-128 characters per RFC 7636"
);
}
#[test]
fn test_pkce_verifier_contains_valid_characters() {
let challenge = PkceChallenge::generate().unwrap();
let allowed_chars: std::collections::HashSet<char> =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
.chars()
.collect();
for c in challenge.verifier.chars() {
assert!(allowed_chars.contains(&c), "PKCE verifier contains invalid character: {}", c);
}
}
#[test]
fn test_pkce_validation() {
let challenge = PkceChallenge::generate().unwrap();
assert!(
challenge.validate(&challenge.verifier),
"Challenge should validate against its own verifier"
);
let wrong_verifier = "wrong_verifier";
assert!(!challenge.validate(wrong_verifier), "Challenge should reject invalid verifier");
}
#[test]
fn test_pkce_generation_is_unique() {
let challenge1 = PkceChallenge::generate().unwrap();
let challenge2 = PkceChallenge::generate().unwrap();
assert_ne!(
challenge1.verifier, challenge2.verifier,
"Generated verifiers should be unique"
);
assert_ne!(
challenge1.challenge, challenge2.challenge,
"Generated challenges should be unique"
);
}
#[test]
fn test_pkce_challenge_is_base64_url_safe() {
let challenge = PkceChallenge::generate().unwrap();
assert!(
!challenge.challenge.contains('+'),
"Challenge should not contain + (not URL-safe)"
);
assert!(
!challenge.challenge.contains('/'),
"Challenge should not contain / (not URL-safe)"
);
for c in challenge.challenge.chars() {
assert!(
c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '=',
"Challenge contains unexpected character: {}",
c
);
}
}
#[test]
fn test_base64_url_encode() {
let bytes = b"hello world";
let encoded = base64_url_encode(bytes);
assert!(!encoded.is_empty());
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
}
}