use std::{
collections::HashSet,
error::Error,
time::{Duration, SystemTime},
};
use bcrypt;
use bson::Document;
use rocket::{
http::Status,
request::{self, FromRequest},
Outcome, Request, State,
};
use serde::{de::Error as DeError, Deserialize, Deserializer};
use crate::{doc::DatabaseDocument, jwt::JwtHandler};
const ONE_WEEK: Duration = Duration::from_secs(60 * 60 * 24 * 7);
fn hash<'de, D>(deserializer: D) -> std::result::Result<String, D::Error>
where
D: Deserializer<'de>,
{
let pw = String::deserialize(deserializer)?;
bcrypt::hash(&pw, bcrypt::DEFAULT_COST)
.map_err(|err| D::Error::custom(err.description()))
}
#[derive(Debug, Serialize, Deserialize)]
pub struct User {
username: String,
#[serde(deserialize_with = "hash", rename(deserialize = "password"))]
pw_hash: String,
roles: UserRoles,
}
impl User {
pub fn verify(&self, login: &Login) -> bool {
bcrypt::verify(&login.password, &self.pw_hash).is_ok()
}
}
impl DatabaseDocument for User {
const COLLECTION_NAME: &'static str = "users";
type QueryData = Login;
fn gen_query(login: &Self::QueryData) -> Document {
doc! { "username" => &login.username }
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UserRoles(HashSet<String>);
impl UserRoles {
pub fn has_roles(&self, roles: &[&str]) -> bool {
roles.iter().all(|role| self.0.contains(*role))
}
}
impl<'a, 'r> FromRequest<'a, 'r> for UserRoles {
type Error = ();
fn from_request(
request: &'a Request<'r>,
) -> request::Outcome<Self, Self::Error> {
let jwt = request.guard::<State<JwtHandler<UserRolesToken>>>()?;
match jwt.decode_cookie(&request.cookies()) {
Some(Ok(token)) => Outcome::Success(token.roles),
Some(Err(_)) => Outcome::Failure((Status::BadRequest, ())),
None => Outcome::Failure((Status::Unauthorized, ())),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UserRolesToken {
exp: SystemTime, username: String,
roles: UserRoles,
}
impl From<User> for UserRolesToken {
fn from(user: User) -> Self {
Self {
exp: SystemTime::now() + ONE_WEEK, username: user.username,
roles: user.roles,
}
}
}
#[derive(Serialize, Deserialize, FromForm)]
pub struct Login {
username: String,
password: String,
}