use std::sync::Arc;
use async_trait::async_trait;
use axum::http::request::Parts;
use crate::sql::sqlx;
use crate::sql::Pool;
use super::auth::parse_basic_auth;
use super::password;
#[derive(Debug, Clone)]
pub struct AuthUser {
pub id: i64,
pub username: String,
pub is_superuser: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("database error: {0}")]
Database(#[from] sqlx::Error),
#[error("database error: {0}")]
Exec(#[from] crate::sql::ExecError),
#[error("token is malformed or expired")]
InvalidToken,
#[error("account is inactive")]
Inactive,
}
#[async_trait]
pub trait AuthBackend: Send + Sync {
async fn authenticate(&self, parts: &Parts, pool: &Pool)
-> Result<Option<AuthUser>, AuthError>;
}
pub type BoxedBackend = Arc<dyn AuthBackend>;
pub struct ModelBackend;
#[async_trait]
impl AuthBackend for ModelBackend {
async fn authenticate(
&self,
parts: &Parts,
pool: &Pool,
) -> Result<Option<AuthUser>, AuthError> {
use crate::core::Column as _;
use crate::sql::FetcherPool as _;
let auth_header = parts
.headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
let (username, password) = match parse_basic_auth(auth_header) {
Some(pair) => pair,
None => return Ok(None),
};
let users = super::auth::User::objects()
.where_(super::auth::User::username.eq(username.clone()))
.fetch_pool(pool)
.await?;
let Some(user) = users.into_iter().next() else {
return Ok(None);
};
if !user.active {
return Err(AuthError::Inactive);
}
let ok = password::verify(&password, &user.password_hash)
.map_err(|_| AuthError::InvalidToken)?;
if !ok {
return Ok(None);
}
Ok(Some(AuthUser {
id: user.id.get().copied().unwrap_or(0),
username: user.username,
is_superuser: user.is_superuser,
}))
}
}
#[derive(crate::Model, Debug, Clone)]
#[rustango(
table = "rustango_api_keys",
admin(
list_display = "user_id, key_prefix, label, expires_at, created_at",
ordering = "-created_at",
readonly_fields = "key_prefix, key_hash, created_at",
)
)]
pub struct ApiKey {
#[rustango(primary_key)]
pub id: crate::sql::Auto<i64>,
pub user_id: i64,
#[rustango(max_length = 8)]
pub key_prefix: String,
#[rustango(max_length = 255)]
pub key_hash: String,
#[rustango(max_length = 100)]
pub label: String,
pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
#[rustango(auto_now_add)]
pub created_at: crate::sql::Auto<chrono::DateTime<chrono::Utc>>,
}
const API_KEY_ENSURE_SQL: &str = r#"
CREATE TABLE IF NOT EXISTS "rustango_api_keys" (
"id" BIGSERIAL PRIMARY KEY,
"user_id" BIGINT NOT NULL
REFERENCES "rustango_users"("id")
ON DELETE CASCADE,
"key_prefix" VARCHAR(8) NOT NULL,
"key_hash" VARCHAR(255) NOT NULL,
"label" VARCHAR(100) NOT NULL DEFAULT '',
"expires_at" TIMESTAMPTZ,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT "rustango_api_keys_prefix_uq" UNIQUE ("key_prefix")
);
"#;
const API_KEY_ENSURE_SQL_SQLITE: &str = r#"
CREATE TABLE IF NOT EXISTS "rustango_api_keys" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"user_id" INTEGER NOT NULL
REFERENCES "rustango_users"("id")
ON DELETE CASCADE,
"key_prefix" TEXT NOT NULL,
"key_hash" TEXT NOT NULL,
"label" TEXT NOT NULL DEFAULT '',
"expires_at" TEXT,
"created_at" TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "rustango_api_keys_prefix_uq" UNIQUE ("key_prefix")
);
"#;
const API_KEY_ENSURE_SQL_MYSQL: &str = r#"
CREATE TABLE IF NOT EXISTS `rustango_api_keys` (
`id` BIGINT AUTO_INCREMENT PRIMARY KEY,
`user_id` BIGINT NOT NULL,
`key_prefix` VARCHAR(8) NOT NULL,
`key_hash` VARCHAR(255) NOT NULL,
`label` VARCHAR(100) NOT NULL DEFAULT '',
`expires_at` DATETIME(6),
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
CONSTRAINT `rustango_api_keys_prefix_uq` UNIQUE (`key_prefix`),
CONSTRAINT `rustango_api_keys_fk_user`
FOREIGN KEY (`user_id`) REFERENCES `rustango_users`(`id`) ON DELETE CASCADE
);
"#;
#[cfg(feature = "postgres")]
pub async fn ensure_api_keys_table(pool: &crate::sql::sqlx::PgPool) -> Result<(), sqlx::Error> {
for stmt in API_KEY_ENSURE_SQL
.split(';')
.map(str::trim)
.filter(|s| !s.is_empty())
{
sqlx::query(stmt).execute(pool).await?;
}
Ok(())
}
pub async fn ensure_api_keys_table_pool(pool: &Pool) -> Result<(), sqlx::Error> {
let ddl = match pool.dialect().name() {
"sqlite" => API_KEY_ENSURE_SQL_SQLITE,
"mysql" => API_KEY_ENSURE_SQL_MYSQL,
_ => API_KEY_ENSURE_SQL,
};
for stmt in ddl.split(';').map(str::trim).filter(|s| !s.is_empty()) {
crate::sql::raw_execute_pool(pool, stmt, Vec::new())
.await
.map_err(|e| match e {
crate::sql::ExecError::Driver(err) => err,
other => sqlx::Error::Protocol(format!("{other}")),
})?;
}
Ok(())
}
pub struct ApiKeyBackend;
#[async_trait]
impl AuthBackend for ApiKeyBackend {
async fn authenticate(
&self,
parts: &Parts,
pool: &Pool,
) -> Result<Option<AuthUser>, AuthError> {
use crate::core::Column as _;
use crate::sql::FetcherPool as _;
let bearer = extract_bearer(parts)?;
let Some(token) = bearer else {
return Ok(None);
};
let (prefix, secret) = match token.split_once('.') {
Some(p) => p,
None => return Ok(None), };
if prefix.len() != 8 {
return Ok(None);
}
let keys = ApiKey::objects()
.where_(ApiKey::key_prefix.eq(prefix.to_owned()))
.fetch_pool(pool)
.await?;
let Some(key) = keys.into_iter().next() else {
return Ok(None);
};
if let Some(exp) = key.expires_at {
if chrono::Utc::now() > exp {
return Err(AuthError::InvalidToken);
}
}
let ok = password::verify(secret, &key.key_hash).map_err(|_| AuthError::InvalidToken)?;
if !ok {
return Ok(None);
}
let users = super::auth::User::objects()
.where_(super::auth::User::id.eq(key.user_id))
.fetch_pool(pool)
.await?;
let Some(user) = users.into_iter().next() else {
return Ok(None);
};
if !user.active {
return Err(AuthError::Inactive);
}
Ok(Some(AuthUser {
id: user.id.get().copied().unwrap_or(0),
username: user.username,
is_superuser: user.is_superuser,
}))
}
}
pub async fn create_api_key(
user_id: i64,
label: &str,
expires_at: Option<chrono::DateTime<chrono::Utc>>,
pool: &Pool,
) -> Result<String, crate::tenancy::error::TenancyError> {
use crate::sql::Auto;
use rand::Rng;
let mut rng = rand::thread_rng();
let prefix_bytes: [u8; 4] = rng.gen();
let prefix = to_hex(&prefix_bytes);
let secret_bytes: [u8; 16] = rng.gen();
let secret = to_hex(&secret_bytes);
let hash = password::hash(&secret)
.map_err(|e| crate::tenancy::error::TenancyError::Validation(e.to_string()))?;
let mut key = ApiKey {
id: Auto::default(),
user_id,
key_prefix: prefix.clone(),
key_hash: hash,
label: label.to_owned(),
expires_at,
created_at: Auto::default(),
};
key.save_pool(pool).await?;
Ok(format!("{prefix}.{secret}"))
}
pub struct JwtBackend {
secret: Vec<u8>,
pub ttl_secs: i64,
}
impl JwtBackend {
#[must_use]
pub fn new(secret: Vec<u8>) -> Self {
Self {
secret,
ttl_secs: 3600,
}
}
#[cfg(feature = "postgres")]
#[must_use]
pub fn from_session_secret(s: &super::operator_console::SessionSecret) -> Self {
Self::new(s.key().to_vec())
}
#[must_use]
pub fn issue(&self, user_id: i64) -> String {
use base64::Engine;
let exp = chrono::Utc::now().timestamp() + self.ttl_secs;
let payload = serde_json::json!({"sub": user_id, "exp": exp});
let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(serde_json::to_vec(&payload).unwrap_or_default());
let sig = hmac_sha256(&self.secret, payload_b64.as_bytes());
let sig_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(sig);
format!("{payload_b64}.{sig_b64}")
}
fn verify_token(&self, token: &str) -> Option<i64> {
use base64::Engine;
use subtle::ConstantTimeEq;
let (payload_b64, sig_b64) = token.split_once('.')?;
let expected = hmac_sha256(&self.secret, payload_b64.as_bytes());
let provided = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(sig_b64)
.ok()?;
if expected.ct_eq(&provided[..]).unwrap_u8() == 0 {
return None;
}
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(payload_b64)
.ok()?;
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).ok()?;
let exp = payload.get("exp")?.as_i64()?;
if chrono::Utc::now().timestamp() >= exp {
return None; }
payload.get("sub")?.as_i64()
}
}
fn hmac_sha256(secret: &[u8], msg: &[u8]) -> [u8; 32] {
use hmac::{Hmac, Mac};
use sha2::Sha256;
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC accepts any key length");
mac.update(msg);
let bytes = mac.finalize().into_bytes();
let mut out = [0u8; 32];
out.copy_from_slice(&bytes[..32]);
out
}
#[async_trait]
impl AuthBackend for JwtBackend {
async fn authenticate(
&self,
parts: &Parts,
pool: &Pool,
) -> Result<Option<AuthUser>, AuthError> {
use crate::core::Column as _;
use crate::sql::FetcherPool as _;
let bearer = extract_bearer(parts)?;
let Some(token) = bearer else {
return Ok(None);
};
if token.chars().filter(|&c| c == '.').count() != 1 {
return Ok(None);
}
if token.split_once('.').map(|(p, _)| p.len()) == Some(8) {
return Ok(None);
}
let user_id = match self.verify_token(token) {
Some(id) => id,
None => return Err(AuthError::InvalidToken),
};
let users = super::auth::User::objects()
.where_(super::auth::User::id.eq(user_id))
.fetch_pool(pool)
.await?;
let Some(user) = users.into_iter().next() else {
return Ok(None);
};
if !user.active {
return Err(AuthError::Inactive);
}
Ok(Some(AuthUser {
id: user.id.get().copied().unwrap_or(0),
username: user.username,
is_superuser: user.is_superuser,
}))
}
}
fn to_hex(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
fn extract_bearer(parts: &Parts) -> Result<Option<&str>, AuthError> {
let Some(value) = parts.headers.get(axum::http::header::AUTHORIZATION) else {
return Ok(None);
};
let s = value.to_str().unwrap_or("");
Ok(s.strip_prefix("Bearer ").map(str::trim))
}