Skip to main content

allowthem_core/
invitations.rs

1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use sha2::{Digest, Sha256};
6
7use crate::db::Db;
8use crate::error::AuthError;
9use crate::types::{Email, InvitationId, UserId};
10
11/// A single-use invitation token record.
12///
13/// Returned by `Db::create_invitation` and `Db::validate_invitation`.
14/// The `token_hash` is never exposed — only the raw token (returned once
15/// at creation) can be used to validate.
16#[derive(Debug, Clone, sqlx::FromRow)]
17pub struct Invitation {
18    pub id: InvitationId,
19    pub email: Option<Email>,
20    pub metadata: Option<String>,
21    pub invited_by: Option<UserId>,
22    pub expires_at: DateTime<Utc>,
23    pub consumed_at: Option<DateTime<Utc>>,
24    pub created_at: DateTime<Utc>,
25}
26
27/// Generate a cryptographically random invitation token.
28///
29/// Fills 32 bytes from the OS random source and encodes as base64url without
30/// padding, producing a 43-character string suitable for inclusion in a URL.
31fn generate_invitation_token() -> String {
32    let mut bytes = [0u8; 32];
33    OsRng
34        .try_fill_bytes(&mut bytes)
35        .expect("OS RNG unavailable");
36    Base64UrlUnpadded::encode_string(&bytes)
37}
38
39/// Hash a raw invitation token with SHA-256.
40///
41/// Returns the hex-encoded digest. This is what is stored in the database.
42/// The raw token is only ever shown once, at creation time.
43fn hash_invitation_token(token: &str) -> String {
44    let digest = Sha256::digest(token.as_bytes());
45    format!("{digest:x}")
46}
47
48impl Db {
49    /// Create an invitation. Returns the raw token (shown once) and the
50    /// `Invitation` record.
51    ///
52    /// If `email` is `Some`, the invitation is targeted at that address.
53    /// If `email` is `None`, it is an open invitation usable by anyone.
54    pub async fn create_invitation(
55        &self,
56        email: Option<&Email>,
57        metadata: Option<&str>,
58        invited_by: Option<UserId>,
59        expires_at: DateTime<Utc>,
60    ) -> Result<(String, Invitation), AuthError> {
61        let id = InvitationId::new();
62        let raw_token = generate_invitation_token();
63        let token_hash = hash_invitation_token(&raw_token);
64        let now = Utc::now();
65        let now_str = now.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
66        let expires_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
67
68        sqlx::query(
69            "INSERT INTO allowthem_invitations \
70             (id, email, token_hash, metadata, invited_by, expires_at, created_at) \
71             VALUES (?, ?, ?, ?, ?, ?, ?)",
72        )
73        .bind(id)
74        .bind(email)
75        .bind(&token_hash)
76        .bind(metadata)
77        .bind(invited_by)
78        .bind(&expires_str)
79        .bind(&now_str)
80        .execute(self.pool())
81        .await
82        .map_err(AuthError::Database)?;
83
84        let inv = Invitation {
85            id,
86            email: email.cloned(),
87            metadata: metadata.map(String::from),
88            invited_by,
89            expires_at,
90            consumed_at: None,
91            created_at: now,
92        };
93
94        Ok((raw_token, inv))
95    }
96
97    /// Mark an invitation as consumed.
98    ///
99    /// Uses `consumed_at IS NULL` as a concurrency guard. Returns `Ok(())`
100    /// on success. Returns `Err(AuthError::NotFound)` if the ID does not
101    /// exist. Returns `Err(AuthError::Gone)` if already consumed — the
102    /// caller should treat this as a race loss.
103    pub async fn consume_invitation(&self, id: InvitationId) -> Result<(), AuthError> {
104        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
105
106        let result = sqlx::query(
107            "UPDATE allowthem_invitations SET consumed_at = ? \
108             WHERE id = ? AND consumed_at IS NULL",
109        )
110        .bind(&now)
111        .bind(id)
112        .execute(self.pool())
113        .await
114        .map_err(AuthError::Database)?;
115
116        if result.rows_affected() == 0 {
117            // Distinguish "does not exist" from "already consumed".
118            let exists: bool = sqlx::query_scalar(
119                "SELECT EXISTS(SELECT 1 FROM allowthem_invitations WHERE id = ?)",
120            )
121            .bind(id)
122            .fetch_one(self.pool())
123            .await
124            .map_err(AuthError::Database)?;
125
126            return Err(if exists {
127                AuthError::Gone
128            } else {
129                AuthError::NotFound
130            });
131        }
132
133        Ok(())
134    }
135
136    /// List unconsumed, non-expired invitations, newest first.
137    pub async fn list_pending_invitations(&self) -> Result<Vec<Invitation>, AuthError> {
138        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
139
140        sqlx::query_as::<_, Invitation>(
141            "SELECT id, email, metadata, invited_by, expires_at, consumed_at, created_at \
142             FROM allowthem_invitations \
143             WHERE consumed_at IS NULL AND expires_at > ? \
144             ORDER BY created_at DESC",
145        )
146        .bind(&now)
147        .fetch_all(self.pool())
148        .await
149        .map_err(AuthError::Database)
150    }
151
152    /// Delete an invitation outright, whether pending or consumed.
153    pub async fn delete_invitation(&self, id: InvitationId) -> Result<(), AuthError> {
154        let result = sqlx::query("DELETE FROM allowthem_invitations WHERE id = ?")
155            .bind(id)
156            .execute(self.pool())
157            .await
158            .map_err(AuthError::Database)?;
159
160        if result.rows_affected() == 0 {
161            return Err(AuthError::NotFound);
162        }
163
164        Ok(())
165    }
166
167    /// Validate a raw invitation token.
168    ///
169    /// Returns `Some(Invitation)` if the token exists, is not expired, and has
170    /// not been consumed. Returns `None` otherwise. The caller is responsible
171    /// for checking email match on targeted invitations.
172    pub async fn validate_invitation(
173        &self,
174        raw_token: &str,
175    ) -> Result<Option<Invitation>, AuthError> {
176        let token_hash = hash_invitation_token(raw_token);
177        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
178
179        sqlx::query_as::<_, Invitation>(
180            "SELECT id, email, metadata, invited_by, expires_at, consumed_at, created_at \
181             FROM allowthem_invitations \
182             WHERE token_hash = ? AND expires_at > ? AND consumed_at IS NULL",
183        )
184        .bind(&token_hash)
185        .bind(&now)
186        .fetch_optional(self.pool())
187        .await
188        .map_err(AuthError::Database)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use chrono::{Duration, Utc};
195
196    use crate::db::Db;
197    use crate::error::AuthError;
198    use crate::types::Email;
199
200    async fn test_db() -> Db {
201        Db::connect("sqlite::memory:").await.expect("in-memory db")
202    }
203
204    #[tokio::test]
205    async fn create_invitation_returns_raw_token_and_invitation() {
206        let db = test_db().await;
207        let email = Email::new("invite@example.com".to_string()).unwrap();
208        let expires = Utc::now() + Duration::hours(24);
209
210        let (raw_token, inv) = db
211            .create_invitation(Some(&email), Some(r#"{"role":"editor"}"#), None, expires)
212            .await
213            .expect("create_invitation");
214
215        assert!(!raw_token.is_empty());
216        assert_eq!(inv.email.as_ref().unwrap().as_str(), "invite@example.com");
217        assert_eq!(inv.metadata.as_deref(), Some(r#"{"role":"editor"}"#));
218        assert!(inv.invited_by.is_none());
219        assert!(inv.consumed_at.is_none());
220    }
221
222    #[tokio::test]
223    async fn create_open_invitation_has_no_email() {
224        let db = test_db().await;
225        let expires = Utc::now() + Duration::hours(24);
226
227        let (_raw, inv) = db
228            .create_invitation(None, None, None, expires)
229            .await
230            .expect("create open invitation");
231
232        assert!(inv.email.is_none());
233        assert!(inv.metadata.is_none());
234    }
235
236    #[tokio::test]
237    async fn validate_returns_invitation_for_valid_token() {
238        let db = test_db().await;
239        let email = Email::new("v@example.com".to_string()).unwrap();
240        let expires = Utc::now() + Duration::hours(24);
241        let (raw, _) = db
242            .create_invitation(Some(&email), Some("{}"), None, expires)
243            .await
244            .unwrap();
245
246        let inv = db.validate_invitation(&raw).await.expect("validate");
247        assert!(inv.is_some());
248        let inv = inv.unwrap();
249        assert_eq!(inv.email.as_ref().unwrap().as_str(), "v@example.com");
250    }
251
252    #[tokio::test]
253    async fn validate_returns_none_for_garbage_token() {
254        let db = test_db().await;
255        let result = db
256            .validate_invitation("not-a-real-token")
257            .await
258            .expect("validate");
259        assert!(result.is_none());
260    }
261
262    #[tokio::test]
263    async fn validate_returns_none_for_expired_token() {
264        let db = test_db().await;
265        let expires = Utc::now() - Duration::hours(1);
266        let (raw, _) = db
267            .create_invitation(None, None, None, expires)
268            .await
269            .unwrap();
270
271        let result = db.validate_invitation(&raw).await.expect("validate");
272        assert!(result.is_none(), "expired invitation must return None");
273    }
274
275    #[tokio::test]
276    async fn consume_marks_invitation_consumed() {
277        let db = test_db().await;
278        let expires = Utc::now() + Duration::hours(24);
279        let (raw, inv) = db
280            .create_invitation(None, None, None, expires)
281            .await
282            .unwrap();
283
284        db.consume_invitation(inv.id).await.expect("consume");
285
286        // Validation must now return None.
287        let result = db.validate_invitation(&raw).await.expect("validate");
288        assert!(result.is_none(), "consumed invitation must not validate");
289    }
290
291    #[tokio::test]
292    async fn consume_twice_returns_gone() {
293        let db = test_db().await;
294        let expires = Utc::now() + Duration::hours(24);
295        let (_, inv) = db
296            .create_invitation(None, None, None, expires)
297            .await
298            .unwrap();
299
300        db.consume_invitation(inv.id).await.expect("first consume");
301
302        let err = db
303            .consume_invitation(inv.id)
304            .await
305            .expect_err("second consume should fail");
306        assert!(
307            matches!(err, AuthError::Gone),
308            "expected AuthError::Gone, got {err:?}"
309        );
310    }
311
312    #[tokio::test]
313    async fn list_pending_excludes_expired_and_consumed() {
314        let db = test_db().await;
315        let future = Utc::now() + Duration::hours(24);
316        let past = Utc::now() - Duration::hours(1);
317
318        // Pending (should appear)
319        let (_, pending) = db
320            .create_invitation(None, Some("pending"), None, future)
321            .await
322            .unwrap();
323
324        // Expired (should not appear)
325        let _ = db
326            .create_invitation(None, Some("expired"), None, past)
327            .await
328            .unwrap();
329
330        // Consumed (should not appear)
331        let (_, consumed) = db
332            .create_invitation(None, Some("consumed"), None, future)
333            .await
334            .unwrap();
335        db.consume_invitation(consumed.id).await.unwrap();
336
337        let list = db.list_pending_invitations().await.expect("list");
338        assert_eq!(list.len(), 1);
339        assert_eq!(list[0].id, pending.id);
340    }
341
342    #[tokio::test]
343    async fn create_invitation_with_invited_by_stores_user_id() {
344        let db = test_db().await;
345        let email = Email::new("creator@example.com".to_string()).unwrap();
346        let user = db
347            .create_user(email, "password123", None, None)
348            .await
349            .expect("create user");
350        let expires = Utc::now() + Duration::hours(24);
351
352        let (_, inv) = db
353            .create_invitation(None, None, Some(user.id), expires)
354            .await
355            .expect("create with invited_by");
356
357        assert_eq!(inv.invited_by, Some(user.id));
358    }
359
360    #[tokio::test]
361    async fn delete_invitation_removes_it() {
362        let db = test_db().await;
363        let expires = Utc::now() + Duration::hours(24);
364        let (raw, inv) = db
365            .create_invitation(None, None, None, expires)
366            .await
367            .unwrap();
368
369        db.delete_invitation(inv.id).await.expect("delete");
370
371        let result = db.validate_invitation(&raw).await.expect("validate");
372        assert!(result.is_none(), "deleted invitation must not validate");
373
374        // List should be empty.
375        let list = db.list_pending_invitations().await.expect("list");
376        assert!(list.is_empty());
377    }
378
379    #[tokio::test]
380    async fn consume_nonexistent_returns_not_found() {
381        let db = test_db().await;
382        let fake_id = crate::types::InvitationId::new();
383        let err = db
384            .consume_invitation(fake_id)
385            .await
386            .expect_err("consume nonexistent should fail");
387        assert!(
388            matches!(err, AuthError::NotFound),
389            "expected AuthError::NotFound, got {err:?}"
390        );
391    }
392
393    #[tokio::test]
394    async fn delete_nonexistent_returns_not_found() {
395        let db = test_db().await;
396        let fake_id = crate::types::InvitationId::new();
397        let err = db
398            .delete_invitation(fake_id)
399            .await
400            .expect_err("delete nonexistent should fail");
401        assert!(
402            matches!(err, AuthError::NotFound),
403            "expected AuthError::NotFound, got {err:?}"
404        );
405    }
406}