use std::{sync::Arc, time::SystemTime};
use anyhow::bail;
use blake2::digest::{FixedOutput, Mac};
use bytes::{BufMut, Bytes, BytesMut};
use chacha20poly1305::{
XChaCha20Poly1305, XNonce,
aead::{Aead, AeadCore, OsRng},
};
use flexbuffers::VectorReader;
use ordinary_config::{
AccessTokenConfig, AuthConfig, ClientPasswordHash, InviteConfig, InviteMode, MfaConfig,
PasswordConfig, PasswordProtocol, RefreshTokenConfig, TotpAlgorithm, TotpConfig,
};
use tracing::instrument;
use uuid::Uuid;
use x25519_dalek::{EphemeralSecret, PublicKey};
use saferlmdb::{
self as lmdb, Database, DatabaseOptions, Environment, ReadTransaction, WriteTransaction, put,
};
pub use opaque_ke::ServerSetup;
use totp_rs::{Algorithm, Secret, TOTP};
use crate::keys::{KeyAlg, generate_256_bit_key};
use crate::{
DefaultCipherSuite, EXP_LEN, SIG_LEN, ZEROED_KEY,
keys::decrypt_256_bit_key,
recovery::{check_code, consume_code},
token::generate_hmac,
validate_account,
};
use crate::token::{extract_hmac_no_check, verify_client_signature, verify_hmac};
use sha2::{Digest, Sha256};
pub struct Auth {
pub domain: String,
pub config: AuthConfig,
totp_alg: Algorithm,
access_token_key: [u8; 32],
refresh_token_key: [u8; 32],
reset_password_token_key: [u8; 32],
invite_token_key: [u8; 32],
encryption_key: [u8; 32],
env: Arc<Environment>,
account_db: Arc<Database<'static>>,
state_db: Arc<Database<'static>>,
invite_db: Arc<Database<'static>>,
recovery_db: Arc<Database<'static>>,
}
type InviteClaimsValidator = fn(&str, &VectorReader<&[u8]>) -> anyhow::Result<bool>;
impl Auth {
#[allow(clippy::too_many_lines)]
pub fn new(
domain: String,
config: Option<AuthConfig>,
encryption_key: [u8; 32],
env: Arc<Environment>,
) -> anyhow::Result<Self> {
let account_db = Arc::new(Database::open(
env.clone(),
Some("user"),
&DatabaseOptions::new(lmdb::db::Flags::CREATE),
)?);
let state_db = Arc::new(Database::open(
env.clone(),
Some("auth"),
&DatabaseOptions::new(lmdb::db::Flags::CREATE),
)?);
let invite_db = Arc::new(Database::open(
env.clone(),
Some("invite"),
&DatabaseOptions::new(lmdb::db::Flags::CREATE),
)?);
let recovery_db = Arc::new(Database::open(
env.clone(),
Some("recovery"),
&DatabaseOptions::new(lmdb::db::Flags::CREATE),
)?);
let key_db = Arc::new(Database::open(
env.clone(),
Some("key"),
&DatabaseOptions::new(lmdb::db::Flags::CREATE),
)?);
let mut config = config.unwrap_or_else(|| AuthConfig {
password: PasswordConfig {
protocol: PasswordProtocol::Opaque,
},
mfa: MfaConfig {
totp: TotpConfig {
template: None,
algorithm: TotpAlgorithm::Sha1,
},
},
refresh_token: RefreshTokenConfig::default(),
access_token: AccessTokenConfig::default(),
client_hash: ClientPasswordHash::Sha256,
cookies_enabled: true,
invite: None,
});
config.access_token.claims.sort_by_key(|a| a.idx);
let txn = WriteTransaction::new(env.clone())?;
let access_token_key = Self::get_token_hmac_key(
&encryption_key,
&key_db,
config.access_token.rotation,
&txn,
"access",
)?;
let refresh_token_key = Self::get_token_hmac_key(
&encryption_key,
&key_db,
config.refresh_token.rotation,
&txn,
"refresh",
)?;
let reset_password_token_key = Self::get_token_hmac_key(
&encryption_key,
&key_db,
config.access_token.rotation,
&txn,
"reset_password",
)?;
let invite_token_key = Self::get_token_hmac_key(
&encryption_key,
&key_db,
config.access_token.rotation,
&txn,
"invite",
)?;
txn.commit()?;
Ok(Self {
domain,
totp_alg: match &config.mfa.totp.algorithm {
TotpAlgorithm::Sha1 => Algorithm::SHA1,
},
config,
access_token_key,
refresh_token_key,
reset_password_token_key,
invite_token_key,
encryption_key,
env,
state_db,
account_db,
invite_db,
recovery_db,
})
}
fn get_token_hmac_key(
encryption_key: &[u8; 32],
key_db: &Arc<Database>,
rotation: u32,
txn: &WriteTransaction,
key_name: &str,
) -> anyhow::Result<[u8; 32]> {
let mut access = txn.access();
let mut rng = OsRng;
let cipher = {
use chacha20poly1305::aead::KeyInit;
XChaCha20Poly1305::new(encryption_key.as_ref().into())
};
let mut lookup: BytesMut = key_name.as_bytes().into();
lookup.put_u8(KeyAlg::Blake2SMac256.as_byte());
let hmac_key = if let Ok(kid_key) = access.get::<[u8], [u8]>(key_db, lookup.as_ref()) {
let (kid, curr_key) = decrypt_256_bit_key(&cipher, kid_key)?;
if let Some(timestamp) = kid.get_timestamp() {
if (timestamp.to_unix().0 + u64::from(rotation))
< SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_secs()
{
let (kid_key, new_key, _kid) = generate_256_bit_key(&cipher, &mut rng)?;
access.put(key_db, lookup.as_ref(), &kid_key, &put::Flags::empty())?;
new_key
} else {
curr_key
}
} else {
curr_key
}
} else {
let (kid_key, hmac_key, _kid) = generate_256_bit_key(&cipher, &mut rng)?;
access.put(key_db, lookup.as_ref(), &kid_key, &put::Flags::empty())?;
hmac_key
};
let hashed_key: [u8; 32] = blake2::Blake2sMac256::new_from_slice(&hmac_key[..])?
.chain_update(lookup.as_ref())
.finalize_fixed()
.into();
Ok(hashed_key)
}
const MIN_REGISTRATION_START_LEN: usize = 1 + 32;
#[allow(clippy::type_complexity)]
#[instrument(skip_all, err)]
pub fn registration_start(
&self,
payload: Bytes,
existing_account: Option<&str>,
invite_claims_validator: Option<InviteClaimsValidator>,
checked_claims: Option<(VectorReader<&[u8]>, &[u8])>,
) -> anyhow::Result<Bytes> {
use chacha20poly1305::aead::KeyInit;
if payload.len() < Self::MIN_REGISTRATION_START_LEN {
bail!("payload too small");
}
let account_len = payload[0] as usize;
if account_len != payload.len() - Self::MIN_REGISTRATION_START_LEN {
bail!("invalid format");
}
let raw_account = &payload[1..=account_len];
let raw_account_str = std::str::from_utf8(raw_account)?;
let account_str = validate_account(raw_account_str)?;
let account = account_str.as_bytes();
if let Some(existing_account) = existing_account
&& existing_account != account_str
{
bail!("account mismatch");
}
tracing::info!(account = %account_str);
let invite_claims = if self.config.invite.is_some() && existing_account.is_none() {
let Some((claims_vec, claims)) = checked_claims else {
bail!("no checked claims")
};
if let Some(validator) = invite_claims_validator {
let res = validator(&account_str, &claims_vec)?;
if !res {
bail!("failed to validate invite token claims");
}
}
Some(claims)
} else {
None
};
let txn = WriteTransaction::new(self.env.clone())?;
let out = {
let mut access = txn.access();
if existing_account.is_none() {
if access.get::<[u8], [u8]>(&self.account_db, account).is_ok() {
bail!(
"account {} has already been registered.",
std::str::from_utf8(account)?
);
}
if access.get::<[u8], [u8]>(&self.state_db, account).is_ok() {
bail!(
"account {} has already been registered.",
std::str::from_utf8(account)?
);
}
}
let client_start = &payload[account_len + 1..account_len + 1 + 32];
let mut rng = OsRng;
let opaque = ServerSetup::<DefaultCipherSuite>::new(&mut rng);
let out = crate::registration::server_start(&opaque, account, client_start)?;
let cipher = XChaCha20Poly1305::new(&self.encryption_key.into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let mut encrypted_stored = match cipher.encrypt(&nonce, opaque.serialize().as_ref()) {
Ok(v) => v,
Err(err) => bail!("{err}"),
};
encrypted_stored.extend_from_slice(&nonce);
if let Some(invite_claims) = invite_claims {
encrypted_stored.extend_from_slice(invite_claims);
}
access.put(
&self.state_db,
account,
&encrypted_stored,
&put::Flags::empty(),
)?;
out
};
txn.commit()?;
Ok(out)
}
const MIN_REGISTRATION_FINISH_LEN: usize = 1 + 32 + 192;
#[allow(clippy::type_complexity, clippy::too_many_lines)]
#[instrument(skip_all, err)]
pub fn registration_finish(
&self,
payload: Bytes,
existing_account: Option<&str>,
) -> anyhow::Result<(Bytes, Vec<u8>, Option<Bytes>)> {
use chacha20poly1305::aead::KeyInit;
if payload.len() < Self::MIN_REGISTRATION_FINISH_LEN {
bail!("payload too small");
}
let account_len = payload[0] as usize;
if account_len != payload.len() - Self::MIN_REGISTRATION_FINISH_LEN {
bail!("invalid format");
}
let raw_account = &payload[1..=account_len];
let raw_account_str = std::str::from_utf8(raw_account)?;
let account_str = validate_account(raw_account_str)?;
let account = account_str.as_bytes();
if let Some(existing_account) = existing_account
&& existing_account != account_str
{
bail!("account mismatch");
}
tracing::info!(account = %account_str);
let public_key: [u8; 32] = payload[account_len + 1..account_len + 33].try_into()?;
if public_key == ZEROED_KEY {
bail!("public key cannot be all 0s");
}
let client_finish = &payload[account_len + 33..account_len + 33 + 192];
let cipher = XChaCha20Poly1305::new(&self.encryption_key.into());
let totp_nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let password_file_nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let txn = WriteTransaction::new(self.env.clone())?;
let (mut totp_mfa_secret, recovery_codes, invite_claims) = {
let mut invite_claims = None;
let mut access = txn.access();
let (totp_mfa_secret, recovery_codes) = if existing_account.is_none() {
if access.get::<[u8], [u8]>(&self.account_db, account).is_ok() {
bail!("account has already been registered");
}
let password_file = crate::registration::server_finish(client_finish)?;
let mut password_file =
match cipher.encrypt(&password_file_nonce, password_file.as_ref()) {
Ok(v) => v,
Err(err) => bail!("{err}"),
};
password_file.extend_from_slice(&password_file_nonce);
let totp_mfa_secret = Secret::generate_secret().to_bytes()?;
let mut encrypted_totp_mfa_secret =
match cipher.encrypt(&totp_nonce, totp_mfa_secret.as_ref()) {
Ok(v) => v,
Err(err) => bail!("{err}"),
};
encrypted_totp_mfa_secret.extend_from_slice(&totp_nonce);
password_file.splice(0..0, encrypted_totp_mfa_secret.iter().copied());
let serialized_opaque_and_invite_claims =
access.get::<[u8], [u8]>(&self.state_db, account)?;
if serialized_opaque_and_invite_claims == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
if serialized_opaque_and_invite_claims.len() > 168 {
invite_claims = Some(Bytes::copy_from_slice(
&serialized_opaque_and_invite_claims[168..],
));
}
password_file.splice(
0..0,
serialized_opaque_and_invite_claims[0..168].iter().copied(),
);
let mut empty_claims_builder =
flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let empty_claims_vec = empty_claims_builder.start_vector();
empty_claims_vec.end_vector();
password_file.extend_from_slice(empty_claims_builder.view());
let (stored_recovery_codes, recovery_codes) = crate::recovery::generate_codes()?;
access.put(
&self.recovery_db,
account,
&stored_recovery_codes[..],
&put::Flags::empty(),
)?;
access.put(
&self.account_db,
account,
&password_file,
&put::Flags::empty(),
)?;
(totp_mfa_secret, recovery_codes)
} else {
let account_record: &[u8] = access.get(&self.account_db, account)?;
if account_record == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
let password_file = crate::registration::server_finish(client_finish)?;
let mut password_file =
match cipher.encrypt(&password_file_nonce, password_file.as_ref()) {
Ok(v) => v,
Err(err) => bail!("{err}"),
};
password_file.extend_from_slice(&password_file_nonce);
password_file.splice(0..0, account_record[168..228].iter().copied());
let serialized_opaque_and_invite_claims =
access.get::<[u8], [u8]>(&self.state_db, account)?;
if serialized_opaque_and_invite_claims == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
if serialized_opaque_and_invite_claims.len() > 168 {
invite_claims = Some(Bytes::copy_from_slice(
&serialized_opaque_and_invite_claims[168..],
));
}
password_file.splice(
0..0,
serialized_opaque_and_invite_claims[0..168].iter().copied(),
);
password_file.extend_from_slice(&account_record[460..]);
access.put(
&self.account_db,
account,
&password_file,
&put::Flags::empty(),
)?;
(vec![], String::new())
};
access.del_key(&self.state_db, account)?;
(totp_mfa_secret, recovery_codes, invite_claims)
};
txn.commit()?;
if existing_account.is_none() {
let public_key = PublicKey::from(public_key);
let ephemeral_secret = EphemeralSecret::random_from_rng(OsRng);
let ephemeral_public_key = PublicKey::from(&ephemeral_secret);
let shared_secret = ephemeral_secret.diffie_hellman(&public_key);
if !shared_secret.was_contributory() {
bail!("non-contributory shared secret");
}
let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
totp_mfa_secret.extend_from_slice(recovery_codes.as_bytes());
let mut encrypted_secret = match cipher.encrypt(&nonce, totp_mfa_secret.as_ref()) {
Ok(es) => es,
Err(err) => bail!("{err}"),
};
encrypted_secret.extend_from_slice(&nonce);
encrypted_secret.extend_from_slice(ephemeral_public_key.as_bytes());
Ok((
Bytes::from(encrypted_secret),
account.to_vec(),
invite_claims,
))
} else {
Ok((Bytes::new(), account.to_vec(), invite_claims))
}
}
const MIN_LOGIN_START_LEN: usize = 1 + 96;
#[instrument(skip_all, err)]
pub fn login_start(&self, payload: Bytes) -> anyhow::Result<Bytes> {
if payload.len() < Self::MIN_LOGIN_START_LEN {
bail!("payload too small");
}
let account_len = payload[0] as usize;
if account_len != payload.len() - Self::MIN_LOGIN_START_LEN {
bail!("invalid format");
}
let raw_account = &payload[1..=account_len];
let raw_account_str = std::str::from_utf8(raw_account)?;
let account_str = validate_account(raw_account_str)?;
let account = account_str.as_bytes();
tracing::info!(account = %account_str);
let client_start = &payload[account_len + 1..account_len + 1 + 96];
let txn = WriteTransaction::new(self.env.clone())?;
let message = {
use chacha20poly1305::aead::KeyInit;
let mut access = txn.access();
let account_record: &[u8] = access.get(&self.account_db, account)?;
if account_record == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
let cipher = XChaCha20Poly1305::new(&self.encryption_key.into());
let nonce = XNonce::from_slice(&account_record[144..168]);
let serialized_opaque = match cipher.decrypt(nonce, &account_record[0..144]) {
Ok(v) => v,
Err(err) => bail!("{err}"),
};
let opaque = match ServerSetup::<DefaultCipherSuite>::deserialize(&serialized_opaque) {
Ok(o) => o,
Err(e) => bail!("{e}"),
};
let nonce = XNonce::from_slice(&account_record[436..460]);
let password_file = match cipher.decrypt(nonce, &account_record[228..436]) {
Ok(v) => v,
Err(err) => bail!("{err}"),
};
let (state, message) =
crate::login::server_start(&opaque, account, &password_file, client_start)?;
access.put(&self.state_db, account, &state[..], &put::Flags::empty())?;
message
};
txn.commit()?;
Ok(message)
}
const MIN_LOGIN_FINISH_LEN: usize = 1 + 32 + 16 + 24 + 64;
#[instrument(skip_all, err)]
#[allow(clippy::type_complexity)]
pub fn login_finish(
&self,
payload: Bytes,
check_mfa: bool,
) -> anyhow::Result<(Bytes, [u8; 32], Option<Bytes>)> {
let payload_len = payload.len();
if payload_len < Self::MIN_LOGIN_FINISH_LEN {
bail!("payload is too small");
}
let account_len = payload[0] as usize;
if payload_len != account_len + Self::MIN_LOGIN_FINISH_LEN
&& payload_len != account_len + Self::MIN_LOGIN_FINISH_LEN + 32
{
bail!("invalid format");
}
let raw_account = &payload[1..=account_len];
let raw_account_str = std::str::from_utf8(raw_account)?;
let account_str = validate_account(raw_account_str)?;
let account = account_str.as_bytes();
tracing::info!(account = %account_str);
let encrypted_mfa_code = &payload[account_len + 1..payload_len - 24 - 64];
let mfa_code_nonce = &payload[payload_len - 24 - 64..payload_len - 64];
let client_finish = &payload[payload_len - 64..];
let txn = WriteTransaction::new(self.env.clone())?;
let (encrypted_refresh_token, key, verifier) = {
let mut access = txn.access();
let account_record: &[u8] = access.get(&self.account_db, account)?;
if account_record == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
let cipher = {
use chacha20poly1305::aead::KeyInit;
XChaCha20Poly1305::new(&self.encryption_key.into())
};
let nonce = XNonce::from_slice(&account_record[204..228]);
let mfa_secret = match cipher.decrypt(nonce, &account_record[168..204]) {
Ok(v) => v,
Err(err) => bail!("{err}"),
};
let mfa_hash = if check_mfa {
let mfa_code = TOTP::new(
self.totp_alg,
6,
1,
30,
Secret::Raw(mfa_secret).to_bytes()?,
Some(self.domain.clone()),
std::str::from_utf8(account)?.to_string(),
)?
.generate_current()?;
let mut mfa_input = self.domain.as_bytes().to_vec();
mfa_input.extend_from_slice(account);
mfa_input.extend_from_slice(mfa_code.as_bytes());
match &self.config.client_hash {
ClientPasswordHash::Sha256 => {
let mut hasher = Sha256::new();
hasher.update(&mfa_input);
hasher.finalize().to_vec()
}
}
} else {
vec![0u8; 32]
};
let server_start: &[u8] = access.get(&self.state_db, account)?;
if server_start == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
let (encrypted_refresh_token, key, verifier) = crate::login::server_finish(
self,
account,
&mfa_hash[..],
encrypted_mfa_code,
mfa_code_nonce,
client_finish,
server_start,
)?;
access.del_key(&self.state_db, account)?;
(encrypted_refresh_token, key, verifier)
};
txn.commit()?;
Ok((encrypted_refresh_token, key, verifier))
}
#[instrument(skip_all, err)]
pub fn reset_password_login_start(&self, payload: Bytes) -> anyhow::Result<Bytes> {
self.login_start(payload)
}
#[instrument(skip_all, err)]
pub fn reset_password_login_finish(&self, payload: Bytes) -> anyhow::Result<Bytes> {
use chacha20poly1305::aead::KeyInit;
let (_, key, verifier) = self.login_finish(payload.clone(), true)?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account_str = std::str::from_utf8(&payload[1..=account_len])?;
let mut reset_password_claims =
flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut reset_password_claims_vec = reset_password_claims.start_vector();
let token_uuid = Uuid::now_v7();
let token_uuid_bytes = token_uuid.as_bytes();
reset_password_claims_vec.push(flexbuffers::Blob(&token_uuid_bytes[..]));
reset_password_claims_vec.push(self.domain.as_str());
reset_password_claims_vec.push(account_str);
if let Some(verifier) = verifier {
reset_password_claims_vec.push(flexbuffers::Blob(&verifier[..]));
}
reset_password_claims_vec.end_vector();
let password_reset_token = generate_hmac(
reset_password_claims.view(),
60 * 15,
&self.reset_password_token_key,
)?;
let cipher = XChaCha20Poly1305::new(&key.into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let mut encrypted_password_reset_token =
match cipher.encrypt(&nonce, password_reset_token.as_ref()) {
Ok(et) => et,
Err(err) => bail!("{err}"),
};
encrypted_password_reset_token.extend_from_slice(&nonce);
Ok(Bytes::copy_from_slice(&encrypted_password_reset_token))
}
#[instrument(skip_all, err)]
pub fn reset_password_registration_start(
&self,
payload: Bytes,
token: &[u8],
) -> anyhow::Result<Bytes> {
let account_str = self.verify_password_reset_token(token)?;
self.registration_start(payload, Some(account_str), None, None)
}
#[instrument(skip_all, err)]
pub fn reset_password_registration_finish(
&self,
payload: Bytes,
token: &[u8],
) -> anyhow::Result<()> {
let account_str = self.verify_password_reset_token(token)?;
self.registration_finish(payload, Some(account_str))?;
Ok(())
}
#[instrument(skip_all, err)]
pub fn forgot_password_start(&self, mut payload: Bytes) -> anyhow::Result<Bytes> {
let payload_len = payload.len();
if payload_len < 33 {
bail!("invalid payload");
}
let recovery_code = payload.split_off(payload_len - 32);
let account_len = (*payload.first().expect("payload already checked")) as usize;
if account_len > payload.len() - 1 {
bail!("invalid format");
}
let account = Bytes::copy_from_slice(&payload[1..=account_len]);
let account_str = std::str::from_utf8(&account[..])?;
let txn = ReadTransaction::new(self.env.clone())?;
let access = txn.access();
let read_recovery_codes = access.get(&self.recovery_db, account.as_ref())?;
if read_recovery_codes == b"deleted" {
bail!(
"{} has been deleted",
std::str::from_utf8(account.as_ref())?
);
}
let is_valid_recovery_code = check_code(&recovery_code[..], read_recovery_codes)?;
if !is_valid_recovery_code {
bail!("invalid recovery code");
}
self.registration_start(payload, Some(account_str), None, None)
}
#[instrument(skip_all, err)]
pub fn forgot_password_finish(&self, mut payload: Bytes) -> anyhow::Result<()> {
let payload_len = payload.len();
if payload_len < 12 {
bail!("invalid payload");
}
let recovery_code = payload.split_off(payload_len - 11);
let account_len = (*payload.first().expect("payload already checked")) as usize;
if account_len > payload.len() - 1 {
bail!("invalid format");
}
let account = Bytes::copy_from_slice(&payload[1..=account_len]);
let account_str = std::str::from_utf8(&account[..])?;
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
let read_recovery_codes = access.get(&self.recovery_db, account.as_ref())?;
if read_recovery_codes == b"deleted" {
bail!(
"{} has been deleted",
std::str::from_utf8(account.as_ref())?
);
}
let (is_valid_recovery_code, stored_codes) =
consume_code(&recovery_code[..], read_recovery_codes)?;
if is_valid_recovery_code {
access.put(
&self.recovery_db,
&account[..],
&stored_codes[..],
&put::Flags::empty(),
)?;
} else {
bail!("invalid recovery code");
}
}
txn.commit()?;
self.registration_finish(payload, Some(account_str))?;
Ok(())
}
#[instrument(skip_all, err)]
pub fn reset_totp_mfa_start(&self, payload: Bytes) -> anyhow::Result<Bytes> {
self.login_start(payload)
}
#[instrument(skip_all, err)]
pub fn reset_totp_mfa_finish(&self, payload: Bytes) -> anyhow::Result<Bytes> {
use chacha20poly1305::aead::KeyInit;
let (_, key, _) = self.login_finish(payload.clone(), true)?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account = &payload[1..=account_len];
let txn = WriteTransaction::new(self.env.clone())?;
let mfa_secret = {
let mut access = txn.access();
let mfa_secret = Secret::generate_secret().to_bytes()?;
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let cipher = XChaCha20Poly1305::new(&self.encryption_key.into());
let mut encrypted_mfa_secret = match cipher.encrypt(&nonce, mfa_secret.as_ref()) {
Ok(v) => v,
Err(err) => bail!("{err}"),
};
encrypted_mfa_secret.extend_from_slice(&nonce);
let account_record: &[u8] = access.get(&self.account_db, account)?;
if account_record == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
let mut account_record = account_record.to_vec();
account_record.splice(168..228, encrypted_mfa_secret);
access.put(
&self.account_db,
account,
&account_record,
&put::Flags::empty(),
)?;
mfa_secret
};
txn.commit()?;
let cipher = XChaCha20Poly1305::new(&key.into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let mut encrypted_mfa_secret = match cipher.encrypt(&nonce, mfa_secret.as_ref()) {
Ok(es) => es,
Err(err) => bail!("{err}"),
};
encrypted_mfa_secret.extend_from_slice(&nonce);
Ok(Bytes::copy_from_slice(&encrypted_mfa_secret))
}
#[instrument(skip_all, err)]
pub fn lost_totp_mfa_start(&self, mut payload: Bytes) -> anyhow::Result<Bytes> {
let payload_len = payload.len();
if payload_len < 33 {
bail!("invalid payload");
}
let recovery_code = payload.split_off(payload_len - 32);
let server_message = self.login_start(payload.clone())?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account = &payload[1..=account_len];
let txn = ReadTransaction::new(self.env.clone())?;
let access = txn.access();
let read_recovery_codes = access.get(&self.recovery_db, account)?;
if read_recovery_codes == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
let is_valid_recovery_code = check_code(&recovery_code[..], read_recovery_codes)?;
if !is_valid_recovery_code {
bail!("invalid recovery code");
}
Ok(server_message)
}
#[instrument(skip_all, err)]
pub fn lost_totp_mfa_finish(&self, mut payload: Bytes) -> anyhow::Result<Bytes> {
use chacha20poly1305::aead::KeyInit;
let payload_len = payload.len();
if payload_len < 12 {
bail!("invalid payload");
}
let recovery_code = payload.split_off(payload_len - 11);
let (_, key, _) = self.login_finish(payload.clone(), false)?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account = &payload[1..=account_len];
let txn = WriteTransaction::new(self.env.clone())?;
let mfa_secret = {
let mut access = txn.access();
let read_recovery_codes = access.get(&self.recovery_db, account)?;
if read_recovery_codes == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
let (is_valid_recovery_code, stored_codes) =
consume_code(&recovery_code[..], read_recovery_codes)?;
if is_valid_recovery_code {
access.put(
&self.recovery_db,
account,
&stored_codes[..],
&put::Flags::empty(),
)?;
let mfa_secret = Secret::generate_secret().to_bytes()?;
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let cipher = XChaCha20Poly1305::new(&self.encryption_key.into());
let mut encrypted_mfa_secret = match cipher.encrypt(&nonce, mfa_secret.as_ref()) {
Ok(v) => v,
Err(err) => bail!("{err}"),
};
encrypted_mfa_secret.extend_from_slice(&nonce);
let account_record: &[u8] = access.get(&self.account_db, account)?;
if account_record == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
let mut account_record = account_record.to_vec();
account_record.splice(168..228, encrypted_mfa_secret);
access.put(
&self.account_db,
account,
&account_record,
&put::Flags::empty(),
)?;
mfa_secret
} else {
bail!("invalid recovery code");
}
};
txn.commit()?;
let cipher = XChaCha20Poly1305::new(&key.into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let mut encrypted_mfa_secret = match cipher.encrypt(&nonce, mfa_secret.as_ref()) {
Ok(es) => es,
Err(err) => bail!("{err}"),
};
encrypted_mfa_secret.extend_from_slice(&nonce);
Ok(Bytes::copy_from_slice(&encrypted_mfa_secret))
}
#[instrument(skip_all, err)]
pub fn reset_recovery_codes_start(&self, payload: Bytes) -> anyhow::Result<Bytes> {
self.login_start(payload)
}
#[instrument(skip_all, err)]
pub fn reset_recovery_codes_finish(&self, payload: Bytes) -> anyhow::Result<Bytes> {
use chacha20poly1305::aead::KeyInit;
let (_, key, _) = self.login_finish(payload.clone(), true)?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account = &payload[1..=account_len];
let (stored_recovery_codes, recovery_codes) = crate::recovery::generate_codes()?;
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
access.put(
&self.recovery_db,
account,
&stored_recovery_codes[..],
&put::Flags::empty(),
)?;
}
txn.commit()?;
let cipher = XChaCha20Poly1305::new(&key.into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let mut encrypted_recovery_codes = match cipher.encrypt(&nonce, recovery_codes.as_bytes()) {
Ok(rc) => rc,
Err(err) => bail!("{err}"),
};
encrypted_recovery_codes.extend_from_slice(&nonce);
Ok(Bytes::copy_from_slice(&encrypted_recovery_codes))
}
#[instrument(skip_all, err)]
pub fn delete_account_start(&self, payload: Bytes) -> anyhow::Result<Bytes> {
self.login_start(payload)
}
#[instrument(skip_all, err)]
pub fn delete_account_finish(&self, payload: Bytes) -> anyhow::Result<()> {
self.login_finish(payload.clone(), true)?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account = &payload[1..=account_len];
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
access.put(&self.account_db, account, b"deleted", &put::Flags::empty())?;
access.put(&self.state_db, account, b"deleted", &put::Flags::empty())?;
access.put(&self.recovery_db, account, b"deleted", &put::Flags::empty())?;
}
txn.commit()?;
Ok(())
}
#[instrument(skip_all, err)]
pub fn list_accounts(&self) -> anyhow::Result<Bytes> {
let txn = ReadTransaction::new(self.env.clone())?;
let access = txn.access();
let mut cursor = txn.cursor(self.account_db.clone())?;
let (first_account, first_user) = cursor.first::<[u8], [u8]>(&access)?;
let mut builder = flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut builder_vec = builder.start_vector();
let mut first_user_vec = builder_vec.start_vector();
first_user_vec.push(first_account);
if first_user.len() > 460 {
let claims_reader = flexbuffers::Reader::get_root(&first_user[460..])?;
for claims_field in &self.config.access_token.claims {
claims_field.kind.copy_to(
&claims_reader.as_vector().idx(claims_field.idx as usize),
&mut first_user_vec,
None,
)?;
}
}
first_user_vec.end_vector();
while let Ok((account, user)) = cursor.next::<[u8], [u8]>(&access) {
let mut user_vec = builder_vec.start_vector();
user_vec.push(account);
if user.len() > 460 {
let claims_reader = flexbuffers::Reader::get_root(&user[460..])?;
for claims_field in &self.config.access_token.claims {
claims_field.kind.copy_to(
&claims_reader.as_vector().idx(claims_field.idx as usize),
&mut user_vec,
None,
)?;
}
}
user_vec.end_vector();
}
builder_vec.end_vector();
Ok(Bytes::copy_from_slice(builder.view()))
}
#[instrument(skip_all, err)]
pub fn set_claims(&self, account: &[u8], claims: &[u8]) -> anyhow::Result<()> {
let str_account = std::str::from_utf8(account)?;
let claims_root = flexbuffers::Reader::get_root(claims)?;
let mut claims_builder =
flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut dest = claims_builder.start_vector();
let system_claims_gap = dest.start_vector();
system_claims_gap.end_vector();
let claims_vec = claims_root.as_vector();
for claim in &self.config.access_token.claims {
let src = claims_vec.idx(claim.idx as usize);
claim.kind.copy_to(&src, &mut dest, None)?;
}
dest.end_vector();
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
let account_record: &[u8] = access.get(&self.account_db, account)?;
if account_record == b"deleted" {
bail!("{} has been deleted", std::str::from_utf8(account)?);
}
let mut account_record = account_record.to_vec();
if account_record.len() > 460 {
account_record.truncate(460);
}
account_record.extend_from_slice(claims_builder.view());
access.put(
&self.account_db,
account,
&account_record[..],
&put::Flags::empty(),
)?;
}
txn.commit()?;
tracing::info!(account = %str_account);
Ok(())
}
fn create_invite(
&self,
invite_config: &InviteConfig,
inviter_account: &str,
custom_claims: Option<Bytes>,
) -> anyhow::Result<Bytes> {
let mut invite_claims = flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut invite_claims_vec = invite_claims.start_vector();
let mut system_claims_vec = invite_claims_vec.start_vector();
let token_uuid = Uuid::now_v7();
let token_uuid_bytes = token_uuid.as_bytes();
tracing::info!(account = %inviter_account, token = %token_uuid);
system_claims_vec.push(flexbuffers::Blob(&token_uuid_bytes[..]));
system_claims_vec.push(self.domain.as_str());
system_claims_vec.push(inviter_account);
system_claims_vec.end_vector();
if let Some(config_claims) = &invite_config.claims
&& let Some(custom_claims) = custom_claims
{
let custom_claims_root = flexbuffers::Reader::get_root(custom_claims.as_ref())?;
let custom_claims = custom_claims_root.as_vector();
for claim in config_claims {
claim.kind.copy_to(
&custom_claims.idx(claim.idx as usize),
&mut invite_claims_vec,
None,
)?;
}
}
invite_claims_vec.end_vector();
match &invite_config.mode {
InviteMode::Root => {
if inviter_account != format!("root@{}", self.domain) {
bail!("inviter must be root");
}
}
InviteMode::Admin => {
}
InviteMode::Viral => (),
}
let invite_token = generate_hmac(
invite_claims.view(),
invite_config.lifetime,
&self.invite_token_key,
)?;
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
access.put(
&self.invite_db,
&token_uuid_bytes[..],
&[] as &[u8],
&put::Flags::empty(),
)?;
}
txn.commit()?;
Ok(invite_token)
}
pub fn account_exists(&self, account: &str) -> anyhow::Result<bool> {
let txn = ReadTransaction::new(self.env.clone())?;
let access = txn.access();
if access
.get::<[u8], [u8]>(&self.account_db, account.as_bytes())
.is_ok()
{
return Ok(true);
}
if access
.get::<[u8], [u8]>(&self.state_db, account.as_bytes())
.is_ok()
{
return Ok(true);
}
Ok(false)
}
#[instrument(skip_all, err)]
pub fn api_invite_get(
&self,
inviter_domain: &str,
inviter_account: &str,
custom_claims: Option<Bytes>,
) -> anyhow::Result<Bytes> {
if let Some(invite_config) = &self.config.invite {
let account = format!("{inviter_account}@{inviter_domain}");
self.create_invite(invite_config, account.as_str(), custom_claims)
} else {
bail!("no invite config")
}
}
#[instrument(skip_all, err)]
pub fn invite_get(&self, access_token: Bytes) -> anyhow::Result<Bytes> {
if let Some(invite_config) = &self.config.invite {
let (account, _) = self.verify_access_token(&access_token)?;
self.create_invite(invite_config, account, None)
} else {
bail!("no invite config")
}
}
#[allow(clippy::type_complexity)]
#[instrument(skip_all, err)]
pub fn invite_check<'a>(
&self,
invite_token: &'a [u8],
) -> anyhow::Result<(VectorReader<&'a [u8]>, &'a [u8])> {
if let Some(invite_config) = &self.config.invite {
let claims = verify_hmac(invite_token, &self.invite_token_key)?;
let claims_root = flexbuffers::Reader::get_root(claims)?;
let claims_vec = claims_root.as_vector();
let system_claims = claims_vec.idx(0).as_vector();
let token_uuid_bytes: [u8; 16] = system_claims.idx(0).as_blob().0.try_into()?;
let token_uuid_str = Uuid::from_bytes(token_uuid_bytes).to_string();
let api_or_site_domain = system_claims.idx(1).as_str();
let inviter_account = system_claims.idx(2).as_str();
if api_or_site_domain == self.domain {
match invite_config.mode {
InviteMode::Root => {
if inviter_account != format!("root@{}", self.domain) {
bail!("inviter must be root");
}
}
InviteMode::Admin => {
}
InviteMode::Viral => (),
}
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
if access
.del_key(&self.invite_db, &token_uuid_bytes[..])
.is_err()
{
bail!("token id does not exist");
}
}
txn.commit()?;
tracing::info!(account = %inviter_account, token = %token_uuid_str);
Ok((claims_vec, claims))
} else {
bail!(
"domain {api_or_site_domain} does not match for invite {token_uuid_str} from user {inviter_account}"
)
}
} else {
bail!("no invite config")
}
}
pub fn access_get(&self, refresh_token: &Bytes) -> anyhow::Result<Bytes> {
let span = tracing::info_span!("auth");
let span = span.in_scope(|| tracing::info_span!("token"));
let span = span.in_scope(|| tracing::info_span!("access"));
let span = span.in_scope(|| tracing::info_span!("get"));
span.in_scope(|| {
let (account, client_verifying_key) = self.verify_refresh_token(&refresh_token[..])?;
let mut claims_builder =
flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut dest = claims_builder.start_vector();
let txn = ReadTransaction::new(self.env.clone())?;
let access = txn.access();
let account_record: &[u8] = access.get(&self.account_db, account.as_bytes())?;
if account_record == b"deleted" {
bail!("{account} has been deleted");
}
let user_claims = flexbuffers::Reader::get_root(&account_record[460..])?.as_vector();
let mut system_claims = dest.start_vector();
let token_id = Uuid::new_v4();
system_claims.push(flexbuffers::Blob(&token_id.as_bytes()[..]));
system_claims.push(self.domain.as_str());
system_claims.push(account);
if let Some(client_verifying_key) = client_verifying_key {
system_claims.push(flexbuffers::Blob(&client_verifying_key[..]));
}
system_claims.end_vector();
for claim in &self.config.access_token.claims {
let src = user_claims.idx(claim.idx as usize);
claim.kind.copy_to(&src, &mut dest, None)?;
}
dest.end_vector();
let access_token = generate_hmac(
claims_builder.view(),
self.config.access_token.lifetime,
&self.access_token_key,
)?;
tracing::info!(token = %token_id, "generated");
Ok(access_token)
})
}
#[allow(clippy::type_complexity)]
pub fn verify_access_token<'a>(
&self,
token: &'a [u8],
) -> anyhow::Result<(&'a str, VectorReader<&'a [u8]>)> {
let span = tracing::info_span!("auth");
let span = span.in_scope(|| tracing::info_span!("token"));
let span = span.in_scope(|| tracing::info_span!("access"));
let span = span.in_scope(|| tracing::info_span!("verify"));
span.in_scope(|| {
let claims: &'a [u8] = extract_hmac_no_check(token)?;
let claims_vec = match flexbuffers::Reader::get_root(
&claims[..claims.len().checked_sub(EXP_LEN + SIG_LEN).unwrap_or(claims.len())],
) {
Ok(v) => v.as_vector(),
Err(_) => flexbuffers::Reader::get_root(claims)?.as_vector(),
};
let system_claims = claims_vec.idx(0).as_vector();
let token_uuid_bytes: [u8; 16] = system_claims.idx(0).as_blob().0.try_into()?;
let claims_uuid_str = Uuid::from_bytes(token_uuid_bytes).to_string();
let claims_domain = system_claims.idx(1).as_str();
let claims_account = system_claims.idx(2).as_str();
if claims_domain == self.domain {
let claims_verifier = system_claims.idx(3).as_blob().0;
if claims_verifier.is_empty() {
verify_hmac(token, &self.access_token_key)?;
} else {
verify_client_signature(token, claims_verifier.try_into()?)?;
verify_hmac(&token[..token.len() - (EXP_LEN + SIG_LEN)], &self.access_token_key)?;
}
tracing::info!(
account = %claims_account,
token = %claims_uuid_str,
"verified"
);
Ok((claims_account, claims_vec))
} else {
bail!(
"domain {claims_domain} does not match for user {claims_account} with token {claims_uuid_str}"
)
}
})
}
pub(crate) fn generate_refresh_token(
&self,
account: &[u8],
verifier_claim: Option<&[u8]>,
) -> anyhow::Result<Bytes> {
let span = tracing::info_span!("auth");
let span = span.in_scope(|| tracing::info_span!("token"));
let span = span.in_scope(|| tracing::info_span!("refresh"));
let span = span.in_scope(|| tracing::info_span!("generate"));
span.in_scope(|| {
let str_account = std::str::from_utf8(account)?;
let mut claims_builder =
flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut system_claims = claims_builder.start_vector();
let token_id = Uuid::new_v4();
system_claims.push(flexbuffers::Blob(&token_id.as_bytes()[..]));
system_claims.push(self.domain.as_str());
system_claims.push(str_account);
if let Some(verifier_claim) = verifier_claim {
system_claims.push(flexbuffers::Blob(verifier_claim));
}
system_claims.end_vector();
let refresh_token = generate_hmac(
claims_builder.view(),
self.config.refresh_token.lifetime,
&self.refresh_token_key,
)?;
tracing::info!(token = token_id.to_string(), "generated");
Ok(refresh_token)
})
}
#[allow(clippy::type_complexity)]
fn verify_refresh_token<'a>(
&self,
token: &'a [u8],
) -> anyhow::Result<(&'a str, Option<&'a [u8; 32]>)> {
let span = tracing::info_span!("auth");
let span = span.in_scope(|| tracing::info_span!("token"));
let span = span.in_scope(|| tracing::info_span!("refresh"));
let span = span.in_scope(|| tracing::info_span!("verify"));
span.in_scope(|| {
let claims = extract_hmac_no_check(token)?;
let claims_vec = match flexbuffers::Reader::get_root(
&claims[..claims
.len()
.checked_sub(EXP_LEN + SIG_LEN)
.unwrap_or(claims.len())],
) {
Ok(v) => v.as_vector(),
Err(_) => flexbuffers::Reader::get_root(claims)?.as_vector(),
};
let token_uuid_bytes: [u8; 16] = claims_vec.idx(0).as_blob().0.try_into()?;
let claims_uuid_str = Uuid::from_bytes(token_uuid_bytes).to_string();
let claims_domain = claims_vec.idx(1).as_str();
let claims_account = claims_vec.idx(2).as_str();
let claims_verifier = claims_vec.idx(3).as_blob().0;
let mut client_verifying_key = None;
if claims_verifier.is_empty() {
verify_hmac(token, &self.refresh_token_key)?;
} else {
let verifying_key_bytes = claims_verifier.try_into()?;
client_verifying_key = Some(verifying_key_bytes);
verify_client_signature(token, verifying_key_bytes)?;
verify_hmac(
&token[..token.len() - (EXP_LEN + SIG_LEN)],
&self.refresh_token_key,
)?;
}
tracing::info!(
domain = %claims_domain,
account = %claims_account,
token = %claims_uuid_str,
"verified"
);
Ok((claims_account, client_verifying_key))
})
}
#[instrument(skip_all, err)]
pub fn verify_password_reset_token<'a>(&self, token: &'a [u8]) -> anyhow::Result<&'a str> {
let claims: &'a [u8] = extract_hmac_no_check(token)?;
let claims_vec = match flexbuffers::Reader::get_root(
&claims[..claims
.len()
.checked_sub(EXP_LEN + SIG_LEN)
.unwrap_or(claims.len())],
) {
Ok(v) => v.as_vector(),
Err(_) => flexbuffers::Reader::get_root(claims)?.as_vector(),
};
let token_uuid_bytes: [u8; 16] = claims_vec.idx(0).as_blob().0.try_into()?;
let claims_uuid_str = Uuid::from_bytes(token_uuid_bytes).to_string();
let claims_domain = claims_vec.idx(1).as_str();
let claims_account = claims_vec.idx(2).as_str();
if claims_domain == self.domain {
let claims_verifier = claims_vec.idx(3).as_blob().0;
if claims_verifier.is_empty() {
verify_hmac(token, &self.reset_password_token_key)?;
} else {
verify_client_signature(token, claims_verifier.try_into()?)?;
verify_hmac(
&token[..token.len() - (EXP_LEN + SIG_LEN)],
&self.reset_password_token_key,
)?;
}
tracing::info!(
account = %claims_account,
token = %claims_uuid_str,
"verified"
);
Ok(claims_account)
} else {
bail!(
"domain {claims_domain} does not match for user {claims_account} with token {claims_uuid_str}"
)
}
}
}