1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use sha2::{Digest, Sha256};
6
7use crate::db::Db;
8use crate::error::AuthError;
9use crate::types::{Email, InvitationId, UserId};
10
11#[derive(Debug, Clone, sqlx::FromRow)]
17pub struct Invitation {
18 pub id: InvitationId,
19 pub email: Option<Email>,
20 pub metadata: Option<String>,
21 pub invited_by: Option<UserId>,
22 pub expires_at: DateTime<Utc>,
23 pub consumed_at: Option<DateTime<Utc>>,
24 pub created_at: DateTime<Utc>,
25}
26
27fn generate_invitation_token() -> String {
32 let mut bytes = [0u8; 32];
33 OsRng
34 .try_fill_bytes(&mut bytes)
35 .expect("OS RNG unavailable");
36 Base64UrlUnpadded::encode_string(&bytes)
37}
38
39fn hash_invitation_token(token: &str) -> String {
44 let digest = Sha256::digest(token.as_bytes());
45 format!("{digest:x}")
46}
47
48impl Db {
49 pub async fn create_invitation(
55 &self,
56 email: Option<&Email>,
57 metadata: Option<&str>,
58 invited_by: Option<UserId>,
59 expires_at: DateTime<Utc>,
60 ) -> Result<(String, Invitation), AuthError> {
61 let id = InvitationId::new();
62 let raw_token = generate_invitation_token();
63 let token_hash = hash_invitation_token(&raw_token);
64 let now = Utc::now();
65 let now_str = now.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
66 let expires_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
67
68 sqlx::query(
69 "INSERT INTO allowthem_invitations \
70 (id, email, token_hash, metadata, invited_by, expires_at, created_at) \
71 VALUES (?, ?, ?, ?, ?, ?, ?)",
72 )
73 .bind(id)
74 .bind(email)
75 .bind(&token_hash)
76 .bind(metadata)
77 .bind(invited_by)
78 .bind(&expires_str)
79 .bind(&now_str)
80 .execute(self.pool())
81 .await
82 .map_err(AuthError::Database)?;
83
84 let inv = Invitation {
85 id,
86 email: email.cloned(),
87 metadata: metadata.map(String::from),
88 invited_by,
89 expires_at,
90 consumed_at: None,
91 created_at: now,
92 };
93
94 Ok((raw_token, inv))
95 }
96
97 pub async fn consume_invitation(&self, id: InvitationId) -> Result<(), AuthError> {
104 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
105
106 let result = sqlx::query(
107 "UPDATE allowthem_invitations SET consumed_at = ? \
108 WHERE id = ? AND consumed_at IS NULL",
109 )
110 .bind(&now)
111 .bind(id)
112 .execute(self.pool())
113 .await
114 .map_err(AuthError::Database)?;
115
116 if result.rows_affected() == 0 {
117 let exists: bool = sqlx::query_scalar(
119 "SELECT EXISTS(SELECT 1 FROM allowthem_invitations WHERE id = ?)",
120 )
121 .bind(id)
122 .fetch_one(self.pool())
123 .await
124 .map_err(AuthError::Database)?;
125
126 return Err(if exists {
127 AuthError::Gone
128 } else {
129 AuthError::NotFound
130 });
131 }
132
133 Ok(())
134 }
135
136 pub async fn list_pending_invitations(&self) -> Result<Vec<Invitation>, AuthError> {
138 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
139
140 sqlx::query_as::<_, Invitation>(
141 "SELECT id, email, metadata, invited_by, expires_at, consumed_at, created_at \
142 FROM allowthem_invitations \
143 WHERE consumed_at IS NULL AND expires_at > ? \
144 ORDER BY created_at DESC",
145 )
146 .bind(&now)
147 .fetch_all(self.pool())
148 .await
149 .map_err(AuthError::Database)
150 }
151
152 pub async fn delete_invitation(&self, id: InvitationId) -> Result<(), AuthError> {
154 let result = sqlx::query("DELETE FROM allowthem_invitations WHERE id = ?")
155 .bind(id)
156 .execute(self.pool())
157 .await
158 .map_err(AuthError::Database)?;
159
160 if result.rows_affected() == 0 {
161 return Err(AuthError::NotFound);
162 }
163
164 Ok(())
165 }
166
167 pub async fn validate_invitation(
173 &self,
174 raw_token: &str,
175 ) -> Result<Option<Invitation>, AuthError> {
176 let token_hash = hash_invitation_token(raw_token);
177 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
178
179 sqlx::query_as::<_, Invitation>(
180 "SELECT id, email, metadata, invited_by, expires_at, consumed_at, created_at \
181 FROM allowthem_invitations \
182 WHERE token_hash = ? AND expires_at > ? AND consumed_at IS NULL",
183 )
184 .bind(&token_hash)
185 .bind(&now)
186 .fetch_optional(self.pool())
187 .await
188 .map_err(AuthError::Database)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use chrono::{Duration, Utc};
195
196 use crate::db::Db;
197 use crate::error::AuthError;
198 use crate::types::Email;
199
200 async fn test_db() -> Db {
201 Db::connect("sqlite::memory:").await.expect("in-memory db")
202 }
203
204 #[tokio::test]
205 async fn create_invitation_returns_raw_token_and_invitation() {
206 let db = test_db().await;
207 let email = Email::new("invite@example.com".to_string()).unwrap();
208 let expires = Utc::now() + Duration::hours(24);
209
210 let (raw_token, inv) = db
211 .create_invitation(Some(&email), Some(r#"{"role":"editor"}"#), None, expires)
212 .await
213 .expect("create_invitation");
214
215 assert!(!raw_token.is_empty());
216 assert_eq!(inv.email.as_ref().unwrap().as_str(), "invite@example.com");
217 assert_eq!(inv.metadata.as_deref(), Some(r#"{"role":"editor"}"#));
218 assert!(inv.invited_by.is_none());
219 assert!(inv.consumed_at.is_none());
220 }
221
222 #[tokio::test]
223 async fn create_open_invitation_has_no_email() {
224 let db = test_db().await;
225 let expires = Utc::now() + Duration::hours(24);
226
227 let (_raw, inv) = db
228 .create_invitation(None, None, None, expires)
229 .await
230 .expect("create open invitation");
231
232 assert!(inv.email.is_none());
233 assert!(inv.metadata.is_none());
234 }
235
236 #[tokio::test]
237 async fn validate_returns_invitation_for_valid_token() {
238 let db = test_db().await;
239 let email = Email::new("v@example.com".to_string()).unwrap();
240 let expires = Utc::now() + Duration::hours(24);
241 let (raw, _) = db
242 .create_invitation(Some(&email), Some("{}"), None, expires)
243 .await
244 .unwrap();
245
246 let inv = db.validate_invitation(&raw).await.expect("validate");
247 assert!(inv.is_some());
248 let inv = inv.unwrap();
249 assert_eq!(inv.email.as_ref().unwrap().as_str(), "v@example.com");
250 }
251
252 #[tokio::test]
253 async fn validate_returns_none_for_garbage_token() {
254 let db = test_db().await;
255 let result = db
256 .validate_invitation("not-a-real-token")
257 .await
258 .expect("validate");
259 assert!(result.is_none());
260 }
261
262 #[tokio::test]
263 async fn validate_returns_none_for_expired_token() {
264 let db = test_db().await;
265 let expires = Utc::now() - Duration::hours(1);
266 let (raw, _) = db
267 .create_invitation(None, None, None, expires)
268 .await
269 .unwrap();
270
271 let result = db.validate_invitation(&raw).await.expect("validate");
272 assert!(result.is_none(), "expired invitation must return None");
273 }
274
275 #[tokio::test]
276 async fn consume_marks_invitation_consumed() {
277 let db = test_db().await;
278 let expires = Utc::now() + Duration::hours(24);
279 let (raw, inv) = db
280 .create_invitation(None, None, None, expires)
281 .await
282 .unwrap();
283
284 db.consume_invitation(inv.id).await.expect("consume");
285
286 let result = db.validate_invitation(&raw).await.expect("validate");
288 assert!(result.is_none(), "consumed invitation must not validate");
289 }
290
291 #[tokio::test]
292 async fn consume_twice_returns_gone() {
293 let db = test_db().await;
294 let expires = Utc::now() + Duration::hours(24);
295 let (_, inv) = db
296 .create_invitation(None, None, None, expires)
297 .await
298 .unwrap();
299
300 db.consume_invitation(inv.id).await.expect("first consume");
301
302 let err = db
303 .consume_invitation(inv.id)
304 .await
305 .expect_err("second consume should fail");
306 assert!(
307 matches!(err, AuthError::Gone),
308 "expected AuthError::Gone, got {err:?}"
309 );
310 }
311
312 #[tokio::test]
313 async fn list_pending_excludes_expired_and_consumed() {
314 let db = test_db().await;
315 let future = Utc::now() + Duration::hours(24);
316 let past = Utc::now() - Duration::hours(1);
317
318 let (_, pending) = db
320 .create_invitation(None, Some("pending"), None, future)
321 .await
322 .unwrap();
323
324 let _ = db
326 .create_invitation(None, Some("expired"), None, past)
327 .await
328 .unwrap();
329
330 let (_, consumed) = db
332 .create_invitation(None, Some("consumed"), None, future)
333 .await
334 .unwrap();
335 db.consume_invitation(consumed.id).await.unwrap();
336
337 let list = db.list_pending_invitations().await.expect("list");
338 assert_eq!(list.len(), 1);
339 assert_eq!(list[0].id, pending.id);
340 }
341
342 #[tokio::test]
343 async fn create_invitation_with_invited_by_stores_user_id() {
344 let db = test_db().await;
345 let email = Email::new("creator@example.com".to_string()).unwrap();
346 let user = db
347 .create_user(email, "password123", None, None)
348 .await
349 .expect("create user");
350 let expires = Utc::now() + Duration::hours(24);
351
352 let (_, inv) = db
353 .create_invitation(None, None, Some(user.id), expires)
354 .await
355 .expect("create with invited_by");
356
357 assert_eq!(inv.invited_by, Some(user.id));
358 }
359
360 #[tokio::test]
361 async fn delete_invitation_removes_it() {
362 let db = test_db().await;
363 let expires = Utc::now() + Duration::hours(24);
364 let (raw, inv) = db
365 .create_invitation(None, None, None, expires)
366 .await
367 .unwrap();
368
369 db.delete_invitation(inv.id).await.expect("delete");
370
371 let result = db.validate_invitation(&raw).await.expect("validate");
372 assert!(result.is_none(), "deleted invitation must not validate");
373
374 let list = db.list_pending_invitations().await.expect("list");
376 assert!(list.is_empty());
377 }
378
379 #[tokio::test]
380 async fn consume_nonexistent_returns_not_found() {
381 let db = test_db().await;
382 let fake_id = crate::types::InvitationId::new();
383 let err = db
384 .consume_invitation(fake_id)
385 .await
386 .expect_err("consume nonexistent should fail");
387 assert!(
388 matches!(err, AuthError::NotFound),
389 "expected AuthError::NotFound, got {err:?}"
390 );
391 }
392
393 #[tokio::test]
394 async fn delete_nonexistent_returns_not_found() {
395 let db = test_db().await;
396 let fake_id = crate::types::InvitationId::new();
397 let err = db
398 .delete_invitation(fake_id)
399 .await
400 .expect_err("delete nonexistent should fail");
401 assert!(
402 matches!(err, AuthError::NotFound),
403 "expected AuthError::NotFound, got {err:?}"
404 );
405 }
406}