pub mod github;
pub mod google;
pub mod wechat;
const OAUTH_TIMEOUT_SECS: u64 = 10;
pub(crate) fn http_client() -> reqwest::Client {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(OAUTH_TIMEOUT_SECS))
.build()
.unwrap_or_else(|_| reqwest::Client::new())
}
use std::collections::HashMap;
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::errors::app_error::AppResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthUserInfo {
pub provider_user_id: String,
pub email: Option<String>,
pub display_name: Option<String>,
pub avatar_url: Option<String>,
pub raw_profile: serde_json::Value,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OAuthTokenResponse {
pub access_token: String,
#[allow(dead_code)]
pub token_type: Option<String>,
pub refresh_token: Option<String>,
pub expires_in: Option<u64>,
#[allow(dead_code)]
pub scope: Option<String>,
}
#[async_trait::async_trait]
pub trait OAuthProvider: Send + Sync {
fn name(&self) -> &str;
fn authorize_url(&self, state: &str, code_challenge: &str) -> String;
async fn exchange_code(&self, code: &str, code_verifier: &str)
-> AppResult<OAuthTokenResponse>;
async fn fetch_user_info(&self, access_token: &str) -> AppResult<OAuthUserInfo>;
}
#[derive(Default)]
pub struct OAuthProviderRegistry {
providers: HashMap<String, Box<dyn OAuthProvider>>,
}
impl OAuthProviderRegistry {
#[must_use]
pub fn new() -> Self {
Self {
providers: HashMap::new(),
}
}
pub fn register(&mut self, provider: Box<dyn OAuthProvider>) {
self.providers.insert(provider.name().to_string(), provider);
}
pub fn get(&self, name: &str) -> Option<&dyn OAuthProvider> {
self.providers.get(name).map(|p| p.as_ref())
}
pub fn provider_names(&self) -> Vec<&str> {
self.providers.keys().map(|s| s.as_str()).collect()
}
}
pub fn generate_code_verifier() -> String {
let mut bytes = [0u8; 32];
getrandom::getrandom(&mut bytes)
.unwrap_or_else(|e| panic!("code_verifier generation failed: {e}"));
URL_SAFE_NO_PAD.encode(bytes)
}
pub fn generate_code_challenge(verifier: &str) -> String {
let digest = Sha256::digest(verifier.as_bytes());
URL_SAFE_NO_PAD.encode(digest)
}
pub fn generate_state() -> String {
crate::utils::id::random_hex(32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn registry_register_and_get() {
let mut reg = OAuthProviderRegistry::new();
assert!(reg.get("github").is_none());
reg.register(Box::new(github::GitHubProvider::new(
"test_id".into(),
"test_secret".into(),
)));
assert!(reg.get("github").is_some());
assert_eq!(reg.provider_names(), vec!["github"]);
}
#[test]
fn pkce_code_challenge_deterministic() {
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
let challenge = generate_code_challenge(verifier);
assert_eq!(challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM");
}
#[test]
fn pkce_verifier_length() {
let verifier = generate_code_verifier();
assert!((43..=128).contains(&verifier.len()));
}
#[test]
fn state_is_hex_64_chars() {
let state = generate_state();
assert_eq!(state.len(), 64);
assert!(state.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn github_authorize_url_format() {
let provider = github::GitHubProvider::new("my_client_id".into(), "secret".into());
let url = provider.authorize_url("state123", "challenge456");
assert!(url.contains("client_id=my_client_id"));
assert!(url.contains("state=state123"));
assert!(url.contains("code_challenge=challenge456"));
assert!(url.contains("code_challenge_method=S256"));
assert!(url.contains("scope=user:email"));
}
#[test]
fn google_authorize_url_format() {
let provider = google::GoogleProvider::new("my_client_id".into(), "secret".into());
let url = provider.authorize_url("state123", "challenge456");
assert!(url.contains("client_id=my_client_id"));
assert!(url.contains("state=state123"));
assert!(url.contains("scope=openid+email+profile"));
}
}