use std::{sync::Arc, time::SystemTime};
use blake2::digest::{FixedOutput, Mac};
use bytes::{BufMut, Bytes, BytesMut};
use chacha20poly1305::{
XChaCha20Poly1305, XNonce,
aead::{Aead, AeadCore, OsRng},
};
use flexbuffers::VectorReader;
use ordinary_config::{
AccessTokenConfig, AuthConfig, ClientPasswordHash, HmacTokenAlgorithm, HmacTokenConfig,
InviteConfig, InviteMode, MfaConfig, PasswordConfig, PasswordProtocol, RefreshTokenConfig,
SignedTokenAlgorithm, SignedTokenConfig, TotpAlgorithm, TotpConfig,
};
use std::error::Error;
use tracing::instrument;
use uuid::Uuid;
use x25519_dalek::{EphemeralSecret, PublicKey};
use saferlmdb::{
self as lmdb, Database, DatabaseOptions, Environment, ReadTransaction, WriteTransaction, put,
};
pub use opaque_ke::ServerSetup;
use totp_rs::{Algorithm, Secret, TOTP};
use crate::{
DefaultCipherSuite,
keys::decrypt_256_bit_key,
recovery::{check_code, consume_code},
token::generate_hmac,
};
use crate::{
keys::{KeyAlg, decrypt_ed25519_pair, generate_256_bit_key, generate_ed25519_pair},
token::extract_signed_no_check,
};
use sha2::{Digest, Sha256};
pub struct Auth {
domain: String,
pub config: AuthConfig,
totp_alg: Algorithm,
access_token_key: [u8; 32],
reset_password_token_key: [u8; 32],
invite_token_key: [u8; 32],
token_signing_key: [u8; 32],
pub token_verifying_key: [u8; 32],
encryption_key: [u8; 32],
env: Arc<Environment>,
account_db: Arc<Database<'static>>,
state_db: Arc<Database<'static>>,
invite_db: Arc<Database<'static>>,
recovery_db: Arc<Database<'static>>,
}
type InviteClaimsValidator = fn(&str, &VectorReader<&[u8]>) -> Result<bool, Box<dyn Error>>;
impl Auth {
#[allow(clippy::too_many_lines)]
pub fn new(
domain: String,
config: Option<AuthConfig>,
encryption_key: [u8; 32],
env: Arc<Environment>,
) -> Result<Self, Box<dyn Error>> {
let account_db = Arc::new(Database::open(
env.clone(),
Some("user"),
&DatabaseOptions::new(lmdb::db::Flags::CREATE),
)?);
let state_db = Arc::new(Database::open(
env.clone(),
Some("auth"),
&DatabaseOptions::new(lmdb::db::Flags::CREATE),
)?);
let invite_db = Arc::new(Database::open(
env.clone(),
Some("invite"),
&DatabaseOptions::new(lmdb::db::Flags::CREATE),
)?);
let recovery_db = Arc::new(Database::open(
env.clone(),
Some("recovery"),
&DatabaseOptions::new(lmdb::db::Flags::CREATE),
)?);
let key_db = Arc::new(Database::open(
env.clone(),
Some("key"),
&DatabaseOptions::new(lmdb::db::Flags::CREATE),
)?);
let mut config = config.unwrap_or_else(|| AuthConfig {
password: PasswordConfig {
protocol: PasswordProtocol::Opaque,
},
mfa: MfaConfig {
totp: TotpConfig {
template: None,
algorithm: TotpAlgorithm::Sha1,
},
},
hmac_token: HmacTokenConfig {
algorithm: HmacTokenAlgorithm::HmacBlake2b256,
rotation: 60 * 60 * 24 * 7,
},
signed_token: SignedTokenConfig {
algorithm: SignedTokenAlgorithm::DsaEd25519,
rotation: 60 * 60 * 24 * 30,
},
refresh_token: RefreshTokenConfig {
lifetime: 60 * 60 * 24 * 7,
},
access_token: AccessTokenConfig {
lifetime: 60 * 60 * 24,
claims: vec![],
},
client_hash: ClientPasswordHash::Sha256,
cookies_enabled: true,
invite: None,
});
config.access_token.claims.sort_by(|a, b| a.idx.cmp(&b.idx));
let txn = WriteTransaction::new(env.clone())?;
let (token_hmac_key, token_signing_key, token_verifying_key) = {
use chacha20poly1305::aead::KeyInit;
let mut access = txn.access();
let mut rng = OsRng;
let cipher = XChaCha20Poly1305::new(&encryption_key.into());
let hmac_key = if let Ok(kid_key) =
access.get::<[u8], [u8]>(&key_db, &[KeyAlg::Blake2SMac256.as_byte()])
{
let (kid, curr_key) = decrypt_256_bit_key(&cipher, kid_key)?;
if let Some(timestamp) = kid.get_timestamp() {
if (timestamp.to_unix().0 + u64::from(config.hmac_token.rotation))
< SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_secs()
{
let (kid_key, new_key, _kid) = generate_256_bit_key(&cipher, &mut rng)?;
access.put(
&key_db,
&[KeyAlg::Blake2SMac256.as_byte()],
&kid_key,
&put::Flags::empty(),
)?;
new_key
} else {
curr_key
}
} else {
curr_key
}
} else {
let (kid_key, hmac_key, _kid) = generate_256_bit_key(&cipher, &mut rng)?;
access.put(
&key_db,
&[KeyAlg::Blake2SMac256.as_byte()],
&kid_key,
&put::Flags::empty(),
)?;
hmac_key
};
let (signing_key, verifying_key) = if let Ok(kid_keys) =
access.get::<[u8], [u8]>(&key_db, &[KeyAlg::Ed25519Pair.as_byte()])
{
let (kid, curr_signing_key, curr_verifying_key) =
decrypt_ed25519_pair(&cipher, kid_keys)?;
if let Some(timestamp) = kid.get_timestamp() {
if (timestamp.to_unix().0 + u64::from(config.signed_token.rotation))
< SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_secs()
{
let (kid_keys, new_signing_key, new_verifying_key, _kid) =
generate_ed25519_pair(&cipher, &mut rng)?;
access.put(
&key_db,
&[KeyAlg::Ed25519Pair.as_byte()],
&kid_keys,
&put::Flags::empty(),
)?;
(new_signing_key, new_verifying_key)
} else {
(curr_signing_key, curr_verifying_key)
}
} else {
(curr_signing_key, curr_verifying_key)
}
} else {
let (kid_keys, signing_key, verifying_key, _kid) =
generate_ed25519_pair(&cipher, &mut rng)?;
access.put(
&key_db,
&[KeyAlg::Ed25519Pair.as_byte()],
&kid_keys,
&put::Flags::empty(),
)?;
(signing_key, verifying_key)
};
(hmac_key, signing_key, verifying_key)
};
txn.commit()?;
let access_token_key: [u8; 32] =
blake2::Blake2sMac256::new_from_slice(&token_hmac_key[..])?
.chain_update(b"access")
.finalize_fixed()
.into();
let reset_password_token_key: [u8; 32] =
blake2::Blake2sMac256::new_from_slice(&token_hmac_key[..])?
.chain_update(b"reset_password")
.finalize_fixed()
.into();
let invite_token_key: [u8; 32] =
blake2::Blake2sMac256::new_from_slice(&token_hmac_key[..])?
.chain_update(b"invite")
.finalize_fixed()
.into();
Ok(Self {
domain,
totp_alg: match &config.mfa.totp.algorithm {
TotpAlgorithm::Sha1 => Algorithm::SHA1,
},
config,
access_token_key,
reset_password_token_key,
invite_token_key,
token_signing_key,
token_verifying_key,
encryption_key,
env,
state_db,
account_db,
invite_db,
recovery_db,
})
}
#[allow(clippy::type_complexity)]
#[instrument(skip(self, payload, existing_account, invite_claims_validator), err)]
pub fn registration_start(
&self,
payload: Bytes,
existing_account: Option<&str>,
invite_claims_validator: Option<InviteClaimsValidator>,
) -> Result<Bytes, Box<dyn Error>> {
if let Some(account_len) = payload.first() {
let account_len = *account_len as usize;
if account_len > payload.len() - (1 + 32) {
return Err("invalid format".into());
}
let account = &payload[1..=account_len];
let account_str = std::str::from_utf8(account)?;
if let Some(existing_account) = existing_account
&& existing_account != account_str
{
return Err("account mismatch".into());
}
tracing::info!(account = %account_str);
let invite_claims = if self.config.invite.is_some() && existing_account.is_none() {
let invite_token = &payload[account_len + 1 + 32..];
let (claims_vec, claims) = self.invite_check(invite_token)?;
if let Some(validator) = invite_claims_validator {
let res = validator(account_str, &claims_vec)?;
if !res {
return Err("failed to validate invite token claims".into());
}
}
Some(claims)
} else {
None
};
let txn = WriteTransaction::new(self.env.clone())?;
let out = {
let mut access = txn.access();
if existing_account.is_none() {
if access.get::<[u8], [u8]>(&self.account_db, account).is_ok() {
return Err("account has already been registered.".into());
}
if access.get::<[u8], [u8]>(&self.state_db, account).is_ok() {
return Err("account has already been registered.".into());
}
}
let client_start = &payload[account_len + 1..account_len + 1 + 32];
let mut rng = OsRng;
let opaque = ServerSetup::<DefaultCipherSuite>::new(&mut rng);
let out = crate::registration::server_start(&opaque, account, client_start)?;
let mut stored = BytesMut::new();
stored.put(&opaque.serialize()[..]);
if let Some(invite_claims) = invite_claims {
stored.put(invite_claims);
}
access.put(&self.state_db, account, &stored[..], &put::Flags::empty())?;
out
};
txn.commit()?;
return Ok(out);
}
Err("invalid payload of len 0".into())
}
#[allow(clippy::type_complexity, clippy::too_many_lines)]
#[instrument(skip(self, payload, existing_account), err)]
pub fn registration_finish(
&self,
payload: Bytes,
existing_account: Option<&str>,
) -> Result<(Bytes, Vec<u8>, Option<Bytes>), Box<dyn Error>> {
if let Some(account_len) = payload.first() {
use chacha20poly1305::aead::KeyInit;
let account_len = *account_len as usize;
if account_len > payload.len() - 2 {
return Err("invalid format".into());
}
let account = &payload[1..=account_len];
let account_str = std::str::from_utf8(account)?;
if let Some(existing_account) = existing_account
&& existing_account != account_str
{
return Err("account mismatch".into());
}
tracing::info!(account = %account_str);
let public_key: [u8; 32] = payload[account_len + 1..account_len + 33].try_into()?;
let client_finish = &payload[account_len + 33..];
let mut password_file = crate::registration::server_finish(client_finish)?;
let cipher = XChaCha20Poly1305::new(&self.encryption_key.into());
let totp_nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let opaque_nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let txn = WriteTransaction::new(self.env.clone())?;
let (mut totp_mfa_secret, recovery_codes, invite_claims) = {
let mut invite_claims = None;
let mut access = txn.access();
let (totp_mfa_secret, recovery_codes) = if existing_account.is_none() {
if access.get::<[u8], [u8]>(&self.account_db, account).is_ok() {
return Err("account has already been registered".into());
}
let totp_mfa_secret = Secret::generate_secret().to_bytes()?;
let mut encrypted_totp_mfa_secret =
match cipher.encrypt(&totp_nonce, totp_mfa_secret.as_ref()) {
Ok(v) => v,
Err(err) => return Err(err.to_string().into()),
};
encrypted_totp_mfa_secret.extend_from_slice(&totp_nonce);
password_file.splice(0..0, encrypted_totp_mfa_secret.iter().copied());
let serialized_opaque_and_invite_claims =
access.get::<[u8], [u8]>(&self.state_db, account)?;
if serialized_opaque_and_invite_claims.len() > 128 {
invite_claims = Some(Bytes::copy_from_slice(
&serialized_opaque_and_invite_claims[128..],
));
}
let mut encrypted_opaque = match cipher
.encrypt(&opaque_nonce, &serialized_opaque_and_invite_claims[0..128])
{
Ok(v) => v,
Err(err) => return Err(err.to_string().into()),
};
encrypted_opaque.extend_from_slice(&opaque_nonce);
password_file.splice(0..0, encrypted_opaque.iter().copied());
let mut empty_claims_builder =
flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let empty_claims_vec = empty_claims_builder.start_vector();
empty_claims_vec.end_vector();
password_file.extend_from_slice(empty_claims_builder.view());
let (stored_recovery_codes, recovery_codes) =
crate::recovery::generate_codes()?;
access.put(
&self.recovery_db,
account,
&stored_recovery_codes[..],
&put::Flags::empty(),
)?;
(totp_mfa_secret, recovery_codes)
} else {
let account_record: &[u8] = access.get(&self.account_db, account)?;
password_file.splice(0..0, account_record[168..228].iter().copied());
let serialized_opaque_and_invite_claims =
access.get::<[u8], [u8]>(&self.state_db, account)?;
if serialized_opaque_and_invite_claims.len() > 128 {
invite_claims = Some(Bytes::copy_from_slice(
&serialized_opaque_and_invite_claims[128..],
));
}
let mut encrypted_opaque = match cipher
.encrypt(&opaque_nonce, &serialized_opaque_and_invite_claims[0..128])
{
Ok(v) => v,
Err(err) => return Err(err.to_string().into()),
};
encrypted_opaque.extend_from_slice(&opaque_nonce);
password_file.splice(0..0, encrypted_opaque.iter().copied());
password_file.extend_from_slice(&account_record[420..]);
(vec![], String::new())
};
access.put(
&self.account_db,
account,
&password_file,
&put::Flags::empty(),
)?;
access.del_key(&self.state_db, account)?;
(totp_mfa_secret, recovery_codes, invite_claims)
};
txn.commit()?;
return if existing_account.is_none() {
let public_key = PublicKey::from(public_key);
let ephemeral_secret = EphemeralSecret::random_from_rng(OsRng);
let ephemeral_public_key = PublicKey::from(&ephemeral_secret);
let shared_secret = ephemeral_secret.diffie_hellman(&public_key);
let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
totp_mfa_secret.extend_from_slice(recovery_codes.as_bytes());
let mut encrypted_secret = match cipher.encrypt(&nonce, totp_mfa_secret.as_ref()) {
Ok(es) => es,
Err(err) => return Err(err.to_string().into()),
};
encrypted_secret.extend_from_slice(&nonce);
encrypted_secret.extend_from_slice(ephemeral_public_key.as_bytes());
Ok((
Bytes::from(encrypted_secret),
account.to_vec(),
invite_claims,
))
} else {
Ok((Bytes::new(), account.to_vec(), invite_claims))
};
}
Err("invalid payload of len 0".into())
}
#[instrument(skip(self, payload), err)]
pub fn login_start(&self, payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
if let Some(account_len) = payload.first() {
let account_len = *account_len as usize;
if account_len > payload.len() - 1 - 96 {
return Err("invalid format".into());
}
let account = &payload[1..=account_len];
tracing::info!(account = %std::str::from_utf8(account)?);
let client_start = &payload[account_len + 1..];
let txn = WriteTransaction::new(self.env.clone())?;
let message = {
use chacha20poly1305::aead::KeyInit;
let mut access = txn.access();
let account_record: &[u8] = access.get(&self.account_db, account)?;
let cipher = XChaCha20Poly1305::new(&self.encryption_key.into());
let nonce = XNonce::from_slice(&account_record[144..168]);
let serialized_opaque = match cipher.decrypt(nonce, &account_record[0..144]) {
Ok(v) => v,
Err(err) => return Err(err.to_string().into()),
};
let opaque =
match ServerSetup::<DefaultCipherSuite>::deserialize(&serialized_opaque) {
Ok(o) => o,
Err(e) => return Err(e.to_string().into()),
};
let (state, message) = crate::login::server_start(
&opaque,
account,
&account_record[228..420],
client_start,
)?;
access.put(&self.state_db, account, &state[..], &put::Flags::empty())?;
message
};
txn.commit()?;
return Ok(message);
}
Err("invalid payload of size 0".into())
}
#[instrument(skip(self, payload, check_mfa), err)]
#[allow(clippy::type_complexity)]
pub fn login_finish(
&self,
payload: Bytes,
check_mfa: bool,
) -> Result<(Bytes, [u8; 32], Option<Bytes>), Box<dyn Error>> {
if let Some(account_len) = payload.first() {
let account_len = *account_len as usize;
if account_len > payload.len() - 2 {
return Err("invalid format".into());
}
let account = &payload[1..=account_len];
tracing::info!(account = %std::str::from_utf8(account)?);
let payload_len = payload.len();
if payload_len != account_len + 1 + 32 + 16 + 24 + 64
&& payload_len != account_len + 1 + 32 + 32 + 16 + 24 + 64
{
return Err("invalid format".into());
}
let encrypted_mfa_code = &payload[account_len + 1..payload_len - 24 - 64];
let mfa_code_nonce = &payload[payload_len - 24 - 64..payload_len - 64];
let client_finish = &payload[payload_len - 64..];
let txn = WriteTransaction::new(self.env.clone())?;
let (encrypted_refresh_token, key, verifier) = {
let mut access = txn.access();
let account_record: &[u8] = access.get(&self.account_db, account)?;
let cipher = {
use chacha20poly1305::aead::KeyInit;
XChaCha20Poly1305::new(&self.encryption_key.into())
};
let nonce = XNonce::from_slice(&account_record[204..228]);
let mfa_secret = match cipher.decrypt(nonce, &account_record[168..204]) {
Ok(v) => v,
Err(err) => return Err(err.to_string().into()),
};
let mfa_hash = if check_mfa {
let mfa_code = TOTP::new(
self.totp_alg,
6,
1,
30,
Secret::Raw(mfa_secret).to_bytes()?,
Some(self.domain.clone()),
std::str::from_utf8(account)?.to_string(),
)?
.generate_current()?;
let mut mfa_input = self.domain.as_bytes().to_vec();
mfa_input.extend_from_slice(account);
mfa_input.extend_from_slice(mfa_code.as_bytes());
match &self.config.client_hash {
ClientPasswordHash::Sha256 => {
let mut hasher = Sha256::new();
hasher.update(&mfa_input);
hasher.finalize().to_vec()
}
}
} else {
vec![0u8; 32]
};
let server_start: &[u8] = access.get(&self.state_db, account)?;
let (encrypted_refresh_token, key, verifier) = crate::login::server_finish(
self,
account,
&mfa_hash[..],
encrypted_mfa_code,
mfa_code_nonce,
client_finish,
server_start,
)?;
access.del_key(&self.state_db, account)?;
(encrypted_refresh_token, key, verifier)
};
txn.commit()?;
return Ok((encrypted_refresh_token, key, verifier));
}
Err("invalid payload size 0".into())
}
#[instrument(skip(self, payload), err)]
pub fn reset_password_login_start(&self, payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
self.login_start(payload)
}
#[instrument(skip(self, payload), err)]
pub fn reset_password_login_finish(&self, payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
use chacha20poly1305::aead::KeyInit;
let (_, key, verifier) = self.login_finish(payload.clone(), true)?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account_str = std::str::from_utf8(&payload[1..=account_len])?;
let mut reset_password_claims =
flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut reset_password_claims_vec = reset_password_claims.start_vector();
let token_uuid = Uuid::now_v7();
let token_uuid_bytes = token_uuid.as_bytes();
reset_password_claims_vec.push(flexbuffers::Blob(&token_uuid_bytes[..]));
reset_password_claims_vec.push(self.domain.as_str());
reset_password_claims_vec.push(account_str);
if let Some(verifier) = verifier {
reset_password_claims_vec.push(flexbuffers::Blob(&verifier[..]));
}
reset_password_claims_vec.end_vector();
let password_reset_token = generate_hmac(
reset_password_claims.view(),
60 * 15,
self.reset_password_token_key,
)?;
let cipher = XChaCha20Poly1305::new(&key.into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let mut encrypted_password_reset_token =
match cipher.encrypt(&nonce, password_reset_token.as_ref()) {
Ok(et) => et,
Err(err) => return Err(err.to_string().into()),
};
encrypted_password_reset_token.extend_from_slice(&nonce);
Ok(Bytes::copy_from_slice(&encrypted_password_reset_token))
}
#[instrument(skip(self, payload, token), err)]
pub fn reset_password_registration_start(
&self,
payload: Bytes,
token: &[u8],
) -> Result<Bytes, Box<dyn Error>> {
let account_str = self.verify_password_reset_token(token)?;
self.registration_start(payload, Some(account_str), None)
}
#[instrument(skip(self, payload), err)]
pub fn reset_password_registration_finish(
&self,
payload: Bytes,
token: &[u8],
) -> Result<(), Box<dyn Error>> {
let account_str = self.verify_password_reset_token(token)?;
self.registration_finish(payload, Some(account_str))?;
Ok(())
}
#[instrument(skip(self, payload), err)]
pub fn forgot_password_start(&self, mut payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
let payload_len = payload.len();
if payload_len < 33 {
return Err("invalid payload".into());
}
let recovery_code = payload.split_off(payload_len - 32);
let account_len = (*payload.first().expect("payload already checked")) as usize;
if account_len > payload.len() - 1 {
return Err("invalid format".into());
}
let account = Bytes::copy_from_slice(&payload[1..=account_len]);
let account_str = std::str::from_utf8(&account[..])?;
let txn = ReadTransaction::new(self.env.clone())?;
let access = txn.access();
let read_recovery_codes = access.get(&self.recovery_db, &account[..])?;
let is_valid_recovery_code = check_code(&recovery_code[..], read_recovery_codes)?;
if !is_valid_recovery_code {
return Err("invalid recovery code".into());
}
self.registration_start(payload, Some(account_str), None)
}
#[instrument(skip(self, payload), err)]
pub fn forgot_password_finish(&self, mut payload: Bytes) -> Result<(), Box<dyn Error>> {
let payload_len = payload.len();
if payload_len < 12 {
return Err("invalid payload".into());
}
let recovery_code = payload.split_off(payload_len - 11);
let account_len = (*payload.first().expect("payload already checked")) as usize;
if account_len > payload.len() - 1 {
return Err("invalid format".into());
}
let account = Bytes::copy_from_slice(&payload[1..=account_len]);
let account_str = std::str::from_utf8(&account[..])?;
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
let read_recovery_codes = access.get(&self.recovery_db, &account[..])?;
let (is_valid_recovery_code, stored_codes) =
consume_code(&recovery_code[..], read_recovery_codes)?;
if is_valid_recovery_code {
access.put(
&self.recovery_db,
&account[..],
&stored_codes[..],
&put::Flags::empty(),
)?;
} else {
return Err("invalid recovery code".into());
}
}
txn.commit()?;
self.registration_finish(payload, Some(account_str))?;
Ok(())
}
#[instrument(skip(self, payload), err)]
pub fn reset_totp_mfa_start(&self, payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
self.login_start(payload)
}
#[instrument(skip(self, payload), err)]
pub fn reset_totp_mfa_finish(&self, payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
use chacha20poly1305::aead::KeyInit;
let (_, key, _) = self.login_finish(payload.clone(), true)?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account = &payload[1..=account_len];
let txn = WriteTransaction::new(self.env.clone())?;
let mfa_secret = {
let mut access = txn.access();
let mfa_secret = Secret::generate_secret().to_bytes()?;
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let cipher = XChaCha20Poly1305::new(&self.encryption_key.into());
let mut encrypted_mfa_secret = match cipher.encrypt(&nonce, mfa_secret.as_ref()) {
Ok(v) => v,
Err(err) => return Err(err.to_string().into()),
};
encrypted_mfa_secret.extend_from_slice(&nonce);
let account_record: &[u8] = access.get(&self.account_db, account)?;
let mut account_record = account_record.to_vec();
account_record.splice(168..228, encrypted_mfa_secret);
access.put(
&self.account_db,
account,
&account_record,
&put::Flags::empty(),
)?;
mfa_secret
};
txn.commit()?;
let cipher = XChaCha20Poly1305::new(&key.into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let mut encrypted_mfa_secret = match cipher.encrypt(&nonce, mfa_secret.as_ref()) {
Ok(es) => es,
Err(err) => return Err(err.to_string().into()),
};
encrypted_mfa_secret.extend_from_slice(&nonce);
Ok(Bytes::copy_from_slice(&encrypted_mfa_secret))
}
#[instrument(skip(self, payload), err)]
pub fn lost_totp_mfa_start(&self, mut payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
let payload_len = payload.len();
if payload_len < 33 {
return Err("invalid payload".into());
}
let recovery_code = payload.split_off(payload_len - 32);
let server_message = self.login_start(payload.clone())?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account = &payload[1..=account_len];
let txn = ReadTransaction::new(self.env.clone())?;
let access = txn.access();
let read_recovery_codes = access.get(&self.recovery_db, account)?;
let is_valid_recovery_code = check_code(&recovery_code[..], read_recovery_codes)?;
if !is_valid_recovery_code {
return Err("invalid recovery code".into());
}
Ok(server_message)
}
#[instrument(skip(self, payload), err)]
pub fn lost_totp_mfa_finish(&self, mut payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
use chacha20poly1305::aead::KeyInit;
let payload_len = payload.len();
if payload_len < 12 {
return Err("invalid payload".into());
}
let recovery_code = payload.split_off(payload_len - 11);
let (_, key, _) = self.login_finish(payload.clone(), false)?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account = &payload[1..=account_len];
let txn = WriteTransaction::new(self.env.clone())?;
let mfa_secret = {
let mut access = txn.access();
let read_recovery_codes = access.get(&self.recovery_db, account)?;
let (is_valid_recovery_code, stored_codes) =
consume_code(&recovery_code[..], read_recovery_codes)?;
if is_valid_recovery_code {
access.put(
&self.recovery_db,
account,
&stored_codes[..],
&put::Flags::empty(),
)?;
let mfa_secret = Secret::generate_secret().to_bytes()?;
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let cipher = XChaCha20Poly1305::new(&self.encryption_key.into());
let mut encrypted_mfa_secret = match cipher.encrypt(&nonce, mfa_secret.as_ref()) {
Ok(v) => v,
Err(err) => return Err(err.to_string().into()),
};
encrypted_mfa_secret.extend_from_slice(&nonce);
let account_record: &[u8] = access.get(&self.account_db, account)?;
let mut account_record = account_record.to_vec();
account_record.splice(168..228, encrypted_mfa_secret);
access.put(
&self.account_db,
account,
&account_record,
&put::Flags::empty(),
)?;
mfa_secret
} else {
return Err("invalid recovery code".into());
}
};
txn.commit()?;
let cipher = XChaCha20Poly1305::new(&key.into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let mut encrypted_mfa_secret = match cipher.encrypt(&nonce, mfa_secret.as_ref()) {
Ok(es) => es,
Err(err) => return Err(err.to_string().into()),
};
encrypted_mfa_secret.extend_from_slice(&nonce);
Ok(Bytes::copy_from_slice(&encrypted_mfa_secret))
}
#[instrument(skip(self, payload), err)]
pub fn reset_recovery_codes_start(&self, payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
self.login_start(payload)
}
#[instrument(skip(self, payload), err)]
pub fn reset_recovery_codes_finish(&self, payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
use chacha20poly1305::aead::KeyInit;
let (_, key, _) = self.login_finish(payload.clone(), true)?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account = &payload[1..=account_len];
let (stored_recovery_codes, recovery_codes) = crate::recovery::generate_codes()?;
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
access.put(
&self.recovery_db,
account,
&stored_recovery_codes[..],
&put::Flags::empty(),
)?;
}
txn.commit()?;
let cipher = XChaCha20Poly1305::new(&key.into());
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let mut encrypted_recovery_codes = match cipher.encrypt(&nonce, recovery_codes.as_bytes()) {
Ok(rc) => rc,
Err(err) => return Err(err.to_string().into()),
};
encrypted_recovery_codes.extend_from_slice(&nonce);
Ok(Bytes::copy_from_slice(&encrypted_recovery_codes))
}
#[instrument(skip(self, payload), err)]
pub fn delete_account_start(&self, payload: Bytes) -> Result<Bytes, Box<dyn Error>> {
self.login_start(payload)
}
#[instrument(skip(self, payload), err)]
pub fn delete_account_finish(&self, payload: Bytes) -> Result<(), Box<dyn Error>> {
self.login_finish(payload.clone(), true)?;
let account_len = (*payload.first().expect("payload already checked")) as usize;
let account = &payload[1..=account_len];
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
access.del_key(&self.account_db, account)?;
}
txn.commit()?;
Ok(())
}
#[instrument(skip(self), err)]
pub fn list_accounts(&self) -> Result<Bytes, Box<dyn Error>> {
let txn = ReadTransaction::new(self.env.clone())?;
let access = txn.access();
let mut cursor = txn.cursor(self.account_db.clone())?;
let (first_account, first_user) = cursor.first::<[u8], [u8]>(&access)?;
let mut builder = flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut builder_vec = builder.start_vector();
let mut first_user_vec = builder_vec.start_vector();
first_user_vec.push(first_account);
if first_user.len() > 420 {
let claims_reader = flexbuffers::Reader::get_root(&first_user[420..])?;
for claims_field in &self.config.access_token.claims {
claims_field.kind.copy_to(
&claims_reader.as_vector().idx(claims_field.idx as usize),
&mut first_user_vec,
None,
)?;
}
}
first_user_vec.end_vector();
while let Ok((account, user)) = cursor.next::<[u8], [u8]>(&access) {
let mut user_vec = builder_vec.start_vector();
user_vec.push(account);
if user.len() > 420 {
let claims_reader = flexbuffers::Reader::get_root(&user[420..])?;
for claims_field in &self.config.access_token.claims {
claims_field.kind.copy_to(
&claims_reader.as_vector().idx(claims_field.idx as usize),
&mut user_vec,
None,
)?;
}
}
user_vec.end_vector();
}
builder_vec.end_vector();
Ok(Bytes::copy_from_slice(builder.view()))
}
#[instrument(skip(self, claims, account), err)]
pub fn set_claims(&self, account: &[u8], claims: &[u8]) -> Result<(), Box<dyn Error>> {
let str_account = std::str::from_utf8(account)?;
let claims_root = flexbuffers::Reader::get_root(claims)?;
let mut claims_builder =
flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut dest = claims_builder.start_vector();
let system_claims_gap = dest.start_vector();
system_claims_gap.end_vector();
let claims_vec = claims_root.as_vector();
for claim in &self.config.access_token.claims {
let src = claims_vec.idx(claim.idx as usize);
claim.kind.copy_to(&src, &mut dest, None)?;
}
dest.end_vector();
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
let account_record: &[u8] = access.get(&self.account_db, account)?;
let mut account_record = account_record.to_vec();
if account_record.len() > 420 {
account_record.truncate(420);
}
account_record.extend_from_slice(claims_builder.view());
access.put(
&self.account_db,
account,
&account_record[..],
&put::Flags::empty(),
)?;
}
txn.commit()?;
tracing::info!(account = %str_account);
Ok(())
}
fn create_invite(
&self,
invite_config: &InviteConfig,
account: &str,
custom_claims: Option<&[u8]>,
) -> Result<Result<Bytes, Box<dyn Error>>, Box<dyn Error>> {
let mut invite_claims = flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut dest = invite_claims.start_vector();
let mut invite_claims_vec = dest.start_vector();
let token_uuid = Uuid::now_v7();
let token_uuid_bytes = token_uuid.as_bytes();
tracing::info!(account = %account, token = %token_uuid);
invite_claims_vec.push(flexbuffers::Blob(&token_uuid_bytes[..]));
invite_claims_vec.push(self.domain.as_str());
invite_claims_vec.push(account);
invite_claims_vec.end_vector();
if let Some(config_claims) = &invite_config.claims
&& let Some(custom_claims) = custom_claims
{
let custom_claims_root = flexbuffers::Reader::get_root(custom_claims)?;
let custom_claims = custom_claims_root.as_vector();
for claim in config_claims {
claim
.kind
.copy_to(&custom_claims.idx(claim.idx as usize), &mut dest, None)?;
}
}
dest.end_vector();
Ok(match &invite_config.mode {
InviteMode::Viral => {
let invite_token = generate_hmac(
invite_claims.view(),
invite_config.lifetime,
self.invite_token_key,
)?;
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
access.put(
&self.invite_db,
&token_uuid_bytes[..],
&[] as &[u8],
&put::Flags::empty(),
)?;
}
txn.commit()?;
Ok(invite_token)
}
})
}
#[instrument(skip(self, domain, account), err)]
pub fn admin_invite_get(
&self,
domain: &str,
account: &str,
custom_claims: Option<&[u8]>,
) -> Result<Bytes, Box<dyn Error>> {
if let Some(invite_config) = &self.config.invite {
let account = format!("{account}@{domain}");
self.create_invite(invite_config, account.as_str(), custom_claims)?
} else {
Err("no invite config".into())
}
}
#[instrument(skip(self, access_token), err)]
pub fn invite_get(&self, access_token: Bytes) -> Result<Bytes, Box<dyn Error>> {
if let Some(invite_config) = &self.config.invite {
let (account, _) = self.verify_access_token(&access_token)?;
self.create_invite(invite_config, account, None)?
} else {
Err("no invite config".into())
}
}
#[allow(clippy::type_complexity)]
#[instrument(skip(self, invite_token), err)]
pub fn invite_check<'a>(
&self,
invite_token: &'a [u8],
) -> Result<(VectorReader<&'a [u8]>, &'a [u8]), Box<dyn Error>> {
if let Some(invite_config) = &self.config.invite {
let claims = crate::token::verify_hmac(invite_token, self.invite_token_key)?;
let claims_root = flexbuffers::Reader::get_root(claims)?;
let claims_vec = claims_root.as_vector();
let system_claims = claims_vec.idx(0).as_vector();
let token_uuid_bytes: [u8; 16] = system_claims.idx(0).as_blob().0.try_into()?;
let claims_uuid_str = Uuid::from_bytes(token_uuid_bytes).to_string();
let claims_domain = system_claims.idx(1).as_str();
let claims_account = system_claims.idx(2).as_str();
if claims_domain == self.domain {
match invite_config.mode {
InviteMode::Viral => {
let txn = WriteTransaction::new(self.env.clone())?;
{
let mut access = txn.access();
if access
.get::<[u8], [u8]>(&self.invite_db, &token_uuid_bytes[..])
.is_err()
{
return Err("token id does not exist".into());
}
access.del_key(&self.invite_db, &token_uuid_bytes[..])?;
}
txn.commit()?;
}
}
tracing::info!(account = claims_account, token = claims_uuid_str);
Ok((claims_vec, claims))
} else {
Err(format!(
"domain {claims_domain} does not match for invite {claims_uuid_str} from user {claims_account}"
)
.into())
}
} else {
Err("no invite config".into())
}
}
pub fn access_get(&self, refresh_token: &Bytes) -> Result<Bytes, Box<dyn Error>> {
let span = tracing::info_span!("auth");
let span = span.in_scope(|| tracing::info_span!("token"));
let span = span.in_scope(|| tracing::info_span!("access"));
let span = span.in_scope(|| tracing::info_span!("get"));
span.in_scope(|| {
let (account, client_verifying_key) = self.verify_refresh_token(&refresh_token[..])?;
let mut claims_builder =
flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut dest = claims_builder.start_vector();
let txn = ReadTransaction::new(self.env.clone())?;
let access = txn.access();
let account_record: &[u8] = access.get(&self.account_db, account.as_bytes())?;
let user_claims = flexbuffers::Reader::get_root(&account_record[420..])?.as_vector();
let mut system_claims = dest.start_vector();
let token_id = Uuid::new_v4();
system_claims.push(flexbuffers::Blob(&token_id.as_bytes()[..]));
system_claims.push(self.domain.as_str());
system_claims.push(account);
if let Some(client_verifying_key) = client_verifying_key {
system_claims.push(flexbuffers::Blob(&client_verifying_key[..]));
}
system_claims.end_vector();
for claim in &self.config.access_token.claims {
let src = user_claims.idx(claim.idx as usize);
claim.kind.copy_to(&src, &mut dest, None)?;
}
dest.end_vector();
let access_token = generate_hmac(
claims_builder.view(),
self.config.access_token.lifetime,
self.access_token_key,
)?;
tracing::info!(token = %token_id, "generated");
Ok(access_token)
})
}
#[allow(clippy::type_complexity)]
pub fn verify_access_token<'a>(
&self,
token: &'a [u8],
) -> Result<(&'a str, VectorReader<&'a [u8]>), Box<dyn Error>> {
let span = tracing::info_span!("auth");
let span = span.in_scope(|| tracing::info_span!("token"));
let span = span.in_scope(|| tracing::info_span!("access"));
let span = span.in_scope(|| tracing::info_span!("verify"));
span.in_scope(|| {
let claims: &'a [u8] = crate::token::extract_hmac_no_check(token)?;
let claims_vec = match flexbuffers::Reader::get_root(
&claims[..claims.len().checked_sub(4 + 64).unwrap_or(claims.len())],
) {
Ok(v) => v.as_vector(),
Err(_) => flexbuffers::Reader::get_root(claims)?.as_vector(),
};
let system_claims = claims_vec.idx(0).as_vector();
let token_uuid_bytes: [u8; 16] = system_claims.idx(0).as_blob().0.try_into()?;
let claims_uuid_str = Uuid::from_bytes(token_uuid_bytes).to_string();
let claims_domain = system_claims.idx(1).as_str();
let claims_account = system_claims.idx(2).as_str();
if claims_domain == self.domain {
let claims_verifier = system_claims.idx(3).as_blob().0;
if claims_verifier.is_empty() {
crate::token::verify_hmac(token, self.access_token_key)?;
} else {
crate::token::verify_client_signature(token, claims_verifier.try_into()?)?;
crate::token::verify_hmac(&token[..token.len() - (4 + 64)], self.access_token_key)?;
}
tracing::info!(
account = %claims_account,
token = %claims_uuid_str,
"verified"
);
Ok((claims_account, claims_vec))
} else {
Err(format!(
"domain {claims_domain} does not match for user {claims_account} with token {claims_uuid_str}"
)
.into())
}
})
}
pub(crate) fn generate_refresh_token(
&self,
account: &[u8],
verifier_claim: Option<&[u8]>,
) -> Result<Bytes, Box<dyn Error>> {
let span = tracing::info_span!("auth");
let span = span.in_scope(|| tracing::info_span!("token"));
let span = span.in_scope(|| tracing::info_span!("refresh"));
let span = span.in_scope(|| tracing::info_span!("generate"));
span.in_scope(|| {
let str_account = std::str::from_utf8(account)?;
let mut claims_builder =
flexbuffers::Builder::new(&flexbuffers::BuilderOptions::SHARE_NONE);
let mut system_claims = claims_builder.start_vector();
let token_id = Uuid::new_v4();
system_claims.push(flexbuffers::Blob(&token_id.as_bytes()[..]));
system_claims.push(self.domain.as_str());
system_claims.push(str_account);
if let Some(verifier_claim) = verifier_claim {
system_claims.push(flexbuffers::Blob(verifier_claim));
}
system_claims.end_vector();
let refresh_token = crate::token::generate_signed(
claims_builder.view(),
self.config.refresh_token.lifetime,
self.token_signing_key,
)?;
tracing::info!(token = token_id.to_string(), "generated");
Ok(refresh_token)
})
}
#[allow(clippy::type_complexity)]
fn verify_refresh_token<'a>(
&self,
token: &'a [u8],
) -> Result<(&'a str, Option<[u8; 32]>), Box<dyn Error>> {
let span = tracing::info_span!("auth");
let span = span.in_scope(|| tracing::info_span!("token"));
let span = span.in_scope(|| tracing::info_span!("refresh"));
let span = span.in_scope(|| tracing::info_span!("verify"));
span.in_scope(|| {
let claims = extract_signed_no_check(token)?;
let claims_vec = match flexbuffers::Reader::get_root(
&claims[..claims.len().checked_sub(4 + 64).unwrap_or(claims.len())],
) {
Ok(v) => v.as_vector(),
Err(_) => flexbuffers::Reader::get_root(claims)?.as_vector(),
};
let token_uuid_bytes: [u8; 16] = claims_vec.idx(0).as_blob().0.try_into()?;
let claims_uuid_str = Uuid::from_bytes(token_uuid_bytes).to_string();
let claims_domain = claims_vec.idx(1).as_str();
let claims_account = claims_vec.idx(2).as_str();
let claims_verifier = claims_vec.idx(3).as_blob().0;
let mut client_verifying_key = None;
if claims_verifier.is_empty() {
crate::token::verify_signed(token, self.token_verifying_key)?;
} else {
let verifying_key_bytes = claims_verifier.try_into()?;
client_verifying_key = Some(verifying_key_bytes);
crate::token::verify_client_signature(token, verifying_key_bytes)?;
crate::token::verify_signed(
&token[..token.len() - (4 + 64)],
self.token_verifying_key,
)?;
}
tracing::info!(
domain = %claims_domain,
account = %claims_account,
token = %claims_uuid_str,
"verified"
);
Ok((claims_account, client_verifying_key))
})
}
#[instrument(skip(self, token) err)]
pub fn verify_password_reset_token<'a>(
&self,
token: &'a [u8],
) -> Result<&'a str, Box<dyn Error>> {
let claims: &'a [u8] = crate::token::extract_hmac_no_check(token)?;
let claims_vec = match flexbuffers::Reader::get_root(
&claims[..claims.len().checked_sub(4 + 64).unwrap_or(claims.len())],
) {
Ok(v) => v.as_vector(),
Err(_) => flexbuffers::Reader::get_root(claims)?.as_vector(),
};
let token_uuid_bytes: [u8; 16] = claims_vec.idx(0).as_blob().0.try_into()?;
let claims_uuid_str = Uuid::from_bytes(token_uuid_bytes).to_string();
let claims_domain = claims_vec.idx(1).as_str();
let claims_account = claims_vec.idx(2).as_str();
if claims_domain == self.domain {
let claims_verifier = claims_vec.idx(3).as_blob().0;
if claims_verifier.is_empty() {
crate::token::verify_hmac(token, self.reset_password_token_key)?;
} else {
crate::token::verify_client_signature(token, claims_verifier.try_into()?)?;
crate::token::verify_hmac(
&token[..token.len() - (4 + 64)],
self.reset_password_token_key,
)?;
}
tracing::info!(
account = %claims_account,
token = %claims_uuid_str,
"verified"
);
Ok(claims_account)
} else {
Err(format!(
"domain {claims_domain} does not match for user {claims_account} with token {claims_uuid_str}"
)
.into())
}
}
}