1use time::OffsetDateTime;
4
5use crate::crypto::random::generate_random_string;
6use crate::db::{
7 Create, DbAdapter, DbRecord, DbValue, Delete, DeleteMany, FindMany, Sort, SortDirection,
8 Update, Verification, Where, WhereOperator,
9};
10use crate::error::OpenAuthError;
11
12const VERIFICATION_MODEL: &str = "verification";
13const DEFAULT_ID_LENGTH: usize = 32;
14const VERIFICATION_FIELDS: [&str; 6] = [
15 "id",
16 "identifier",
17 "value",
18 "expires_at",
19 "created_at",
20 "updated_at",
21];
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct CreateVerificationInput {
25 pub id: Option<String>,
26 pub identifier: String,
27 pub value: String,
28 pub expires_at: OffsetDateTime,
29}
30
31impl CreateVerificationInput {
32 pub fn new(
33 identifier: impl Into<String>,
34 value: impl Into<String>,
35 expires_at: OffsetDateTime,
36 ) -> Self {
37 Self {
38 id: None,
39 identifier: identifier.into(),
40 value: value.into(),
41 expires_at,
42 }
43 }
44
45 #[must_use]
46 pub fn id(mut self, id: impl Into<String>) -> Self {
47 self.id = Some(id.into());
48 self
49 }
50}
51
52#[derive(Debug, Clone, Default, PartialEq, Eq)]
53pub struct UpdateVerificationInput {
54 pub value: Option<String>,
55 pub expires_at: Option<OffsetDateTime>,
56}
57
58impl UpdateVerificationInput {
59 pub fn new() -> Self {
60 Self::default()
61 }
62
63 #[must_use]
64 pub fn value(mut self, value: impl Into<String>) -> Self {
65 self.value = Some(value.into());
66 self
67 }
68
69 #[must_use]
70 pub fn expires_at(mut self, expires_at: OffsetDateTime) -> Self {
71 self.expires_at = Some(expires_at);
72 self
73 }
74}
75
76#[derive(Clone, Copy)]
77pub struct DbVerificationStore<'a> {
78 adapter: &'a dyn DbAdapter,
79}
80
81impl<'a> DbVerificationStore<'a> {
82 pub fn new(adapter: &'a dyn DbAdapter) -> Self {
83 Self { adapter }
84 }
85
86 pub async fn create_verification(
87 &self,
88 input: CreateVerificationInput,
89 ) -> Result<Verification, OpenAuthError> {
90 let now = OffsetDateTime::now_utc();
91 let id = input
92 .id
93 .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH));
94
95 let record = self
96 .adapter
97 .create(
98 Create::new(VERIFICATION_MODEL)
99 .data("id", DbValue::String(id))
100 .data("identifier", DbValue::String(input.identifier))
101 .data("value", DbValue::String(input.value))
102 .data("expires_at", DbValue::Timestamp(input.expires_at))
103 .data("created_at", DbValue::Timestamp(now))
104 .data("updated_at", DbValue::Timestamp(now))
105 .select(VERIFICATION_FIELDS)
106 .force_allow_id(),
107 )
108 .await?;
109
110 verification_from_record(record)
111 }
112
113 pub async fn find_verification(
114 &self,
115 identifier: &str,
116 ) -> Result<Option<Verification>, OpenAuthError> {
117 self.delete_expired_verifications().await?;
118
119 let Some(record) = self
120 .adapter
121 .find_many(
122 FindMany::new(VERIFICATION_MODEL)
123 .where_clause(identifier_where(identifier))
124 .sort_by(Sort::new("created_at", SortDirection::Desc))
125 .limit(1)
126 .select(VERIFICATION_FIELDS),
127 )
128 .await?
129 .into_iter()
130 .next()
131 else {
132 return Ok(None);
133 };
134
135 let verification = verification_from_record(record)?;
136 if verification.expires_at <= OffsetDateTime::now_utc() {
137 self.delete_expired_verifications().await?;
138 return Ok(None);
139 }
140
141 Ok(Some(verification))
142 }
143
144 pub async fn find_verification_including_expired(
145 &self,
146 identifier: &str,
147 ) -> Result<Option<Verification>, OpenAuthError> {
148 self.adapter
149 .find_many(
150 FindMany::new(VERIFICATION_MODEL)
151 .where_clause(identifier_where(identifier))
152 .sort_by(Sort::new("created_at", SortDirection::Desc))
153 .limit(1)
154 .select(VERIFICATION_FIELDS),
155 )
156 .await?
157 .into_iter()
158 .next()
159 .map(verification_from_record)
160 .transpose()
161 }
162
163 pub async fn update_verification(
164 &self,
165 identifier: &str,
166 input: UpdateVerificationInput,
167 ) -> Result<Option<Verification>, OpenAuthError> {
168 let mut query = Update::new(VERIFICATION_MODEL).where_clause(identifier_where(identifier));
169
170 if let Some(value) = input.value {
171 query = query.data("value", DbValue::String(value));
172 }
173 if let Some(expires_at) = input.expires_at {
174 query = query.data("expires_at", DbValue::Timestamp(expires_at));
175 }
176 query = query.data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
177
178 self.adapter
179 .update(query)
180 .await?
181 .map(verification_from_record)
182 .transpose()
183 }
184
185 pub async fn delete_verification(&self, identifier: &str) -> Result<(), OpenAuthError> {
186 self.adapter
187 .delete(Delete::new(VERIFICATION_MODEL).where_clause(identifier_where(identifier)))
188 .await
189 }
190
191 pub async fn delete_expired_verifications(&self) -> Result<u64, OpenAuthError> {
192 self.adapter
193 .delete_many(
194 DeleteMany::new(VERIFICATION_MODEL).where_clause(
195 Where::new("expires_at", DbValue::Timestamp(OffsetDateTime::now_utc()))
196 .operator(WhereOperator::Lt),
197 ),
198 )
199 .await
200 }
201}
202
203fn identifier_where(identifier: &str) -> Where {
204 Where::new("identifier", DbValue::String(identifier.to_owned()))
205}
206
207fn verification_from_record(record: DbRecord) -> Result<Verification, OpenAuthError> {
208 Ok(Verification {
209 id: required_string(&record, "id")?.to_owned(),
210 identifier: required_string(&record, "identifier")?.to_owned(),
211 value: required_string(&record, "value")?.to_owned(),
212 expires_at: required_timestamp(&record, "expires_at")?,
213 created_at: required_timestamp(&record, "created_at")?,
214 updated_at: required_timestamp(&record, "updated_at")?,
215 })
216}
217
218fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, OpenAuthError> {
219 match record.get(field) {
220 Some(DbValue::String(value)) => Ok(value),
221 Some(_) => Err(invalid_field(field, "string")),
222 None => Err(missing_field(field)),
223 }
224}
225
226fn required_timestamp(record: &DbRecord, field: &str) -> Result<OffsetDateTime, OpenAuthError> {
227 match record.get(field) {
228 Some(DbValue::Timestamp(value)) => Ok(*value),
229 Some(_) => Err(invalid_field(field, "timestamp")),
230 None => Err(missing_field(field)),
231 }
232}
233
234fn missing_field(field: &str) -> OpenAuthError {
235 OpenAuthError::Adapter(format!("verification record is missing `{field}`"))
236}
237
238fn invalid_field(field: &str, expected: &str) -> OpenAuthError {
239 OpenAuthError::Adapter(format!(
240 "verification record field `{field}` must be {expected}"
241 ))
242}