use std::array;
use bon::Builder;
use constant_time_eq::constant_time_eq;
#[cfg(feature = "auth")]
use miette::Diagnostic;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "auth")]
use thiserror::Error;
use crate::{algorithm::Algorithm, digits::Digits, secret::core::Secret};
#[cfg(feature = "auth")]
use crate::{
algorithm,
auth::{query::Query, url::Url},
digits, secret,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Builder)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Base<'b> {
pub secret: Secret<'b>,
#[builder(default)]
#[cfg_attr(feature = "serde", serde(default))]
pub algorithm: Algorithm,
#[builder(default)]
#[cfg_attr(feature = "serde", serde(default))]
pub digits: Digits,
}
pub const MASK: u32 = 0x7FFF_FFFF;
pub const HALF_BYTE: u8 = 0xF;
impl Base<'_> {
pub fn generate(&self, input: u64) -> u32 {
let hmac = self
.algorithm
.hmac(self.secret.as_ref(), input.to_be_bytes());
let offset = (hmac.last().unwrap() & HALF_BYTE) as usize;
let bytes = array::from_fn(|index| hmac[offset + index]);
let value = u32::from_be_bytes(bytes) & MASK;
value % self.digits.power()
}
pub fn generate_string(&self, input: u64) -> String {
self.digits.string(self.generate(input))
}
pub fn verify(&self, input: u64, code: u32) -> bool {
self.generate(input) == code
}
pub fn verify_string<S: AsRef<str>>(&self, input: u64, code: S) -> bool {
constant_time_eq(
self.generate_string(input).as_bytes(),
code.as_ref().as_bytes(),
)
}
}
#[cfg(feature = "auth")]
pub const SECRET: &str = "secret";
#[cfg(feature = "auth")]
pub const ALGORITHM: &str = "algorithm";
#[cfg(feature = "auth")]
pub const DIGITS: &str = "digits";
#[cfg(feature = "auth")]
#[derive(Debug, Error, Diagnostic)]
#[error("failed to find secret")]
#[diagnostic(code(otp_std::base::secret), help("make sure the secret is present"))]
pub struct SecretNotFoundError;
#[cfg(feature = "auth")]
#[derive(Debug, Error, Diagnostic)]
#[error(transparent)]
#[diagnostic(transparent)]
pub enum ErrorSource {
SecretNotFound(#[from] SecretNotFoundError),
Secret(#[from] secret::core::Error),
Algorithm(#[from] algorithm::Error),
Digits(#[from] digits::ParseError),
}
#[cfg(feature = "auth")]
#[derive(Debug, Error, Diagnostic)]
#[error("failed to extract base from OTP URL")]
#[diagnostic(
code(otp_std::base::extract),
help("see the report for more information")
)]
pub struct Error {
#[source]
#[diagnostic_source]
pub source: ErrorSource,
}
#[cfg(feature = "auth")]
impl Error {
pub const fn new(source: ErrorSource) -> Self {
Self { source }
}
pub fn secret_not_found(error: SecretNotFoundError) -> Self {
Self::new(error.into())
}
pub fn new_secret_not_found() -> Self {
Self::secret_not_found(SecretNotFoundError)
}
pub fn secret(error: secret::core::Error) -> Self {
Self::new(error.into())
}
pub fn algorithm(error: algorithm::Error) -> Self {
Self::new(error.into())
}
pub fn digits(error: digits::ParseError) -> Self {
Self::new(error.into())
}
}
#[cfg(feature = "auth")]
impl Base<'_> {
pub fn query_for(&self, url: &mut Url) {
let secret = self.secret.encode();
let algorithm = self.algorithm.static_str();
let digits = self.digits.to_string();
url.query_pairs_mut()
.append_pair(SECRET, secret.as_str())
.append_pair(ALGORITHM, algorithm)
.append_pair(DIGITS, digits.as_str());
}
pub fn extract_from(query: &mut Query<'_>) -> Result<Self, Error> {
let secret = query
.remove(SECRET)
.ok_or_else(Error::new_secret_not_found)?
.parse()
.map_err(Error::secret)?;
let maybe_algorithm = query
.remove(ALGORITHM)
.map(|string| string.parse())
.transpose()
.map_err(Error::algorithm)?;
let maybe_digits = query
.remove(DIGITS)
.map(|string| string.parse())
.transpose()
.map_err(Error::digits)?;
let base = Self::builder()
.secret(secret)
.maybe_algorithm(maybe_algorithm)
.maybe_digits(maybe_digits)
.build();
Ok(base)
}
}
pub type Owned = Base<'static>;
impl Base<'_> {
pub fn into_owned(self) -> Owned {
Owned::builder()
.secret(self.secret.into_owned())
.algorithm(self.algorithm)
.digits(self.digits)
.build()
}
}