#![cfg(feature = "_oauth-core")]
use async_trait::async_trait;
use std::sync::Arc;
use crate::pool::Pool;
pub mod github;
#[cfg(feature = "oauth-google")]
pub mod google;
#[cfg(feature = "oauth-microsoft")]
pub mod microsoft;
#[cfg(feature = "oauth-oidc")]
pub mod oidc;
#[derive(Debug, Clone)]
pub struct NormalizedProfile {
pub provider_user_id: String,
pub login: String,
pub name: Option<String>,
pub email: Option<String>,
pub avatar_url: Option<String>,
pub html_url: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct OAuthAttempt {
pub csrf_state: String,
pub pkce_verifier: Option<String>,
pub nonce: Option<String>,
}
#[async_trait]
pub trait OAuthProvider: Send + Sync + 'static {
fn name(&self) -> &str;
fn display_name(&self) -> &str;
fn icon_svg(&self) -> Option<&str> {
None
}
fn client_id(&self) -> &str;
fn client_secret(&self) -> &str;
fn redirect_url(&self) -> &str;
fn auth_url(&self) -> &str;
fn token_url(&self) -> &str;
fn scopes(&self) -> &[&str];
async fn fetch_profile(
&self,
http: &reqwest::Client,
access_token: &str,
) -> anyhow::Result<NormalizedProfile>;
fn use_pkce(&self) -> bool {
false
}
fn begin(&self) -> anyhow::Result<(String, OAuthAttempt)> {
use oauth2::{CsrfToken, PkceCodeChallenge, Scope};
let client = basic_client(self)?;
let mut request = client.authorize_url(CsrfToken::new_random);
for scope in self.scopes() {
request = request.add_scope(Scope::new((*scope).to_string()));
}
let pkce_verifier = if self.use_pkce() {
let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();
request = request.set_pkce_challenge(challenge);
Some(verifier.secret().to_string())
} else {
None
};
let (auth_url, csrf_state) = request.url();
Ok((
auth_url.to_string(),
OAuthAttempt {
csrf_state: csrf_state.secret().to_string(),
pkce_verifier,
nonce: None,
},
))
}
async fn finish(
&self,
http: &reqwest::Client,
code: &str,
attempt: &OAuthAttempt,
) -> anyhow::Result<NormalizedProfile> {
use oauth2::{AuthorizationCode, PkceCodeVerifier, TokenResponse};
let client = basic_client(self)?;
let mut request = client.exchange_code(AuthorizationCode::new(code.to_string()));
if let Some(verifier) = attempt.pkce_verifier.as_ref() {
request = request.set_pkce_verifier(PkceCodeVerifier::new(verifier.clone()));
}
let token = request.request_async(http).await?;
self.fetch_profile(http, token.access_token().secret())
.await
}
}
#[derive(Clone)]
pub struct OAuthRegistry {
pub db: Pool,
pub http: reqwest::Client,
providers: Arc<Vec<Arc<dyn OAuthProvider>>>,
}
impl OAuthRegistry {
pub fn new(db: Pool) -> anyhow::Result<Self> {
let http = reqwest::ClientBuilder::new()
.user_agent(concat!("arium/", env!("CARGO_PKG_VERSION")))
.redirect(reqwest::redirect::Policy::none())
.build()?;
Ok(Self {
db,
http,
providers: Arc::new(Vec::new()),
})
}
pub fn with_provider<P: OAuthProvider>(mut self, p: P) -> Self {
let mut v = (*self.providers).clone();
debug_assert!(
v.iter().all(|existing| existing.name() != p.name()),
"OAuthRegistry: duplicate provider name {:?}",
p.name()
);
v.push(Arc::new(p));
self.providers = Arc::new(v);
self
}
pub fn get(&self, name: &str) -> Option<Arc<dyn OAuthProvider>> {
self.providers.iter().find(|p| p.name() == name).cloned()
}
pub fn list(&self) -> &[Arc<dyn OAuthProvider>] {
&self.providers
}
pub fn is_empty(&self) -> bool {
self.providers.is_empty()
}
}
pub async fn upsert_oauth_user(
db: &Pool,
provider: &str,
profile: NormalizedProfile,
) -> anyhow::Result<i64> {
let existing: Option<(i64,)> = sqlx::query_as(
"SELECT user_id FROM oauth_accounts WHERE provider = $1 AND provider_user_id = $2",
)
.bind(provider)
.bind(&profile.provider_user_id)
.fetch_optional(db)
.await?;
if let Some((user_id,)) = existing {
sqlx::query("UPDATE users SET email = $1, avatar_url = $2, html_url = $3 WHERE id = $4")
.bind(profile.email.as_deref())
.bind(profile.avatar_url.as_deref())
.bind(profile.html_url.as_deref())
.bind(user_id)
.execute(db)
.await?;
return Ok(user_id);
}
if let Some(email) = profile.email.as_deref() {
let matched: Option<(i64,)> =
sqlx::query_as("SELECT id FROM users WHERE LOWER(email) = LOWER($1) LIMIT 1")
.bind(email)
.fetch_optional(db)
.await?;
if let Some((user_id,)) = matched {
sqlx::query(
"INSERT INTO oauth_accounts (provider, provider_user_id, user_id) \
VALUES ($1, $2, $3)",
)
.bind(provider)
.bind(&profile.provider_user_id)
.bind(user_id)
.execute(db)
.await?;
sqlx::query(
"UPDATE users SET display_name = COALESCE(display_name, $1), \
avatar_url = $2, html_url = $3 WHERE id = $4",
)
.bind(profile.name.as_deref())
.bind(profile.avatar_url.as_deref())
.bind(profile.html_url.as_deref())
.bind(user_id)
.execute(db)
.await?;
return Ok(user_id);
}
}
let username = crate::auth::unique_username(db, &profile.login).await?;
let (user_id,): (i64,) = sqlx::query_as(
"INSERT INTO users (anonymous, username, display_name, email, avatar_url, html_url, email_verified_at) \
VALUES (false, $1, $2, $3, $4, $5, $6) RETURNING id",
)
.bind(&username)
.bind(profile.name.as_deref())
.bind(profile.email.as_deref())
.bind(profile.avatar_url.as_deref())
.bind(profile.html_url.as_deref())
.bind(unix_now())
.fetch_one(db)
.await?;
sqlx::query(
"INSERT INTO oauth_accounts (provider, provider_user_id, user_id) VALUES ($1, $2, $3)",
)
.bind(provider)
.bind(&profile.provider_user_id)
.bind(user_id)
.execute(db)
.await?;
crate::auth::assign_default_role(db, user_id).await?;
crate::auth::maybe_bootstrap_admin(db, user_id, profile.email.as_deref()).await?;
crate::auth::maybe_grant_first_admin(db, user_id).await?;
Ok(user_id)
}
fn unix_now() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64
}
type BasicClient = oauth2::basic::BasicClient<
oauth2::EndpointSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointSet,
>;
fn basic_client<P: OAuthProvider + ?Sized>(p: &P) -> anyhow::Result<BasicClient> {
use oauth2::basic::BasicClient as Bc;
use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
Ok(Bc::new(ClientId::new(p.client_id().to_string()))
.set_client_secret(ClientSecret::new(p.client_secret().to_string()))
.set_auth_uri(AuthUrl::new(p.auth_url().to_string())?)
.set_token_uri(TokenUrl::new(p.token_url().to_string())?)
.set_redirect_uri(RedirectUrl::new(p.redirect_url().to_string())?))
}
fn http_err<E: std::fmt::Display>(
status: axum::http::StatusCode,
e: E,
) -> (axum::http::StatusCode, String) {
(status, e.to_string())
}
fn oauth_state_key(provider: &str) -> String {
format!("oauth_state:{provider}")
}
#[derive(serde::Deserialize)]
pub(crate) struct CallbackParams {
code: String,
state: String,
}
pub(crate) async fn oauth_login(
axum::extract::State(reg): axum::extract::State<OAuthRegistry>,
axum::extract::Path(provider): axum::extract::Path<String>,
session: crate::extract::SessionStore,
) -> Result<axum::response::Redirect, (axum::http::StatusCode, String)> {
let provider_arc = reg.get(&provider).ok_or_else(|| {
http_err(
axum::http::StatusCode::NOT_FOUND,
format!("unknown oauth provider: {provider}"),
)
})?;
let (auth_url, attempt) = provider_arc
.begin()
.map_err(|e| http_err(axum::http::StatusCode::INTERNAL_SERVER_ERROR, e))?;
session.set(&oauth_state_key(&provider), attempt);
Ok(axum::response::Redirect::to(&auth_url))
}
pub(crate) async fn oauth_callback(
axum::extract::State(reg): axum::extract::State<OAuthRegistry>,
axum::extract::Path(provider): axum::extract::Path<String>,
session: crate::extract::SessionStore,
auth_session: crate::auth::Session,
audit: crate::extract::AuditCtx,
axum::extract::Query(params): axum::extract::Query<CallbackParams>,
) -> Result<axum::response::Redirect, (axum::http::StatusCode, String)> {
let provider_arc = reg.get(&provider).ok_or_else(|| {
http_err(
axum::http::StatusCode::NOT_FOUND,
format!("unknown oauth provider: {provider}"),
)
})?;
let state_key = oauth_state_key(&provider);
let attempt: Option<OAuthAttempt> = session.get(&state_key);
session.remove(&state_key);
let attempt = attempt.ok_or_else(|| {
http_err(
axum::http::StatusCode::BAD_REQUEST,
"missing oauth state in session",
)
})?;
if attempt.csrf_state != params.state {
return Err(http_err(
axum::http::StatusCode::BAD_REQUEST,
"oauth state mismatch",
));
}
let profile = provider_arc
.finish(®.http, ¶ms.code, &attempt)
.await
.map_err(|e| {
http_err(
axum::http::StatusCode::BAD_GATEWAY,
format!("oauth token exchange / profile fetch failed: {e}"),
)
})?;
let user_id = upsert_oauth_user(®.db, provider_arc.name(), profile)
.await
.map_err(|e| http_err(axum::http::StatusCode::INTERNAL_SERVER_ERROR, e))?;
auth_session.login_user(user_id);
audit
.record(
®.db,
crate::auth::audit::USER_LOGIN_SUCCESS,
Some(user_id),
Some(user_id),
Some(&format!(
"{{\"method\":\"oauth\",\"provider\":\"{}\"}}",
provider_arc.name()
)),
)
.await;
Ok(axum::response::Redirect::to("/"))
}