#![cfg(feature = "oauth-github")]
mod common;
use arium::oauth::{NormalizedProfile, OAuthProvider};
use arium::{AuditConfig, AuthConfig, Mailer};
use async_trait::async_trait;
use axum::Router;
use reqwest::Client;
use reqwest::redirect::Policy;
use serde_json::json;
use sqlx::SqlitePool;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
struct MockProvider {
name: &'static str,
auth_url: String,
token_url: String,
redirect_url: String,
profile: NormalizedProfile,
use_pkce: bool,
}
#[async_trait]
impl OAuthProvider for MockProvider {
fn name(&self) -> &str {
self.name
}
fn display_name(&self) -> &str {
"Test"
}
fn client_id(&self) -> &str {
"test-client-id"
}
fn client_secret(&self) -> &str {
"test-client-secret"
}
fn redirect_url(&self) -> &str {
&self.redirect_url
}
fn auth_url(&self) -> &str {
&self.auth_url
}
fn token_url(&self) -> &str {
&self.token_url
}
fn scopes(&self) -> &[&str] {
&["read:user", "user:email"]
}
fn use_pkce(&self) -> bool {
self.use_pkce
}
async fn fetch_profile(
&self,
_http: &reqwest::Client,
_access_token: &str,
) -> anyhow::Result<NormalizedProfile> {
Ok(self.profile.clone())
}
}
struct TestApp {
pool: SqlitePool,
base_url: String,
client: Client,
_serve: tokio::task::JoinHandle<()>,
}
async fn boot(mock_token_url: &str, profile: NormalizedProfile) -> TestApp {
boot_inner(mock_token_url, profile, false).await
}
async fn boot_inner(mock_token_url: &str, profile: NormalizedProfile, use_pkce: bool) -> TestApp {
let pool = common::pool().await;
let mailer = Mailer::from_env().expect("mailer build");
let provider = MockProvider {
name: "test",
auth_url: "https://example.invalid/authorize".to_string(),
token_url: mock_token_url.to_string(),
redirect_url: "http://127.0.0.1/auth/test/callback".to_string(),
profile,
use_pkce,
};
let cfg = AuthConfig::builder(pool.clone(), mailer)
.oauth_provider(provider)
.unwrap()
.rate_limit(None)
.audit(AuditConfig {
capture_ip: false,
capture_user_agent: false,
retention_days: 0,
})
.build()
.unwrap();
let router: Router = arium::install(Router::new(), cfg).await.expect("install");
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let addr: SocketAddr = listener.local_addr().expect("local_addr");
let base_url = format!("http://{addr}");
let serve = tokio::spawn(async move {
let _ = axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.await;
});
let client = Client::builder()
.cookie_store(true)
.redirect(Policy::none())
.build()
.expect("client");
TestApp {
pool,
base_url,
client,
_serve: serve,
}
}
fn standard_profile() -> NormalizedProfile {
NormalizedProfile {
provider_user_id: "ext-1".to_string(),
login: "testuser".to_string(),
name: Some("Test User".to_string()),
email: Some("test@example.invalid".to_string()),
avatar_url: None,
html_url: None,
}
}
#[tokio::test]
async fn login_redirects_to_authorize_url_with_state_and_scopes() {
let app = boot("http://localhost:1/unused", standard_profile()).await;
let resp = app
.client
.get(format!("{}/auth/test/login", app.base_url))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 303);
let loc = resp.headers().get("location").unwrap().to_str().unwrap();
let parsed = url::Url::parse(loc).expect("location is a URL");
assert_eq!(parsed.host_str(), Some("example.invalid"));
assert_eq!(parsed.path(), "/authorize");
let q: std::collections::HashMap<_, _> = parsed.query_pairs().collect();
assert_eq!(q.get("response_type").map(|s| s.as_ref()), Some("code"));
assert_eq!(
q.get("client_id").map(|s| s.as_ref()),
Some("test-client-id")
);
assert!(q.contains_key("state"), "state must be present");
let scope = q.get("scope").map(|s| s.to_string()).unwrap_or_default();
assert!(scope.contains("read:user"), "scope={scope:?}");
assert!(scope.contains("user:email"), "scope={scope:?}");
}
#[tokio::test]
async fn callback_happy_path_creates_user_and_records_audit_event() {
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "test-access-token",
"token_type": "bearer",
"scope": "read:user",
})))
.expect(1)
.mount(&mock)
.await;
let app = boot(&format!("{}/token", mock.uri()), standard_profile()).await;
let resp = app
.client
.get(format!("{}/auth/test/login", app.base_url))
.send()
.await
.unwrap();
let loc = resp.headers().get("location").unwrap().to_str().unwrap();
let parsed = url::Url::parse(loc).unwrap();
let state = parsed
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.unwrap();
let resp = app
.client
.get(format!(
"{}/auth/test/callback?code=fake-code&state={state}",
app.base_url
))
.send()
.await
.unwrap();
assert_eq!(
resp.status().as_u16(),
303,
"callback success redirects to /"
);
assert_eq!(
resp.headers().get("location").unwrap().to_str().unwrap(),
"/"
);
let user_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM users WHERE anonymous = false")
.fetch_one(&app.pool)
.await
.unwrap();
assert_eq!(user_count, 1);
let oa: (String, String) =
sqlx::query_as("SELECT provider, provider_user_id FROM oauth_accounts LIMIT 1")
.fetch_one(&app.pool)
.await
.unwrap();
assert_eq!(oa, ("test".to_string(), "ext-1".to_string()));
let row: (String, Option<String>) = sqlx::query_as(
"SELECT event_type, details FROM audit_events \
WHERE event_type = 'user.login.success' LIMIT 1",
)
.fetch_one(&app.pool)
.await
.unwrap();
assert_eq!(row.0, "user.login.success");
let details = row.1.unwrap_or_default();
assert!(details.contains("\"method\":\"oauth\""), "{details}");
assert!(details.contains("\"provider\":\"test\""), "{details}");
}
#[tokio::test]
async fn callback_with_no_session_returns_400_missing_state() {
let app = boot("http://localhost:1/unused", standard_profile()).await;
let resp = app
.client
.get(format!(
"{}/auth/test/callback?code=fake&state=anything",
app.base_url
))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 400);
let body = resp.text().await.unwrap();
assert!(body.contains("missing oauth state"), "body={body:?}");
}
#[tokio::test]
async fn callback_with_state_mismatch_returns_400() {
let app = boot("http://localhost:1/unused", standard_profile()).await;
app.client
.get(format!("{}/auth/test/login", app.base_url))
.send()
.await
.unwrap();
let resp = app
.client
.get(format!(
"{}/auth/test/callback?code=fake&state=not-the-real-state",
app.base_url
))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 400);
let body = resp.text().await.unwrap();
assert!(body.contains("state mismatch"), "body={body:?}");
}
#[tokio::test]
async fn callback_state_is_consumed_after_one_attempt() {
let app = boot("http://localhost:1/unused", standard_profile()).await;
let resp = app
.client
.get(format!("{}/auth/test/login", app.base_url))
.send()
.await
.unwrap();
let loc = resp.headers().get("location").unwrap().to_str().unwrap();
let state = url::Url::parse(loc)
.unwrap()
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.unwrap();
let _ = app
.client
.get(format!(
"{}/auth/test/callback?code=x&state=wrong",
app.base_url
))
.send()
.await
.unwrap();
let resp = app
.client
.get(format!(
"{}/auth/test/callback?code=x&state={state}",
app.base_url
))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 400);
let body = resp.text().await.unwrap();
assert!(body.contains("missing oauth state"), "body={body:?}");
}
#[tokio::test]
async fn callback_when_token_endpoint_returns_500_returns_502() {
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/token"))
.respond_with(ResponseTemplate::new(500).set_body_string("upstream boom"))
.mount(&mock)
.await;
let app = boot(&format!("{}/token", mock.uri()), standard_profile()).await;
let resp = app
.client
.get(format!("{}/auth/test/login", app.base_url))
.send()
.await
.unwrap();
let state = url::Url::parse(resp.headers().get("location").unwrap().to_str().unwrap())
.unwrap()
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.unwrap();
let resp = app
.client
.get(format!(
"{}/auth/test/callback?code=x&state={state}",
app.base_url
))
.send()
.await
.unwrap();
assert_eq!(
resp.status().as_u16(),
502,
"provider 5xx must surface as Bad Gateway",
);
}
#[tokio::test]
async fn callback_with_malformed_token_body_returns_502() {
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/token"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("this is not json")
.insert_header("content-type", "application/json"),
)
.mount(&mock)
.await;
let app = boot(&format!("{}/token", mock.uri()), standard_profile()).await;
let resp = app
.client
.get(format!("{}/auth/test/login", app.base_url))
.send()
.await
.unwrap();
let state = url::Url::parse(resp.headers().get("location").unwrap().to_str().unwrap())
.unwrap()
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.unwrap();
let resp = app
.client
.get(format!(
"{}/auth/test/callback?code=x&state={state}",
app.base_url
))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 502);
}
#[tokio::test]
async fn callback_for_unknown_provider_returns_404() {
let app = boot("http://localhost:1/unused", standard_profile()).await;
let resp = app
.client
.get(format!(
"{}/auth/nosuch/callback?code=x&state=y",
app.base_url
))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 404);
}
#[tokio::test]
async fn login_for_unknown_provider_returns_404() {
let app = boot("http://localhost:1/unused", standard_profile()).await;
let resp = app
.client
.get(format!("{}/auth/nosuch/login", app.base_url))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 404);
}
#[tokio::test]
async fn token_request_carries_client_credentials_and_code() {
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/token"))
.and(wiremock::matchers::body_string_contains(
"grant_type=authorization_code",
))
.and(wiremock::matchers::body_string_contains("code=fake-code"))
.and(wiremock::matchers::header(
"authorization",
"Basic dGVzdC1jbGllbnQtaWQ6dGVzdC1jbGllbnQtc2VjcmV0",
))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "tok",
"token_type": "bearer",
})))
.expect(1)
.mount(&mock)
.await;
let app = boot(&format!("{}/token", mock.uri()), standard_profile()).await;
let resp = app
.client
.get(format!("{}/auth/test/login", app.base_url))
.send()
.await
.unwrap();
let state = url::Url::parse(resp.headers().get("location").unwrap().to_str().unwrap())
.unwrap()
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.unwrap();
let resp = app
.client
.get(format!(
"{}/auth/test/callback?code=fake-code&state={state}",
app.base_url
))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 303);
drop(app);
drop(mock);
}
#[tokio::test]
async fn pkce_off_by_default_omits_code_challenge() {
let app = boot("http://localhost:1/unused", standard_profile()).await;
let resp = app
.client
.get(format!("{}/auth/test/login", app.base_url))
.send()
.await
.unwrap();
let loc = resp.headers().get("location").unwrap().to_str().unwrap();
assert!(
!loc.contains("code_challenge"),
"no PKCE expected by default, got {loc}"
);
}
#[tokio::test]
async fn pkce_round_trips_challenge_in_authorize_and_verifier_at_token() {
let mock = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/token"))
.and(wiremock::matchers::body_string_contains("code_verifier="))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": "test-access-token",
"token_type": "bearer",
})))
.expect(1)
.mount(&mock)
.await;
let app = boot_inner(&format!("{}/token", mock.uri()), standard_profile(), true).await;
let resp = app
.client
.get(format!("{}/auth/test/login", app.base_url))
.send()
.await
.unwrap();
let loc = resp.headers().get("location").unwrap().to_str().unwrap();
let parsed = url::Url::parse(loc).unwrap();
let q: std::collections::HashMap<_, _> = parsed.query_pairs().collect();
assert!(q.contains_key("code_challenge"), "loc={loc}");
assert_eq!(
q.get("code_challenge_method").map(|s| s.as_ref()),
Some("S256")
);
let state = q.get("state").map(|s| s.to_string()).unwrap();
let resp = app
.client
.get(format!(
"{}/auth/test/callback?code=fake-code&state={state}",
app.base_url
))
.send()
.await
.unwrap();
assert_eq!(resp.status().as_u16(), 303);
drop(app);
drop(mock);
}
#[tokio::test]
async fn github_provider_pkce_opt_in_adds_code_challenge() {
use arium::oauth::github::GithubProvider;
let off = GithubProvider::new(
"id".to_string(),
"secret".to_string(),
"http://localhost/cb".to_string(),
);
let (url_off, attempt_off) = off.begin().unwrap();
assert!(attempt_off.pkce_verifier.is_none());
assert!(!url_off.contains("code_challenge"), "url_off={url_off}");
let on = GithubProvider::new(
"id".to_string(),
"secret".to_string(),
"http://localhost/cb".to_string(),
)
.with_pkce(true);
let (url_on, attempt_on) = on.begin().unwrap();
assert!(attempt_on.pkce_verifier.is_some());
assert!(url_on.contains("code_challenge="), "url_on={url_on}");
assert!(
url_on.contains("code_challenge_method=S256"),
"url_on={url_on}"
);
}