use base64::Engine;
use crate::core::Column as _;
use crate::sql::sqlx::{PgConnection, PgPool};
use crate::sql::{Auto, Fetcher};
use crate::Model;
use super::error::TenancyError;
use super::password;
#[derive(Model, Debug, Clone)]
#[rustango(table = "rustango_operators", display = "username")]
#[allow(dead_code)]
pub struct Operator {
#[rustango(primary_key)]
pub id: rustango::sql::Auto<i64>,
#[rustango(max_length = 64)]
pub username: String,
#[rustango(max_length = 255)]
pub password_hash: String,
pub active: bool,
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Model, Debug, Clone)]
#[rustango(table = "rustango_users", display = "username")]
#[allow(dead_code)]
pub struct User {
#[rustango(primary_key)]
pub id: rustango::sql::Auto<i64>,
#[rustango(max_length = 64)]
pub username: String,
#[rustango(max_length = 255)]
pub password_hash: String,
pub is_superuser: bool,
pub active: bool,
pub created_at: chrono::DateTime<chrono::Utc>,
}
pub async fn authenticate_operator(
registry: &PgPool,
username: &str,
password: &str,
) -> Result<Option<Operator>, TenancyError> {
let rows: Vec<Operator> = Operator::objects()
.where_(Operator::username.eq(username.to_owned()))
.fetch(registry)
.await?;
let Some(op) = rows.into_iter().next() else {
return Ok(None);
};
if !op.active {
return Ok(None);
}
if !password::verify(password, &op.password_hash)? {
return Ok(None);
}
Ok(Some(op))
}
pub async fn authenticate_user(
conn: &mut PgConnection,
username: &str,
password: &str,
) -> Result<Option<User>, TenancyError> {
use crate::sql::sqlx::Row;
let user_rows = rustango::sql::sqlx::query(
"SELECT id, username, password_hash, is_superuser, active, created_at \
FROM rustango_users WHERE username = $1",
)
.bind(username)
.fetch_optional(&mut *conn)
.await?;
let Some(row) = user_rows else {
return Ok(None);
};
let user = User {
id: Auto::Set(row.try_get::<i64, _>("id")?),
username: row.try_get::<String, _>("username")?,
password_hash: row.try_get::<String, _>("password_hash")?,
is_superuser: row.try_get::<bool, _>("is_superuser")?,
active: row.try_get::<bool, _>("active")?,
created_at: row.try_get::<chrono::DateTime<chrono::Utc>, _>("created_at")?,
};
if !user.active {
return Ok(None);
}
if !password::verify(password, &user.password_hash)? {
return Ok(None);
}
Ok(Some(user))
}
#[must_use]
pub fn parse_basic_auth(header_value: Option<&str>) -> Option<(String, String)> {
let raw = header_value?;
let encoded = raw.strip_prefix("Basic ")?;
let decoded = base64::engine::general_purpose::STANDARD
.decode(encoded.trim())
.ok()?;
let s = String::from_utf8(decoded).ok()?;
let (user, pass) = s.split_once(':')?;
Some((user.to_owned(), pass.to_owned()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_basic_auth_decodes_standard_format() {
let v = "Basic YWxpY2U6aHVudGVyMg==";
let (u, p) = parse_basic_auth(Some(v)).unwrap();
assert_eq!(u, "alice");
assert_eq!(p, "hunter2");
}
#[test]
fn parse_basic_auth_rejects_non_basic_scheme() {
assert!(parse_basic_auth(Some("Bearer tokenhere")).is_none());
assert!(parse_basic_auth(Some("Digest qop=auth")).is_none());
}
#[test]
fn parse_basic_auth_rejects_missing_colon() {
let v = "Basic bm8tY29sb24taGVyZQ==";
assert!(parse_basic_auth(Some(v)).is_none());
}
#[test]
fn parse_basic_auth_handles_none_header() {
assert!(parse_basic_auth(None).is_none());
}
}