1use bon::Builder;
4
5#[cfg(feature = "auth")]
6use miette::Diagnostic;
7
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11#[cfg(feature = "auth")]
12use thiserror::Error;
13
14use crate::{base::Base, counter::Counter};
15
16#[cfg(feature = "auth")]
17use crate::{
18 auth::{query::Query, url::Url},
19 base, counter,
20};
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash, Builder)]
24#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
25pub struct Hotp<'h> {
26 #[cfg_attr(feature = "serde", serde(flatten))]
28 pub base: Base<'h>,
29 #[builder(default)]
31 #[cfg_attr(feature = "serde", serde(default))]
32 pub counter: Counter,
33}
34
35impl<'h> Hotp<'h> {
36 pub const fn base(&self) -> &Base<'h> {
38 &self.base
39 }
40
41 pub fn base_mut(&mut self) -> &mut Base<'h> {
43 &mut self.base
44 }
45
46 pub fn into_base(self) -> Base<'h> {
48 self.base
49 }
50}
51
52impl Hotp<'_> {
53 pub const fn counter(&self) -> u64 {
55 self.counter.get()
56 }
57
58 pub fn try_increment(&mut self) -> bool {
60 if let Some(next) = self.counter.try_next() {
61 self.counter = next;
62
63 true
64 } else {
65 false
66 }
67 }
68
69 pub fn increment(&mut self) {
75 self.counter = self.counter.next();
76 }
77
78 pub fn generate(&self) -> u32 {
80 self.base.generate(self.counter())
81 }
82
83 pub fn generate_string(&self) -> String {
85 self.base.generate_string(self.counter())
86 }
87
88 pub fn verify(&self, code: u32) -> bool {
90 self.base.verify(self.counter(), code)
91 }
92
93 pub fn verify_string<S: AsRef<str>>(&self, code: S) -> bool {
95 self.base.verify_string(self.counter(), code)
96 }
97}
98
99#[cfg(feature = "auth")]
101pub const COUNTER: &str = "counter";
102
103#[cfg(feature = "auth")]
105#[derive(Debug, Error, Diagnostic)]
106#[error("failed to find counter")]
107#[diagnostic(code(otp_std::hotp::counter), help("make sure the counter is present"))]
108pub struct CounterNotFoundError;
109
110#[cfg(feature = "auth")]
112#[derive(Debug, Error, Diagnostic)]
113#[error(transparent)]
114#[diagnostic(transparent)]
115pub enum ErrorSource {
116 Base(#[from] base::Error),
118 CounterNotFound(#[from] CounterNotFoundError),
120 Counter(#[from] counter::Error),
122}
123
124#[cfg(feature = "auth")]
126#[derive(Debug, Error, Diagnostic)]
127#[error("failed to extract HOTP from OTP URL")]
128#[diagnostic(
129 code(otp_std::hotp::extract),
130 help("see the report for more information")
131)]
132pub struct Error {
133 #[source]
135 #[diagnostic_source]
136 pub source: ErrorSource,
137}
138
139#[cfg(feature = "auth")]
140impl Error {
141 pub const fn new(source: ErrorSource) -> Self {
143 Self { source }
144 }
145
146 pub fn base(error: base::Error) -> Self {
148 Self::new(error.into())
149 }
150
151 pub fn counter_not_found(error: CounterNotFoundError) -> Self {
153 Self::new(error.into())
154 }
155
156 pub fn new_counter_not_found() -> Self {
158 Self::counter_not_found(CounterNotFoundError)
159 }
160
161 pub fn counter(error: counter::Error) -> Self {
163 Self::new(error.into())
164 }
165}
166
167#[cfg(feature = "auth")]
168impl Hotp<'_> {
169 pub fn query_for(&self, url: &mut Url) {
171 self.base.query_for(url);
172
173 let counter = self.counter.to_string();
174
175 url.query_pairs_mut().append_pair(COUNTER, counter.as_str());
176 }
177
178 pub fn extract_from(query: &mut Query<'_>) -> Result<Self, Error> {
184 let base = Base::extract_from(query).map_err(Error::base)?;
185
186 let counter = query
187 .remove(COUNTER)
188 .ok_or_else(Error::new_counter_not_found)?
189 .parse()
190 .map_err(Error::counter)?;
191
192 let hotp = Self::builder().base(base).counter(counter).build();
193
194 Ok(hotp)
195 }
196}
197
198pub type Owned = Hotp<'static>;
200
201impl Hotp<'_> {
202 pub fn into_owned(self) -> Owned {
204 Owned::builder()
205 .base(self.base.into_owned())
206 .counter(self.counter)
207 .build()
208 }
209}