Skip to main content

openauth_core/
verification.rs

1//! Database-backed verification token/value helpers.
2
3use time::OffsetDateTime;
4
5use crate::crypto::random::generate_random_string;
6use crate::db::{
7    Create, DbAdapter, DbRecord, DbValue, Delete, DeleteMany, FindMany, Sort, SortDirection,
8    Update, Verification, Where, WhereOperator,
9};
10use crate::error::OpenAuthError;
11
12const VERIFICATION_MODEL: &str = "verification";
13const DEFAULT_ID_LENGTH: usize = 32;
14const VERIFICATION_FIELDS: [&str; 6] = [
15    "id",
16    "identifier",
17    "value",
18    "expires_at",
19    "created_at",
20    "updated_at",
21];
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct CreateVerificationInput {
25    pub id: Option<String>,
26    pub identifier: String,
27    pub value: String,
28    pub expires_at: OffsetDateTime,
29}
30
31impl CreateVerificationInput {
32    pub fn new(
33        identifier: impl Into<String>,
34        value: impl Into<String>,
35        expires_at: OffsetDateTime,
36    ) -> Self {
37        Self {
38            id: None,
39            identifier: identifier.into(),
40            value: value.into(),
41            expires_at,
42        }
43    }
44
45    #[must_use]
46    pub fn id(mut self, id: impl Into<String>) -> Self {
47        self.id = Some(id.into());
48        self
49    }
50}
51
52#[derive(Debug, Clone, Default, PartialEq, Eq)]
53pub struct UpdateVerificationInput {
54    pub value: Option<String>,
55    pub expires_at: Option<OffsetDateTime>,
56}
57
58impl UpdateVerificationInput {
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    #[must_use]
64    pub fn value(mut self, value: impl Into<String>) -> Self {
65        self.value = Some(value.into());
66        self
67    }
68
69    #[must_use]
70    pub fn expires_at(mut self, expires_at: OffsetDateTime) -> Self {
71        self.expires_at = Some(expires_at);
72        self
73    }
74}
75
76#[derive(Clone, Copy)]
77pub struct DbVerificationStore<'a> {
78    adapter: &'a dyn DbAdapter,
79}
80
81impl<'a> DbVerificationStore<'a> {
82    pub fn new(adapter: &'a dyn DbAdapter) -> Self {
83        Self { adapter }
84    }
85
86    pub async fn create_verification(
87        &self,
88        input: CreateVerificationInput,
89    ) -> Result<Verification, OpenAuthError> {
90        let now = OffsetDateTime::now_utc();
91        let id = input
92            .id
93            .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH));
94
95        let record = self
96            .adapter
97            .create(
98                Create::new(VERIFICATION_MODEL)
99                    .data("id", DbValue::String(id))
100                    .data("identifier", DbValue::String(input.identifier))
101                    .data("value", DbValue::String(input.value))
102                    .data("expires_at", DbValue::Timestamp(input.expires_at))
103                    .data("created_at", DbValue::Timestamp(now))
104                    .data("updated_at", DbValue::Timestamp(now))
105                    .select(VERIFICATION_FIELDS)
106                    .force_allow_id(),
107            )
108            .await?;
109
110        verification_from_record(record)
111    }
112
113    pub async fn find_verification(
114        &self,
115        identifier: &str,
116    ) -> Result<Option<Verification>, OpenAuthError> {
117        self.delete_expired_verifications().await?;
118
119        let Some(record) = self
120            .adapter
121            .find_many(
122                FindMany::new(VERIFICATION_MODEL)
123                    .where_clause(identifier_where(identifier))
124                    .sort_by(Sort::new("created_at", SortDirection::Desc))
125                    .limit(1)
126                    .select(VERIFICATION_FIELDS),
127            )
128            .await?
129            .into_iter()
130            .next()
131        else {
132            return Ok(None);
133        };
134
135        let verification = verification_from_record(record)?;
136        if verification.expires_at <= OffsetDateTime::now_utc() {
137            self.delete_expired_verifications().await?;
138            return Ok(None);
139        }
140
141        Ok(Some(verification))
142    }
143
144    pub async fn find_verification_including_expired(
145        &self,
146        identifier: &str,
147    ) -> Result<Option<Verification>, OpenAuthError> {
148        self.adapter
149            .find_many(
150                FindMany::new(VERIFICATION_MODEL)
151                    .where_clause(identifier_where(identifier))
152                    .sort_by(Sort::new("created_at", SortDirection::Desc))
153                    .limit(1)
154                    .select(VERIFICATION_FIELDS),
155            )
156            .await?
157            .into_iter()
158            .next()
159            .map(verification_from_record)
160            .transpose()
161    }
162
163    pub async fn update_verification(
164        &self,
165        identifier: &str,
166        input: UpdateVerificationInput,
167    ) -> Result<Option<Verification>, OpenAuthError> {
168        let mut query = Update::new(VERIFICATION_MODEL).where_clause(identifier_where(identifier));
169
170        if let Some(value) = input.value {
171            query = query.data("value", DbValue::String(value));
172        }
173        if let Some(expires_at) = input.expires_at {
174            query = query.data("expires_at", DbValue::Timestamp(expires_at));
175        }
176        query = query.data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));
177
178        self.adapter
179            .update(query)
180            .await?
181            .map(verification_from_record)
182            .transpose()
183    }
184
185    pub async fn delete_verification(&self, identifier: &str) -> Result<(), OpenAuthError> {
186        self.adapter
187            .delete(Delete::new(VERIFICATION_MODEL).where_clause(identifier_where(identifier)))
188            .await
189    }
190
191    pub async fn delete_expired_verifications(&self) -> Result<u64, OpenAuthError> {
192        self.adapter
193            .delete_many(
194                DeleteMany::new(VERIFICATION_MODEL).where_clause(
195                    Where::new("expires_at", DbValue::Timestamp(OffsetDateTime::now_utc()))
196                        .operator(WhereOperator::Lt),
197                ),
198            )
199            .await
200    }
201}
202
203fn identifier_where(identifier: &str) -> Where {
204    Where::new("identifier", DbValue::String(identifier.to_owned()))
205}
206
207fn verification_from_record(record: DbRecord) -> Result<Verification, OpenAuthError> {
208    Ok(Verification {
209        id: required_string(&record, "id")?.to_owned(),
210        identifier: required_string(&record, "identifier")?.to_owned(),
211        value: required_string(&record, "value")?.to_owned(),
212        expires_at: required_timestamp(&record, "expires_at")?,
213        created_at: required_timestamp(&record, "created_at")?,
214        updated_at: required_timestamp(&record, "updated_at")?,
215    })
216}
217
218fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, OpenAuthError> {
219    match record.get(field) {
220        Some(DbValue::String(value)) => Ok(value),
221        Some(_) => Err(invalid_field(field, "string")),
222        None => Err(missing_field(field)),
223    }
224}
225
226fn required_timestamp(record: &DbRecord, field: &str) -> Result<OffsetDateTime, OpenAuthError> {
227    match record.get(field) {
228        Some(DbValue::Timestamp(value)) => Ok(*value),
229        Some(_) => Err(invalid_field(field, "timestamp")),
230        None => Err(missing_field(field)),
231    }
232}
233
234fn missing_field(field: &str) -> OpenAuthError {
235    OpenAuthError::Adapter(format!("verification record is missing `{field}`"))
236}
237
238fn invalid_field(field: &str, expected: &str) -> OpenAuthError {
239    OpenAuthError::Adapter(format!(
240        "verification record field `{field}` must be {expected}"
241    ))
242}