use sha2::{Digest, Sha256};
use std::time::{SystemTime, UNIX_EPOCH};
use totp_rs::{Algorithm, TOTP};
#[derive(Debug, Clone)]
pub struct AppConfig {
pub name: String,
pub user_identifier: String,
pub issuer: String,
}
impl AppConfig {
pub fn new(name: &str, user_identifier: &str) -> Self {
Self {
name: name.to_string(),
user_identifier: user_identifier.to_string(),
issuer: name.to_string(),
}
}
pub fn with_issuer(name: &str, user_identifier: &str, issuer: &str) -> Self {
Self {
name: name.to_string(),
user_identifier: user_identifier.to_string(),
issuer: issuer.to_string(),
}
}
}
pub struct Kiwavi {
totp: TOTP,
totp_secret: Vec<u8>,
derivation_seed: [u8; 32],
wrong_value_seed: [u8; 32],
app_config: AppConfig,
}
#[derive(Debug, PartialEq)]
pub enum ValidationResult {
Success([u8; 32]),
Invalid([u8; 32]),
}
impl ValidationResult {
pub fn value(&self) -> [u8; 32] {
match self {
ValidationResult::Success(val) | ValidationResult::Invalid(val) => *val,
}
}
pub fn hex(&self) -> String {
bytes_to_hex(&self.value())
}
pub fn is_valid(&self) -> bool {
matches!(self, ValidationResult::Success(_))
}
}
pub trait SaltInput {
fn as_bytes(&self) -> &[u8];
}
impl SaltInput for &str {
fn as_bytes(&self) -> &[u8] {
str::as_bytes(self)
}
}
impl SaltInput for String {
fn as_bytes(&self) -> &[u8] {
self.as_str().as_bytes()
}
}
impl SaltInput for &[u8] {
fn as_bytes(&self) -> &[u8] {
self
}
}
impl SaltInput for Vec<u8> {
fn as_bytes(&self) -> &[u8] {
self.as_slice()
}
}
impl<const N: usize> SaltInput for [u8; N] {
fn as_bytes(&self) -> &[u8] {
self.as_slice()
}
}
impl<const N: usize> SaltInput for &[u8; N] {
fn as_bytes(&self) -> &[u8] {
self.as_slice()
}
}
impl Kiwavi {
pub fn from_prf_salt(prf_salt: [u8; 32], app_config: AppConfig) -> Result<Self, KiwaviError> {
Self::new(prf_salt, app_config)
}
pub fn from_hex(hex_salt: &str, app_config: AppConfig) -> Result<Self, KiwaviError> {
let bytes = hex_to_bytes_vec(hex_salt)?;
Self::new(bytes, app_config)
}
pub fn from_base64(b64_salt: &str, app_config: AppConfig) -> Result<Self, KiwaviError> {
let cleaned = b64_salt.replace(&['+', '/', '='][..], "");
if cleaned.chars().all(|c| c.is_ascii_alphanumeric()) {
Self::new(b64_salt, app_config)
} else {
Err(KiwaviError::InvalidInput(
"Invalid base64 string".to_string(),
))
}
}
pub fn new<T: SaltInput>(user_salt: T, app_config: AppConfig) -> Result<Self, KiwaviError> {
let totp_secret = Self::derive_totp_secret(user_salt.as_bytes(), &app_config);
let totp = TOTP::new(
Algorithm::SHA1,
6, 1, 30, totp_secret.clone(),
)
.map_err(|e| KiwaviError::TotpError(e.to_string()))?;
let mut hasher = Sha256::new();
hasher.update(user_salt.as_bytes());
hasher.update(app_config.name.as_bytes());
hasher.update(b"kiwavi_derivation_v1");
let derivation_seed = hasher.finalize().into();
let mut hasher = Sha256::new();
hasher.update(user_salt.as_bytes());
hasher.update(b"kiwavi_wrong_values_v1");
let wrong_value_seed = hasher.finalize().into();
Ok(Kiwavi {
totp,
totp_secret,
derivation_seed,
wrong_value_seed,
app_config,
})
}
pub fn validate_totp_code(totp_secret_b32: &str, code: &str) -> Result<bool, KiwaviError> {
let decoded_secret = data_encoding::BASE32
.decode(totp_secret_b32.as_bytes())
.map_err(|e| {
KiwaviError::InvalidInput(format!("Failed to decode base32 secret: {}", e))
})?;
let totp = TOTP::new(Algorithm::SHA1, 6, 1, 30, decoded_secret)
.map_err(|e| KiwaviError::TotpError(e.to_string()))?;
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| KiwaviError::InvalidInput("Time error".to_string()))?
.as_secs();
Ok(totp.check(code, current_time))
}
pub fn validate_and_derive(&self, totp_code: &str) -> ValidationResult {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let valid_code = self.totp.generate(current_time);
if totp_code == valid_code {
ValidationResult::Success(self.derive_correct_value())
} else {
ValidationResult::Invalid(self.generate_wrong_value(totp_code))
}
}
pub fn get_setup_qr(&self) -> String {
let secret_b32 = self.totp.get_secret_base32();
let label = format!(
"{}:{}",
self.app_config.name, self.app_config.user_identifier
);
format!(
"otpauth://totp/{}?secret={}&issuer={}&algorithm=SHA1&digits=6&period=30",
urlencoding::encode(&label),
secret_b32,
urlencoding::encode(&self.app_config.issuer)
)
}
pub fn get_totp_secret(&self) -> String {
data_encoding::BASE32.encode(&self.totp_secret)
}
pub fn preview_derived_value(&self) -> String {
bytes_to_hex(&self.derive_correct_value())
}
pub fn get_current_code(&self) -> String {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
self.totp.generate(current_time)
}
fn derive_totp_secret(user_salt: &[u8], app_config: &AppConfig) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(user_salt);
hasher.update(app_config.name.as_bytes());
hasher.update(app_config.user_identifier.as_bytes());
hasher.update(b"kiwavi_totp_secret_v1");
hasher.finalize().to_vec()
}
fn derive_correct_value(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(self.derivation_seed);
hasher.update(self.totp.get_secret_base32().as_bytes());
hasher.update(b"correct_derivation");
hasher.finalize().into()
}
fn generate_wrong_value(&self, wrong_code: &str) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(self.wrong_value_seed);
hasher.update(wrong_code.as_bytes());
let time_window = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
/ 30;
hasher.update(time_window.to_be_bytes());
hasher.finalize().into()
}
}
#[derive(Debug)]
pub enum KiwaviError {
TotpError(String),
InvalidHexString(String),
InvalidInput(String),
}
impl std::fmt::Display for KiwaviError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KiwaviError::TotpError(msg) => write!(f, "TOTP error: {msg}"),
KiwaviError::InvalidHexString(msg) => write!(f, "Invalid hex string: {msg}"),
KiwaviError::InvalidInput(msg) => write!(f, "Invalid input: {msg}"),
}
}
}
impl std::error::Error for KiwaviError {}
fn hex_to_bytes_vec(hex: &str) -> Result<Vec<u8>, KiwaviError> {
let hex = hex.trim_start_matches("0x");
if hex.len() % 2 != 0 {
return Err(KiwaviError::InvalidHexString(
"Hex string must have even length".to_string(),
));
}
let mut bytes = Vec::with_capacity(hex.len() / 2);
for chunk in hex.as_bytes().chunks(2) {
if chunk.len() == 2 {
let hex_byte = std::str::from_utf8(chunk).map_err(|_| {
KiwaviError::InvalidHexString("Invalid UTF-8 in hex string".to_string())
})?;
let byte = u8::from_str_radix(hex_byte, 16).map_err(|_| {
KiwaviError::InvalidHexString(format!("Invalid hex byte: {hex_byte}"))
})?;
bytes.push(byte);
}
}
Ok(bytes)
}
fn bytes_to_hex(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02X}")).collect::<String>()
}