use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use crate::auth::{
AccessTokenClaims, Auth, PaginationParams, Permission, Role, User, UserChangeset, UserSession,
UserSessionChangeset, UserSessionJson, UserSessionResponse, ID,
};
use crate::{Connection, Database, Mailer};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
pub const COOKIE_NAME: &str = "refresh_token";
lazy_static! {
pub static ref ARGON_CONFIG: argon2::Config<'static> = argon2::Config {
variant: argon2::Variant::Argon2id,
version: argon2::Version::Version13,
secret: std::env::var("SECRET_KEY").map_or_else(|_| panic!("No SECRET_KEY environment variable set!"), |s| Box::leak(s.into_boxed_str()).as_bytes()),
..Default::default()
};
}
#[cfg(not(debug_assertions))]
type Seconds = i64;
type StatusCode = u16;
type Message = &'static str;
#[derive(Deserialize, Serialize)]
#[cfg_attr(feature = "plugin_utoipa", derive(utoipa::ToSchema))]
pub struct LoginInput {
email: String,
password: String,
device: Option<String>,
#[cfg(not(debug_assertions))]
ttl: Option<Seconds>, #[cfg(debug_assertions)]
ttl: Option<i64>, }
#[derive(Debug, Serialize, Deserialize)]
pub struct RefreshTokenClaims {
exp: usize,
sub: ID,
token_type: String,
}
#[derive(Serialize, Deserialize)]
#[cfg_attr(feature = "plugin_utoipa", derive(utoipa::ToSchema))]
pub struct RegisterInput {
email: String,
password: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RegistrationClaims {
exp: usize,
sub: ID,
token_type: String,
}
#[derive(Serialize, Deserialize)]
#[cfg_attr(feature = "plugin_utoipa", derive(utoipa::IntoParams))]
pub struct ActivationInput {
activation_token: String,
}
#[derive(Serialize, Deserialize)]
#[cfg_attr(feature = "plugin_utoipa", derive(utoipa::ToSchema))]
pub struct ForgotInput {
email: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ResetTokenClaims {
exp: usize,
sub: ID,
token_type: String,
}
#[derive(Serialize, Deserialize)]
#[cfg_attr(feature = "plugin_utoipa", derive(utoipa::ToSchema))]
pub struct ChangeInput {
old_password: String,
new_password: String,
}
#[derive(Serialize, Deserialize)]
#[cfg_attr(feature = "plugin_utoipa", derive(utoipa::ToSchema))]
pub struct ResetInput {
reset_token: String,
new_password: String,
}
pub fn get_sessions(
db: &Database,
auth: &Auth,
info: &PaginationParams,
) -> Result<UserSessionResponse, (StatusCode, Message)> {
let mut db = db.get_connection().unwrap();
let Ok(sessions) = UserSession::read_all(&mut db, info, auth.user_id) else {
return Err((500, "Could not fetch sessions."));
};
let sessions_json: Vec<UserSessionJson> = sessions
.iter()
.map(|s| UserSessionJson {
id: s.id,
device: s.device.clone(),
created_at: s.created_at,
#[cfg(not(feature = "database_sqlite"))]
updated_at: s.updated_at,
})
.collect();
let Ok(num_sessions) = UserSession::count_all(&mut db, auth.user_id) else {
return Err((500, "Could not fetch sessions."));
};
let num_pages = (num_sessions / info.page_size) + i64::from(num_sessions % info.page_size != 0);
let resp = UserSessionResponse {
sessions: sessions_json,
num_pages,
};
Ok(resp)
}
pub fn destroy_session(
db: &Database,
auth: &Auth,
item_id: ID,
) -> Result<(), (StatusCode, Message)> {
let mut db = db.get_connection().unwrap();
let user_session = match UserSession::read(&mut db, item_id) {
Ok(user_session) if user_session.user_id == auth.user_id => user_session,
Ok(_) => return Err((404, "Session not found.")),
Err(_) => return Err((500, "Internal error.")),
};
UserSession::delete(&mut db, user_session.id)
.map_err(|_| (500, "Could not delete session."))?;
Ok(())
}
pub fn destroy_sessions(db: &Database, auth: &Auth) -> Result<(), (StatusCode, Message)> {
let mut db = db.get_connection().unwrap();
UserSession::delete_all_for_user(&mut db, auth.user_id)
.map_err(|_| (500, "Could not delete sessions."))?;
Ok(())
}
type AccessToken = String;
type RefreshToken = String;
pub fn login(
db: &Database,
item: &LoginInput,
) -> Result<(AccessToken, RefreshToken), (StatusCode, Message)> {
let mut db = db.get_connection().unwrap();
let device = match item.device {
Some(ref device) if device.len() > 256 => {
return Err((400, "'device' cannot be longer than 256 characters."));
}
Some(ref device) => Some(device.clone()),
None => None,
};
let user = match User::find_by_email(&mut db, item.email.clone()) {
Ok(user) if user.activated => user,
Ok(_) => return Err((400, "Account has not been activated.")),
Err(_) => return Err((401, "Invalid credentials.")),
};
let is_valid = argon2::verify_encoded_ext(
&user.hash_password,
item.password.as_bytes(),
ARGON_CONFIG.secret,
ARGON_CONFIG.ad,
)
.unwrap();
if !is_valid {
return Err((401, "Invalid credentials."));
}
create_user_session(&mut db, device, None, user.id)
}
pub fn create_user_session(
db: &mut Connection,
device_type: Option<String>,
ttl: Option<i64>,
user_id: i32,
) -> Result<(AccessToken, RefreshToken), (StatusCode, Message)> {
let device = match device_type {
Some(device) if device.len() > 256 => {
return Err((400, "'device' cannot be longer than 256 characters."));
}
Some(device) => Some(device),
None => None,
};
let Ok(permissions) = Permission::fetch_all(db, user_id) else {
return Err((500, "An internal server error occurred."));
};
let Ok(roles) = Role::fetch_all(db, user_id) else {
return Err((500, "An internal server error occurred."));
};
let access_token_duration = chrono::Duration::seconds(
ttl.map_or_else(|| 15 * 60, |tt| std::cmp::max(tt, 1)),
);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let access_token_claims = AccessTokenClaims {
exp: (chrono::Utc::now() + access_token_duration).timestamp() as usize,
sub: user_id,
token_type: "access_token".to_string(),
roles,
permissions,
};
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let refresh_token_claims = RefreshTokenClaims {
exp: (chrono::Utc::now() + chrono::Duration::hours(24)).timestamp() as usize,
sub: user_id,
token_type: "refresh_token".to_string(),
};
let access_token = encode(
&Header::default(),
&access_token_claims,
&EncodingKey::from_secret(std::env::var("SECRET_KEY").unwrap().as_ref()),
)
.unwrap();
let refresh_token = encode(
&Header::default(),
&refresh_token_claims,
&EncodingKey::from_secret(std::env::var("SECRET_KEY").unwrap().as_ref()),
)
.unwrap();
UserSession::create(
db,
&UserSessionChangeset {
user_id,
refresh_token: refresh_token.clone(),
device,
},
)
.map_err(|_| (500, "Could not create session."))?;
Ok((access_token, refresh_token))
}
pub fn logout(db: &Database, refresh_token: Option<&'_ str>) -> Result<(), (StatusCode, Message)> {
let mut db = db.get_connection().unwrap();
let Some(refresh_token) = refresh_token else {
return Err((401, "Invalid session."));
};
let Ok(session) = UserSession::find_by_refresh_token(&mut db, refresh_token) else {
return Err((401, "Invalid session."));
};
UserSession::delete(&mut db, session.id).map_err(|_| (500, "Could not delete session."))?;
Ok(())
}
pub fn refresh(
db: &Database,
refresh_token_str: Option<&'_ str>,
) -> Result<(AccessToken, RefreshToken), (StatusCode, Message)> {
let mut db = db.get_connection().unwrap();
let Some(refresh_token_str) = refresh_token_str else {
return Err((401, "Invalid session."));
};
let _refresh_token = match decode::<RefreshTokenClaims>(
refresh_token_str,
&DecodingKey::from_secret(std::env::var("SECRET_KEY").unwrap().as_ref()),
&Validation::default(),
) {
Ok(token)
if token
.claims
.token_type
.eq_ignore_ascii_case("refresh_token") =>
{
token
}
_ => return Err((401, "Invalid token.")),
};
let Ok(session) = UserSession::find_by_refresh_token(&mut db, refresh_token_str) else {
return Err((401, "Invalid session."));
};
let Ok(permissions) = Permission::fetch_all(&mut db, session.user_id) else {
return Err((500, "An internal server error occurred."));
};
let Ok(roles) = Role::fetch_all(&mut db, session.user_id) else {
return Err((500, "An internal server error occurred."));
};
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let access_token_claims = AccessTokenClaims {
exp: (chrono::Utc::now() + chrono::Duration::minutes(15)).timestamp() as usize,
sub: session.user_id,
token_type: "access_token".to_string(),
roles,
permissions,
};
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let refresh_token_claims = RefreshTokenClaims {
exp: (chrono::Utc::now() + chrono::Duration::hours(24)).timestamp() as usize,
sub: session.user_id,
token_type: "refresh_token".to_string(),
};
let access_token = encode(
&Header::default(),
&access_token_claims,
&EncodingKey::from_secret(std::env::var("SECRET_KEY").unwrap().as_ref()),
)
.unwrap();
let refresh_token_str = encode(
&Header::default(),
&refresh_token_claims,
&EncodingKey::from_secret(std::env::var("SECRET_KEY").unwrap().as_ref()),
)
.unwrap();
UserSession::update(
&mut db,
session.id,
&UserSessionChangeset {
user_id: session.user_id,
refresh_token: refresh_token_str.clone(),
device: session.device,
},
)
.map_err(|_| (500, "Could not update session."))?;
Ok((access_token, refresh_token_str))
}
pub fn register(
db: &Database,
item: &RegisterInput,
mailer: &Mailer,
) -> Result<(), (StatusCode, Message)> {
let mut db = db.get_connection().unwrap();
match User::find_by_email(&mut db, item.email.to_string()) {
Ok(user) if user.activated => return Err((400, "Already registered.")),
Ok(user) => {
User::delete(&mut db, user.id).unwrap();
}
Err(_) => (),
}
let salt = generate_salt();
let hash = argon2::hash_encoded(item.password.as_bytes(), &salt, &ARGON_CONFIG).unwrap();
let user = User::create(
&mut db,
&UserChangeset {
activated: false,
email: item.email.clone(),
hash_password: hash,
},
)
.unwrap();
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let registration_claims = RegistrationClaims {
exp: (chrono::Utc::now() + chrono::Duration::days(30)).timestamp() as usize,
sub: user.id,
token_type: "activation_token".to_string(),
};
let token = encode(
&Header::default(),
®istration_claims,
&EncodingKey::from_secret(std::env::var("SECRET_KEY").unwrap().as_ref()),
)
.unwrap();
mailer
.templates
.send_register(mailer, &user.email, &format!("activate?token={token}"));
Ok(())
}
pub fn activate(
db: &Database,
item: &ActivationInput,
mailer: &Mailer,
) -> Result<(), (StatusCode, Message)> {
let mut db = db.get_connection().unwrap();
let token = match decode::<RegistrationClaims>(
&item.activation_token,
&DecodingKey::from_secret(std::env::var("SECRET_KEY").unwrap().as_ref()),
&Validation::default(),
) {
Ok(token)
if token
.claims
.token_type
.eq_ignore_ascii_case("activation_token") =>
{
token
}
_ => return Err((401, "Invalid token.")),
};
let user = match User::read(&mut db, token.claims.sub) {
Ok(user) if !user.activated => user,
Ok(_) => return Err((200, "Already activated!")),
Err(_) => return Err((400, "Invalid token.")),
};
User::update(
&mut db,
user.id,
&UserChangeset {
activated: true,
email: user.email.clone(),
hash_password: user.hash_password,
},
)
.map_err(|_| (500, "Could not activate user."))?;
mailer.templates.send_activated(mailer, &user.email);
Ok(())
}
pub fn forgot_password(
db: &Database,
item: &ForgotInput,
mailer: &Mailer,
) -> Result<(), (StatusCode, Message)> {
let mut db = db.get_connection().unwrap();
let user_result = User::find_by_email(&mut db, item.email.clone());
if let Ok(user) = user_result {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let reset_token_claims = ResetTokenClaims {
exp: (chrono::Utc::now() + chrono::Duration::hours(24)).timestamp() as usize,
sub: user.id,
token_type: "reset_token".to_string(),
};
let reset_token = encode(
&Header::default(),
&reset_token_claims,
&EncodingKey::from_secret(std::env::var("SECRET_KEY").unwrap().as_ref()),
)
.unwrap();
let link = &format!("reset?token={reset_token}");
mailer
.templates
.send_recover_existent_account(mailer, &user.email, link);
} else {
let link = &"register".to_string();
mailer
.templates
.send_recover_nonexistent_account(mailer, &item.email, link);
}
Ok(())
}
pub fn change_password(
db: &Database,
item: &ChangeInput,
auth: &Auth,
mailer: &Mailer,
) -> Result<(), (StatusCode, Message)> {
if item.old_password.is_empty() || item.new_password.is_empty() {
return Err((400, "Missing password"));
}
if item.old_password.eq(&item.new_password) {
return Err((400, "The new password must be different"));
}
let mut db = db.get_connection().unwrap();
let user = match User::read(&mut db, auth.user_id) {
Ok(user) if user.activated => user,
Ok(_) => return Err((400, "Account has not been activated")),
Err(_) => return Err((500, "Could not find user")),
};
let is_old_password_valid = argon2::verify_encoded_ext(
&user.hash_password,
item.old_password.as_bytes(),
ARGON_CONFIG.secret,
ARGON_CONFIG.ad,
)
.unwrap();
if !is_old_password_valid {
return Err((401, "Invalid credentials"));
}
let salt = generate_salt();
let new_hash =
argon2::hash_encoded(item.new_password.as_bytes(), &salt, &ARGON_CONFIG).unwrap();
User::update(
&mut db,
auth.user_id,
&UserChangeset {
email: user.email.clone(),
hash_password: new_hash,
activated: user.activated,
},
)
.map_err(|_| (500, "Could not update password"))?;
mailer.templates.send_password_changed(mailer, &user.email);
Ok(())
}
pub const fn check(_: &Auth) {}
pub fn reset_password(
db: &Database,
item: &ResetInput,
mailer: &Mailer,
) -> Result<(), (StatusCode, Message)> {
let mut db = db.get_connection().unwrap();
if item.new_password.is_empty() {
return Err((400, "Missing password"));
}
let token = match decode::<ResetTokenClaims>(
&item.reset_token,
&DecodingKey::from_secret(std::env::var("SECRET_KEY").unwrap().as_ref()),
&Validation::default(),
) {
Ok(token) if token.claims.token_type.eq_ignore_ascii_case("reset_token") => token,
_ => return Err((401, "Invalid token.")),
};
let user = match User::read(&mut db, token.claims.sub) {
Ok(user) if user.activated => user,
Ok(_) => return Err((400, "Account has not been activated")),
Err(_) => return Err((400, "Invalid token.")),
};
let salt = generate_salt();
let new_hash =
argon2::hash_encoded(item.new_password.as_bytes(), &salt, &ARGON_CONFIG).unwrap();
User::update(
&mut db,
user.id,
&UserChangeset {
email: user.email.clone(),
hash_password: new_hash,
activated: user.activated,
},
)
.map_err(|_| (500, "Could not update password"))?;
mailer.templates.send_password_reset(mailer, &user.email);
Ok(())
}
#[must_use]
#[allow(clippy::missing_panics_doc)]
pub fn generate_salt() -> [u8; 16] {
use rand::Fill;
let mut salt = [0; 16];
salt.try_fill(&mut rand::thread_rng()).unwrap();
salt
}