#![allow(dead_code)]
use std::path::PathBuf;
use sqlx::SqlitePool;
use sqlx::sqlite::SqlitePoolOptions;
pub async fn pool() -> SqlitePool {
let pool = SqlitePoolOptions::new()
.max_connections(1)
.connect("sqlite::memory:")
.await
.expect("connect sqlite memory");
run_migrations(&pool).await;
pool
}
async fn run_migrations(pool: &SqlitePool) {
let base = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
apply_dir(pool, base.join("migrations").join("sqlite")).await;
#[cfg(feature = "sql-membership")]
apply_dir(pool, base.join("migrations-membership").join("sqlite")).await;
}
async fn apply_dir(pool: &SqlitePool, dir: PathBuf) {
let mut entries = std::fs::read_dir(&dir)
.unwrap_or_else(|e| panic!("read_dir {}: {e}", dir.display()))
.filter_map(Result::ok)
.map(|e| e.path())
.filter(|p| p.extension().is_some_and(|x| x == "sql"))
.collect::<Vec<_>>();
entries.sort();
for path in entries {
let sql = std::fs::read_to_string(&path)
.unwrap_or_else(|e| panic!("read {}: {e}", path.display()));
let stripped: String = sql
.lines()
.map(|line| match line.find("--") {
Some(idx) => &line[..idx],
None => line,
})
.collect::<Vec<_>>()
.join("\n");
for stmt in stripped.split(';') {
let trimmed = stmt.trim();
if trimmed.is_empty() {
continue;
}
sqlx::query(trimmed)
.execute(pool)
.await
.unwrap_or_else(|e| panic!("migrate {} ({trimmed}): {e}", path.display()));
}
}
}
pub async fn make_user(pool: &SqlitePool, email: &str, password: &str) -> i64 {
let user_id = arium::auth::create_password_user(pool, email, password)
.await
.expect("create_password_user");
sqlx::query("UPDATE users SET email_verified_at = $1 WHERE id = $2")
.bind(now_secs())
.bind(user_id)
.execute(pool)
.await
.expect("verify user");
user_id
}
pub fn now_secs() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
}
pub fn current_totp(secret_base32: &str) -> String {
use totp_rs::{Algorithm, Secret, TOTP};
let bytes = Secret::Encoded(secret_base32.to_string())
.to_bytes()
.expect("decode secret");
let totp =
TOTP::new(Algorithm::SHA1, 6, 1, 30, bytes, None, "".to_string()).expect("totp construct");
totp.generate_current().expect("totp generate")
}
pub struct EnvGuard {
key: &'static str,
prev: Option<String>,
}
impl EnvGuard {
pub fn set(key: &'static str, value: &str) -> Self {
let prev = std::env::var(key).ok();
unsafe {
std::env::set_var(key, value);
}
Self { key, prev }
}
pub fn unset(key: &'static str) -> Self {
let prev = std::env::var(key).ok();
unsafe {
std::env::remove_var(key);
}
Self { key, prev }
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
unsafe {
match &self.prev {
Some(v) => std::env::set_var(self.key, v),
None => std::env::remove_var(self.key),
}
}
}
}
#[cfg(feature = "oauth-github")]
pub mod test_provider {
use arium::oauth::{NormalizedProfile, OAuthProvider};
use async_trait::async_trait;
pub struct TestProvider {
pub name: &'static str,
pub display_name: &'static str,
pub profile: NormalizedProfile,
}
impl TestProvider {
pub fn new(name: &'static str) -> Self {
Self {
name,
display_name: "Test",
profile: NormalizedProfile {
provider_user_id: "1".to_string(),
login: "testuser".to_string(),
name: Some("Test User".to_string()),
email: Some("test@example.invalid".to_string()),
avatar_url: None,
html_url: None,
},
}
}
}
#[async_trait]
impl OAuthProvider for TestProvider {
fn name(&self) -> &str {
self.name
}
fn display_name(&self) -> &str {
self.display_name
}
fn client_id(&self) -> &str {
"test-client-id"
}
fn client_secret(&self) -> &str {
"test-client-secret"
}
fn redirect_url(&self) -> &str {
"http://localhost:8080/auth/test/callback"
}
fn auth_url(&self) -> &str {
"https://example.invalid/authorize"
}
fn token_url(&self) -> &str {
"https://example.invalid/token"
}
fn scopes(&self) -> &[&str] {
&["read:user"]
}
async fn fetch_profile(
&self,
_http: &reqwest::Client,
_access_token: &str,
) -> anyhow::Result<NormalizedProfile> {
Ok(self.profile.clone())
}
}
}
pub mod test_authority {
use arium::ResourceRole;
use arium::authz::{ResourceAuthority, ResourceRef};
use arium::membership::{Membership, MembershipStore, TxExec};
use arium::pool::Pool;
use async_trait::async_trait;
pub struct TableAuthority;
impl TableAuthority {
pub async fn create_table(pool: &Pool) {
sqlx::query(
"CREATE TABLE IF NOT EXISTS test_memberships (\
user_id INTEGER NOT NULL,\
kind TEXT NOT NULL,\
resource_id INTEGER NOT NULL,\
role TEXT NOT NULL,\
PRIMARY KEY (user_id, kind, resource_id))",
)
.execute(pool)
.await
.expect("create test_memberships");
}
pub async fn grant(pool: &Pool, user_id: i64, kind: &str, id: i64, role: &str) {
sqlx::query(
"INSERT INTO test_memberships (user_id, kind, resource_id, role) \
VALUES ($1, $2, $3, $4)",
)
.bind(user_id)
.bind(kind)
.bind(id)
.bind(role)
.execute(pool)
.await
.expect("grant membership");
}
pub async fn revoke(pool: &Pool, user_id: i64, kind: &str, id: i64) {
sqlx::query(
"DELETE FROM test_memberships \
WHERE user_id = $1 AND kind = $2 AND resource_id = $3",
)
.bind(user_id)
.bind(kind)
.bind(id)
.execute(pool)
.await
.expect("revoke membership");
}
}
#[async_trait]
impl ResourceAuthority for TableAuthority {
async fn role_on(
&self,
db: &Pool,
user_id: i64,
r: ResourceRef<'_>,
) -> anyhow::Result<Option<ResourceRole>> {
let role: Option<String> = sqlx::query_scalar(
"SELECT role FROM test_memberships \
WHERE user_id = $1 AND kind = $2 AND resource_id = $3",
)
.bind(user_id)
.bind(r.kind)
.bind(r.id)
.fetch_optional(db)
.await?;
Ok(role.map(|r| ResourceRole::from_str_lossy(&r)))
}
}
#[async_trait]
impl MembershipStore for TableAuthority {
async fn list_members(
&self,
db: &Pool,
r: ResourceRef<'_>,
) -> anyhow::Result<Vec<Membership>> {
let rows: Vec<(i64, String)> = sqlx::query_as(
"SELECT user_id, role FROM test_memberships \
WHERE kind = $1 AND resource_id = $2 ORDER BY user_id",
)
.bind(r.kind)
.bind(r.id)
.fetch_all(db)
.await?;
Ok(rows
.into_iter()
.map(|(user_id, role)| Membership {
user_id,
role: ResourceRole::from_str_lossy(&role),
})
.collect())
}
async fn list_resources_for_user(
&self,
db: &Pool,
user_id: i64,
kind: &str,
min_role: ResourceRole,
) -> anyhow::Result<Vec<i64>> {
let rows: Vec<(i64, String)> = sqlx::query_as(
"SELECT resource_id, role FROM test_memberships \
WHERE user_id = $1 AND kind = $2 ORDER BY resource_id",
)
.bind(user_id)
.bind(kind)
.fetch_all(db)
.await?;
Ok(rows
.into_iter()
.filter(|(_, role)| ResourceRole::from_str_lossy(role).at_least(min_role))
.map(|(id, _)| id)
.collect())
}
async fn role_on_tx(
&self,
tx: &mut TxExec<'_>,
r: ResourceRef<'_>,
user_id: i64,
) -> anyhow::Result<Option<ResourceRole>> {
let role: Option<String> = sqlx::query_scalar(
"SELECT role FROM test_memberships \
WHERE user_id = $1 AND kind = $2 AND resource_id = $3",
)
.bind(user_id)
.bind(r.kind)
.bind(r.id)
.fetch_optional(&mut **tx)
.await?;
Ok(role.map(|r| ResourceRole::from_str_lossy(&r)))
}
async fn count_holders_of_role(
&self,
tx: &mut TxExec<'_>,
r: ResourceRef<'_>,
role: ResourceRole,
) -> anyhow::Result<u64> {
let n: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM test_memberships \
WHERE kind = $1 AND resource_id = $2 AND role = $3",
)
.bind(r.kind)
.bind(r.id)
.bind(role.as_str())
.fetch_one(&mut **tx)
.await?;
Ok(n as u64)
}
async fn upsert_role(
&self,
tx: &mut TxExec<'_>,
r: ResourceRef<'_>,
user_id: i64,
role: ResourceRole,
) -> anyhow::Result<()> {
sqlx::query(
"INSERT INTO test_memberships (user_id, kind, resource_id, role) \
VALUES ($1, $2, $3, $4) \
ON CONFLICT (user_id, kind, resource_id) DO UPDATE SET role = excluded.role",
)
.bind(user_id)
.bind(r.kind)
.bind(r.id)
.bind(role.as_str())
.execute(&mut **tx)
.await?;
Ok(())
}
async fn remove_role(
&self,
tx: &mut TxExec<'_>,
r: ResourceRef<'_>,
user_id: i64,
) -> anyhow::Result<()> {
sqlx::query(
"DELETE FROM test_memberships \
WHERE user_id = $1 AND kind = $2 AND resource_id = $3",
)
.bind(user_id)
.bind(r.kind)
.bind(r.id)
.execute(&mut **tx)
.await?;
Ok(())
}
}
pub struct FailingAuthority;
#[async_trait]
impl ResourceAuthority for FailingAuthority {
async fn role_on(
&self,
_db: &Pool,
_user_id: i64,
_r: ResourceRef<'_>,
) -> anyhow::Result<Option<ResourceRole>> {
Err(anyhow::anyhow!("simulated storage failure"))
}
}
}