use std::collections::HashMap;
use std::sync::Arc;
use crate::social::core::{OAuthProvider, SocialAuthError, StandardClaims, TokenResponse};
use crate::social::flow::{InMemoryStateStore, StateData, StateStore};
pub struct AuthorizationResult {
pub authorization_url: String,
pub state: String,
pub nonce: Option<String>,
pub code_verifier: Option<String>,
}
pub struct CallbackResult {
pub token_response: TokenResponse,
pub claims: Option<StandardClaims>,
}
pub struct SocialAuthBackend {
providers: HashMap<String, Arc<dyn OAuthProvider>>,
state_store: Arc<dyn StateStore>,
}
impl SocialAuthBackend {
pub fn new() -> Self {
Self {
providers: HashMap::new(),
state_store: Arc::new(InMemoryStateStore::new()),
}
}
pub fn with_state_store(state_store: Arc<dyn StateStore>) -> Self {
Self {
providers: HashMap::new(),
state_store,
}
}
pub fn register_provider(&mut self, provider: Arc<dyn OAuthProvider>) {
self.providers.insert(provider.name().to_string(), provider);
}
pub fn get_provider(&self, name: &str) -> Option<&Arc<dyn OAuthProvider>> {
self.providers.get(name)
}
pub fn provider_names(&self) -> Vec<&str> {
self.providers.keys().map(|s| s.as_str()).collect()
}
pub async fn begin_auth(
&self,
provider_name: &str,
code_challenge: Option<&str>,
code_verifier: Option<String>,
) -> Result<AuthorizationResult, SocialAuthError> {
let provider = self.providers.get(provider_name).ok_or_else(|| {
SocialAuthError::Provider(format!("Provider not registered: {}", provider_name))
})?;
let state = generate_random_string(32);
let nonce = if provider.is_oidc() {
Some(generate_random_string(32))
} else {
None
};
let authorization_url = provider
.authorization_url(&state, nonce.as_deref(), code_challenge)
.await?;
let state_data = StateData::new(state.clone(), nonce.clone(), code_verifier.clone());
self.state_store.store(state_data).await?;
Ok(AuthorizationResult {
authorization_url,
state,
nonce,
code_verifier,
})
}
pub async fn handle_callback(
&self,
provider_name: &str,
code: &str,
state: &str,
) -> Result<CallbackResult, SocialAuthError> {
let provider = self.providers.get(provider_name).ok_or_else(|| {
SocialAuthError::Provider(format!("Provider not registered: {}", provider_name))
})?;
let state_data = self.state_store.retrieve(state).await?;
self.state_store.remove(state).await?;
let token_response = provider
.exchange_code(code, state_data.code_verifier.as_deref())
.await?;
let claims = if provider.is_oidc() {
if let Some(id_token_str) = &token_response.id_token {
let id_token = provider
.validate_id_token(id_token_str, state_data.nonce.as_deref())
.await?;
Some(StandardClaims::from(id_token))
} else {
provider
.get_user_info(&token_response.access_token)
.await
.inspect_err(|e| {
tracing::warn!(
provider = %provider_name,
error = %e,
"Failed to fetch user info from OIDC UserInfo fallback; claims will be None",
)
})
.ok()
}
} else {
provider
.get_user_info(&token_response.access_token)
.await
.inspect_err(|e| {
tracing::warn!(
provider = %provider_name,
error = %e,
"Failed to fetch user info from OAuth2 provider; claims will be None",
)
})
.ok()
};
Ok(CallbackResult {
token_response,
claims,
})
}
}
impl Default for SocialAuthBackend {
fn default() -> Self {
Self::new()
}
}
fn generate_random_string(length: usize) -> String {
use rand::Rng;
rand::rng()
.sample_iter(&rand::distr::Alphanumeric)
.take(length)
.map(char::from)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_creation() {
let backend = SocialAuthBackend::new();
assert!(backend.provider_names().is_empty());
}
#[test]
fn test_backend_default() {
let backend = SocialAuthBackend::default();
assert!(backend.provider_names().is_empty());
}
#[test]
fn test_get_nonexistent_provider() {
let backend = SocialAuthBackend::new();
let provider = backend.get_provider("nonexistent");
assert!(provider.is_none());
}
}