use std::collections::HashMap;
use std::sync::Arc;
use subtle::ConstantTimeEq;
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
a.ct_eq(b).into()
}
#[derive(Debug, Clone)]
pub struct AuthResult {
pub username: String,
pub permissions: Vec<String>,
}
#[derive(Debug, Clone)]
pub enum AuthError {
Unauthorized(String),
Forbidden(String),
}
impl std::fmt::Display for AuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unauthorized(msg) => write!(f, "Unauthorized: {msg}"),
Self::Forbidden(msg) => write!(f, "Forbidden: {msg}"),
}
}
}
impl std::error::Error for AuthError {}
impl AuthError {
pub fn code(&self) -> i32 {
match self {
Self::Unauthorized(_) => 401,
Self::Forbidden(_) => 403,
}
}
pub fn message(&self) -> &str {
match self {
Self::Unauthorized(msg) | Self::Forbidden(msg) => msg,
}
}
}
pub trait AuthValidator: Send + Sync {
fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError>;
}
pub const PERM_STREAMING: &str = "Streaming";
#[derive(Debug, Clone)]
pub struct Role {
pub name: String,
pub permissions: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct User {
pub name: String,
pub password: String,
pub role: String,
}
#[derive(Debug, Clone)]
pub struct StaticAuthValidator {
users: HashMap<String, (String, Arc<Role>)>, }
impl StaticAuthValidator {
pub fn new(users: Vec<User>, roles: Vec<Role>) -> Self {
let role_map: HashMap<String, Arc<Role>> = roles
.into_iter()
.map(|r| (r.name.clone(), Arc::new(r)))
.collect();
let empty_role = Arc::new(Role {
name: String::new(),
permissions: vec![],
});
let user_map = users
.into_iter()
.map(|u| {
let role = role_map
.get(&u.role)
.cloned()
.unwrap_or_else(|| empty_role.clone());
(u.name, (u.password, role))
})
.collect();
Self { users: user_map }
}
}
impl AuthValidator for StaticAuthValidator {
fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError> {
if !scheme.eq_ignore_ascii_case("basic") {
return Err(AuthError::Unauthorized(format!(
"Unsupported auth scheme: {scheme}"
)));
}
use base64::Engine;
let decoded = base64::engine::general_purpose::STANDARD
.decode(param)
.map_err(|_| AuthError::Unauthorized("Invalid base64".into()))?;
let decoded = String::from_utf8(decoded)
.map_err(|_| AuthError::Unauthorized("Invalid UTF-8".into()))?;
let (username, password) = decoded
.split_once(':')
.ok_or_else(|| AuthError::Unauthorized("Expected user:password".into()))?;
let (stored_pw, role) = self
.users
.get(username)
.ok_or_else(|| AuthError::Unauthorized("Unknown user".into()))?;
if !constant_time_eq(stored_pw.as_bytes(), password.as_bytes()) {
return Err(AuthError::Unauthorized("Wrong password".into()));
}
Ok(AuthResult {
username: username.to_string(),
permissions: role.permissions.clone(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_validator() -> StaticAuthValidator {
StaticAuthValidator::new(
vec![
User {
name: "admin".into(),
password: "secret".into(),
role: "full".into(),
},
User {
name: "player".into(),
password: "play".into(),
role: "streaming".into(),
},
],
vec![
Role {
name: "full".into(),
permissions: vec!["Streaming".into(), "Control".into()],
},
Role {
name: "streaming".into(),
permissions: vec!["Streaming".into()],
},
],
)
}
fn basic(user: &str, pass: &str) -> String {
use base64::Engine;
base64::engine::general_purpose::STANDARD.encode(format!("{user}:{pass}"))
}
#[test]
fn valid_credentials() {
let v = test_validator();
let result = v.validate("Basic", &basic("admin", "secret")).unwrap();
assert_eq!(result.username, "admin");
assert!(result.permissions.contains(&"Streaming".into()));
assert!(result.permissions.contains(&"Control".into()));
}
#[test]
fn wrong_password() {
let v = test_validator();
let err = v.validate("Basic", &basic("admin", "wrong")).unwrap_err();
assert_eq!(err.code(), 401);
}
#[test]
fn unknown_user() {
let v = test_validator();
let err = v.validate("Basic", &basic("nobody", "x")).unwrap_err();
assert_eq!(err.code(), 401);
}
#[test]
fn unsupported_scheme() {
let v = test_validator();
let err = v.validate("Bearer", "token123").unwrap_err();
assert_eq!(err.code(), 401);
}
#[test]
fn streaming_only_user() {
let v = test_validator();
let result = v.validate("Basic", &basic("player", "play")).unwrap();
assert_eq!(result.username, "player");
assert!(result.permissions.contains(&"Streaming".into()));
assert!(!result.permissions.contains(&"Control".into()));
}
}