use crate::base32::Base32;
use crate::hmac::HmacBuilder;
use crate::types::SHAFamily;
use std::time::{SystemTime, UNIX_EPOCH};
use url::Url;
const DEFAULT_PERIOD: u32 = 30;
const DEFAULT_DIGITS: u32 = 6;
#[derive(Debug)]
pub struct Error(String);
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for Error {}
pub async fn generate_hotp(
secret: &str,
counter: u64,
digits: Option<u32>,
hash: Option<SHAFamily>,
) -> Result<String, Error> {
let digits = digits.unwrap_or(DEFAULT_DIGITS);
if digits < 1 || digits > 8 {
return Err(Error("Digits must be between 1 and 8".to_string()));
}
let mut buffer = [0u8; 8];
buffer.copy_from_slice(&counter.to_be_bytes());
let hmac = HmacBuilder::new(hash, None);
let hmac_result = hmac
.sign(secret.as_bytes(), &buffer)
.map_err(|e| Error(e.to_string()))?;
let offset = hmac_result[hmac_result.len() - 1] & 0x0f;
let truncated = ((hmac_result[offset as usize] & 0x7f) as u32) << 24
| ((hmac_result[(offset + 1) as usize] & 0xff) as u32) << 16
| ((hmac_result[(offset + 2) as usize] & 0xff) as u32) << 8
| (hmac_result[(offset + 3) as usize] & 0xff) as u32;
let otp = truncated % 10u32.pow(digits);
Ok(format!("{:0width$}", otp, width = digits as usize))
}
pub async fn generate_totp(
secret: &str,
period: Option<u32>,
digits: Option<u32>,
hash: Option<SHAFamily>,
) -> Result<String, Error> {
let period = period.unwrap_or(DEFAULT_PERIOD);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| Error(e.to_string()))?
.as_secs();
let counter = now / period as u64;
generate_hotp(secret, counter, digits, hash).await
}
pub async fn verify_totp(
otp: &str,
secret: &str,
window: Option<i32>,
digits: Option<u32>,
period: Option<u32>,
) -> Result<bool, Error> {
let window = window.unwrap_or(1);
let period = period.unwrap_or(DEFAULT_PERIOD);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| Error(e.to_string()))?
.as_secs();
let counter = now / period as u64;
for i in -window..=window {
let current_counter = counter
.checked_add_signed(i as i64)
.ok_or_else(|| Error("Counter overflow".to_string()))?;
let generated_otp = generate_hotp(secret, current_counter, digits, None).await?;
if otp == generated_otp {
return Ok(true);
}
}
Ok(false)
}
pub fn generate_qr_code(
issuer: &str,
account: &str,
secret: &str,
digits: Option<u32>,
period: Option<u32>,
) -> Result<String, Error> {
let digits = digits.unwrap_or(DEFAULT_DIGITS);
let period = period.unwrap_or(DEFAULT_PERIOD);
let base_uri = format!(
"otpauth://totp/{}:{}",
urlencoding::encode(issuer),
urlencoding::encode(account)
);
let mut url = Url::parse(&base_uri).map_err(|e| Error(e.to_string()))?;
let encoded_secret = Base32::encode(secret.as_bytes(), Some(false));
url.query_pairs_mut()
.append_pair("secret", &encoded_secret)
.append_pair("issuer", issuer)
.append_pair("digits", &digits.to_string())
.append_pair("period", &period.to_string());
Ok(url.to_string())
}
pub struct OTP {
secret: String,
digits: u32,
period: u32,
}
impl OTP {
pub fn new(secret: &str, digits: Option<u32>, period: Option<u32>) -> Self {
Self {
secret: secret.to_string(),
digits: digits.unwrap_or(DEFAULT_DIGITS),
period: period.unwrap_or(DEFAULT_PERIOD),
}
}
pub async fn hotp(&self, counter: u64) -> Result<String, Error> {
generate_hotp(&self.secret, counter, Some(self.digits), None).await
}
pub async fn totp(&self) -> Result<String, Error> {
generate_totp(&self.secret, Some(self.period), Some(self.digits), None).await
}
pub async fn verify(&self, otp: &str, window: Option<i32>) -> Result<bool, Error> {
verify_totp(
otp,
&self.secret,
window,
Some(self.digits),
Some(self.period),
)
.await
}
pub fn url(&self, issuer: &str, account: &str) -> Result<String, Error> {
generate_qr_code(
issuer,
account,
&self.secret,
Some(self.digits),
Some(self.period),
)
}
}