use bon::Builder;
#[cfg(feature = "auth")]
use miette::Diagnostic;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "auth")]
use thiserror::Error;
use crate::{base::Base, counter::Counter};
#[cfg(feature = "auth")]
use crate::{
auth::{query::Query, url::Url},
base, counter,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Builder)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Hotp<'h> {
#[cfg_attr(feature = "serde", serde(flatten))]
pub base: Base<'h>,
#[builder(default)]
#[cfg_attr(feature = "serde", serde(default))]
pub counter: Counter,
}
impl<'h> Hotp<'h> {
pub const fn base(&self) -> &Base<'h> {
&self.base
}
pub fn base_mut(&mut self) -> &mut Base<'h> {
&mut self.base
}
pub fn into_base(self) -> Base<'h> {
self.base
}
}
impl Hotp<'_> {
pub const fn counter(&self) -> u64 {
self.counter.get()
}
pub fn try_increment(&mut self) -> bool {
if let Some(next) = self.counter.try_next() {
self.counter = next;
true
} else {
false
}
}
pub fn increment(&mut self) {
self.counter = self.counter.next();
}
pub fn generate(&self) -> u32 {
self.base.generate(self.counter())
}
pub fn generate_string(&self) -> String {
self.base.generate_string(self.counter())
}
pub fn verify(&self, code: u32) -> bool {
self.base.verify(self.counter(), code)
}
pub fn verify_string<S: AsRef<str>>(&self, code: S) -> bool {
self.base.verify_string(self.counter(), code)
}
}
#[cfg(feature = "auth")]
pub const COUNTER: &str = "counter";
#[cfg(feature = "auth")]
#[derive(Debug, Error, Diagnostic)]
#[error("failed to find counter")]
#[diagnostic(code(otp_std::hotp::counter), help("make sure the counter is present"))]
pub struct CounterNotFoundError;
#[cfg(feature = "auth")]
#[derive(Debug, Error, Diagnostic)]
#[error(transparent)]
#[diagnostic(transparent)]
pub enum ErrorSource {
Base(#[from] base::Error),
CounterNotFound(#[from] CounterNotFoundError),
Counter(#[from] counter::Error),
}
#[cfg(feature = "auth")]
#[derive(Debug, Error, Diagnostic)]
#[error("failed to extract HOTP from OTP URL")]
#[diagnostic(
code(otp_std::hotp::extract),
help("see the report for more information")
)]
pub struct Error {
#[source]
#[diagnostic_source]
pub source: ErrorSource,
}
#[cfg(feature = "auth")]
impl Error {
pub const fn new(source: ErrorSource) -> Self {
Self { source }
}
pub fn base(error: base::Error) -> Self {
Self::new(error.into())
}
pub fn counter_not_found(error: CounterNotFoundError) -> Self {
Self::new(error.into())
}
pub fn new_counter_not_found() -> Self {
Self::counter_not_found(CounterNotFoundError)
}
pub fn counter(error: counter::Error) -> Self {
Self::new(error.into())
}
}
#[cfg(feature = "auth")]
impl Hotp<'_> {
pub fn query_for(&self, url: &mut Url) {
self.base.query_for(url);
let counter = self.counter.to_string();
url.query_pairs_mut().append_pair(COUNTER, counter.as_str());
}
pub fn extract_from(query: &mut Query<'_>) -> Result<Self, Error> {
let base = Base::extract_from(query).map_err(Error::base)?;
let counter = query
.remove(COUNTER)
.ok_or_else(Error::new_counter_not_found)?
.parse()
.map_err(Error::counter)?;
let hotp = Self::builder().base(base).counter(counter).build();
Ok(hotp)
}
}
pub type Owned = Hotp<'static>;
impl Hotp<'_> {
pub fn into_owned(self) -> Owned {
Owned::builder()
.base(self.base.into_owned())
.counter(self.counter)
.build()
}
}