use crate::{Authority, PasswordEncoder, Role, SecurityError, SecurityResult};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub username: String,
#[serde(skip_serializing)]
pub password: String,
pub authorities: Vec<Authority>,
pub enabled: bool,
pub account_non_expired: bool,
pub credentials_non_expired: bool,
pub account_non_locked: bool,
}
impl User {
pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
username: username.into(),
password: password.into(),
authorities: Vec::new(),
enabled: true,
account_non_expired: true,
credentials_non_expired: true,
account_non_locked: true,
}
}
pub fn with_roles(
username: impl Into<String>,
password: impl Into<String>,
roles: &[Role],
) -> Self {
let authorities = roles.iter().map(|r| Authority::Role(r.clone())).collect();
Self {
username: username.into(),
password: password.into(),
authorities,
enabled: true,
account_non_expired: true,
credentials_non_expired: true,
account_non_locked: true,
}
}
pub fn add_authority(mut self, authority: Authority) -> Self {
self.authorities.push(authority);
self
}
pub fn add_role(mut self, role: Role) -> Self {
self.authorities.push(Authority::Role(role));
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn has_authority(&self, authority: &Authority) -> bool {
self.authorities.contains(authority)
}
pub fn has_role(&self, role: &Role) -> bool {
self.authorities.contains(&Authority::Role(role.clone()))
}
}
#[derive(Debug, Clone)]
pub struct UserBuilder {
username: Option<String>,
password: Option<String>,
authorities: Vec<Authority>,
enabled: bool,
account_non_expired: bool,
credentials_non_expired: bool,
account_non_locked: bool,
}
impl UserBuilder {
pub fn new() -> Self {
Self {
username: None,
password: None,
authorities: Vec::new(),
enabled: true,
account_non_expired: true,
credentials_non_expired: true,
account_non_locked: true,
}
}
pub fn username(mut self, username: impl Into<String>) -> Self {
self.username = Some(username.into());
self
}
pub fn password(mut self, password: impl Into<String>) -> Self {
self.password = Some(password.into());
self
}
pub fn password_encoded(
mut self,
password: impl Into<String>,
encoder: &dyn PasswordEncoder,
) -> Self {
let raw = password.into();
self.password = Some(encoder.encode(&raw));
self
}
pub fn roles(mut self, roles: &[Role]) -> Self {
for role in roles {
self.authorities.push(Authority::Role(role.clone()));
}
self
}
pub fn authorities(mut self, authorities: &[Authority]) -> Self {
self.authorities.extend(authorities.iter().cloned());
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn account_non_expired(mut self, non_expired: bool) -> Self {
self.account_non_expired = non_expired;
self
}
pub fn credentials_non_expired(mut self, non_expired: bool) -> Self {
self.credentials_non_expired = non_expired;
self
}
pub fn account_non_locked(mut self, non_locked: bool) -> Self {
self.account_non_locked = non_locked;
self
}
pub fn build(self) -> SecurityResult<User> {
Ok(User {
username: self
.username
.ok_or_else(|| SecurityError::InvalidCredentials("Missing username".to_string()))?,
password: self
.password
.ok_or_else(|| SecurityError::InvalidCredentials("Missing password".to_string()))?,
authorities: self.authorities,
enabled: self.enabled,
account_non_expired: self.account_non_expired,
credentials_non_expired: self.credentials_non_expired,
account_non_locked: self.account_non_locked,
})
}
}
impl Default for UserBuilder {
fn default() -> Self {
Self::new()
}
}
pub trait UserDetails: Send + Sync {
fn authorities(&self) -> Vec<Authority>;
fn password(&self) -> &str;
fn username(&self) -> &str;
fn is_account_non_expired(&self) -> bool;
fn is_account_non_locked(&self) -> bool;
fn is_credentials_non_expired(&self) -> bool;
fn is_enabled(&self) -> bool;
}
impl UserDetails for User {
fn authorities(&self) -> Vec<Authority> {
self.authorities.clone()
}
fn password(&self) -> &str {
&self.password
}
fn username(&self) -> &str {
&self.username
}
fn is_account_non_expired(&self) -> bool {
self.account_non_expired
}
fn is_account_non_locked(&self) -> bool {
self.account_non_locked
}
fn is_credentials_non_expired(&self) -> bool {
self.credentials_non_expired
}
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[async_trait::async_trait]
pub trait UserService: Send + Sync {
async fn load_user_by_username(&self, username: &str) -> SecurityResult<Arc<dyn UserDetails>>;
async fn create_user(&self, user: User) -> SecurityResult<()>;
async fn update_user(&self, user: User) -> SecurityResult<()>;
async fn delete_user(&self, username: &str) -> SecurityResult<()>;
async fn user_exists(&self, username: &str) -> bool;
}
#[derive(Debug)]
pub struct InMemoryUserService {
users: Arc<tokio::sync::RwLock<std::collections::HashMap<String, User>>>,
}
impl InMemoryUserService {
pub fn new() -> Self {
Self {
users: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
}
}
pub async fn add_user(&self, user: User) {
let mut users = self.users.write().await;
users.insert(user.username.clone(), user);
}
pub async fn with_users(users: Vec<User>) -> Self {
let service = Self::new();
let users_map: std::collections::HashMap<_, _> =
users.into_iter().map(|u| (u.username.clone(), u)).collect();
service.users.write().await.extend(users_map);
service
}
}
impl Default for InMemoryUserService {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl UserService for InMemoryUserService {
async fn load_user_by_username(&self, username: &str) -> SecurityResult<Arc<dyn UserDetails>> {
let users: tokio::sync::RwLockReadGuard<'_, std::collections::HashMap<String, User>> =
self.users.read().await;
users
.get(username)
.map(|u: &User| Arc::new(u.clone()) as Arc<dyn UserDetails>)
.ok_or_else(|| SecurityError::UserNotFound(username.to_string()))
}
async fn create_user(&self, user: User) -> SecurityResult<()> {
let mut users: tokio::sync::RwLockWriteGuard<'_, std::collections::HashMap<String, User>> =
self.users.write().await;
users.insert(user.username.clone(), user);
Ok(())
}
async fn update_user(&self, user: User) -> SecurityResult<()> {
let mut users: tokio::sync::RwLockWriteGuard<'_, std::collections::HashMap<String, User>> =
self.users.write().await;
users.insert(user.username.clone(), user);
Ok(())
}
async fn delete_user(&self, username: &str) -> SecurityResult<()> {
let mut users: tokio::sync::RwLockWriteGuard<'_, std::collections::HashMap<String, User>> =
self.users.write().await;
users
.remove(username)
.ok_or_else(|| SecurityError::UserNotFound(username.to_string()))?;
Ok(())
}
async fn user_exists(&self, username: &str) -> bool {
let users: tokio::sync::RwLockReadGuard<'_, std::collections::HashMap<String, User>> =
self.users.read().await;
users.contains_key(username)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_user_builder() {
let user = UserBuilder::new()
.username("john")
.password("secret")
.roles(&[Role::User, Role::Admin])
.build()
.unwrap();
assert_eq!(user.username, "john");
assert!(user.has_role(&Role::User));
assert!(user.has_role(&Role::Admin));
}
#[test]
fn test_user_with_roles() {
let user = User::with_roles("john", "secret", &[Role::User, Role::Admin]);
assert!(user.has_role(&Role::User));
assert!(user.has_role(&Role::Admin));
}
#[tokio::test]
async fn test_in_memory_user_service() {
let service = InMemoryUserService::with_users(vec![User::with_roles(
"john",
"secret",
&[Role::User],
)])
.await;
assert!(service.user_exists("john").await);
let user = service.load_user_by_username("john").await.unwrap();
assert_eq!(user.username(), "john");
assert!(user.is_enabled());
}
}