1use chrono::{DateTime, Utc};
15use serde::{Deserialize, Serialize};
16use sqlx::Row;
17
18use crate::db::Db;
19use crate::error::AuthError;
20use crate::social_provider_encrypt::{decrypt_split, encrypt_split};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
26#[sqlx(type_name = "TEXT", rename_all = "lowercase")]
27#[serde(rename_all = "lowercase")]
28pub enum EmailConfigMode {
29 Managed,
30 Smtp,
31 Webhook,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
41#[sqlx(type_name = "TEXT", rename_all = "lowercase")]
42#[serde(rename_all = "lowercase")]
43pub enum SmtpTlsMode {
44 None,
45 StartTls,
46 #[sqlx(rename = "implicit")]
47 #[serde(rename = "implicit")]
48 ImplicitTls,
49}
50
51#[derive(Debug, Clone)]
52pub struct SmtpOverride {
53 pub host: String,
54 pub port: u16,
55 pub username: Option<String>,
56 pub password: Option<String>,
59 pub from_address: String,
60 pub tls: SmtpTlsMode,
61}
62
63#[derive(Debug, Clone)]
64pub struct WebhookOverride {
65 pub url: String,
66 pub signing_secret: Option<Vec<u8>>,
68}
69
70#[derive(Debug, Clone)]
71pub struct ManagedOverride {
72 pub from_address: Option<String>,
75}
76
77#[derive(Debug, Clone)]
79pub struct EmailConfig {
80 pub mode: EmailConfigMode,
81 pub smtp: Option<SmtpOverride>,
82 pub webhook: Option<WebhookOverride>,
83 pub managed: Option<ManagedOverride>,
84 pub created_at: DateTime<Utc>,
85 pub updated_at: DateTime<Utc>,
86}
87
88#[derive(Debug, Clone)]
91pub struct SetEmailConfig {
92 pub mode: EmailConfigMode,
93 pub smtp: Option<SmtpOverride>,
94 pub webhook: Option<WebhookOverride>,
95 pub managed: Option<ManagedOverride>,
96}
97
98impl Db {
101 pub async fn get_email_config(
108 &self,
109 mfa_key: &[u8; 32],
110 ) -> Result<Option<EmailConfig>, AuthError> {
111 let row_opt = sqlx::query(
112 "SELECT mode, smtp_host, smtp_port, smtp_username, \
113 smtp_password_enc, smtp_password_nonce, smtp_from_address, smtp_tls, \
114 webhook_url, webhook_secret_enc, webhook_secret_nonce, \
115 managed_from_address, created_at, updated_at \
116 FROM allowthem_email_config WHERE singleton = 'singleton'",
117 )
118 .fetch_optional(self.pool())
119 .await?;
120
121 let Some(row) = row_opt else { return Ok(None) };
122
123 let mode: EmailConfigMode = row.try_get("mode")?;
124 let created_at: DateTime<Utc> = row.try_get("created_at")?;
125 let updated_at: DateTime<Utc> = row.try_get("updated_at")?;
126
127 let smtp = if mode == EmailConfigMode::Smtp {
128 let password_enc: Option<Vec<u8>> = row.try_get("smtp_password_enc")?;
129 let password_nonce: Option<Vec<u8>> = row.try_get("smtp_password_nonce")?;
130 let password = match (password_enc, password_nonce) {
131 (Some(enc), Some(nonce)) => {
132 let bytes = decrypt_split(&nonce, &enc, mfa_key)?;
133 Some(String::from_utf8(bytes).map_err(|_| {
134 AuthError::MfaEncryption("smtp password not valid utf-8".into())
135 })?)
136 }
137 _ => None,
138 };
139 let host: String = row.try_get("smtp_host")?;
140 let port_i: i64 = row.try_get("smtp_port")?;
141 let port: u16 = u16::try_from(port_i)
142 .map_err(|_| AuthError::Validation("smtp_port out of range".into()))?;
143 let username: Option<String> = row.try_get("smtp_username")?;
144 let from_address: String = row.try_get("smtp_from_address")?;
145 let tls: SmtpTlsMode = row.try_get("smtp_tls")?;
146 Some(SmtpOverride {
147 host,
148 port,
149 username,
150 password,
151 from_address,
152 tls,
153 })
154 } else {
155 None
156 };
157
158 let webhook = if mode == EmailConfigMode::Webhook {
159 let url: String = row.try_get("webhook_url")?;
160 let secret_enc: Option<Vec<u8>> = row.try_get("webhook_secret_enc")?;
161 let secret_nonce: Option<Vec<u8>> = row.try_get("webhook_secret_nonce")?;
162 let signing_secret = match (secret_enc, secret_nonce) {
163 (Some(enc), Some(nonce)) => Some(decrypt_split(&nonce, &enc, mfa_key)?),
164 _ => None,
165 };
166 Some(WebhookOverride {
167 url,
168 signing_secret,
169 })
170 } else {
171 None
172 };
173
174 let managed = if mode == EmailConfigMode::Managed {
175 let from_address: Option<String> = row.try_get("managed_from_address")?;
176 Some(ManagedOverride { from_address })
177 } else {
178 None
179 };
180
181 Ok(Some(EmailConfig {
182 mode,
183 smtp,
184 webhook,
185 managed,
186 created_at,
187 updated_at,
188 }))
189 }
190
191 pub async fn set_email_config(
196 &self,
197 cfg: &SetEmailConfig,
198 mfa_key: &[u8; 32],
199 ) -> Result<(), AuthError> {
200 match cfg.mode {
204 EmailConfigMode::Smtp if cfg.smtp.is_none() => {
205 return Err(AuthError::Validation(
206 "mode=smtp requires smtp override".into(),
207 ));
208 }
209 EmailConfigMode::Webhook if cfg.webhook.is_none() => {
210 return Err(AuthError::Validation(
211 "mode=webhook requires webhook override".into(),
212 ));
213 }
214 _ => {}
215 }
216
217 let (smtp_pw_enc, smtp_pw_nonce) = match cfg.smtp.as_ref().and_then(|s| s.password.as_ref())
220 {
221 Some(pw) => {
222 let enc = encrypt_split(pw.as_bytes(), mfa_key)?;
223 (Some(enc.ciphertext), Some(enc.nonce.to_vec()))
224 }
225 None => (None, None),
226 };
227 let (wh_secret_enc, wh_secret_nonce) =
228 match cfg.webhook.as_ref().and_then(|w| w.signing_secret.as_ref()) {
229 Some(secret) => {
230 let enc = encrypt_split(secret, mfa_key)?;
231 (Some(enc.ciphertext), Some(enc.nonce.to_vec()))
232 }
233 None => (None, None),
234 };
235
236 sqlx::query(
237 "INSERT INTO allowthem_email_config \
238 (singleton, mode, \
239 smtp_host, smtp_port, smtp_username, \
240 smtp_password_enc, smtp_password_nonce, smtp_from_address, smtp_tls, \
241 webhook_url, webhook_secret_enc, webhook_secret_nonce, \
242 managed_from_address, updated_at) \
243 VALUES ('singleton', ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, \
244 strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) \
245 ON CONFLICT (singleton) DO UPDATE SET \
246 mode = excluded.mode, \
247 smtp_host = excluded.smtp_host, \
248 smtp_port = excluded.smtp_port, \
249 smtp_username = excluded.smtp_username, \
250 smtp_password_enc = excluded.smtp_password_enc, \
251 smtp_password_nonce = excluded.smtp_password_nonce, \
252 smtp_from_address = excluded.smtp_from_address, \
253 smtp_tls = excluded.smtp_tls, \
254 webhook_url = excluded.webhook_url, \
255 webhook_secret_enc = excluded.webhook_secret_enc, \
256 webhook_secret_nonce = excluded.webhook_secret_nonce, \
257 managed_from_address = excluded.managed_from_address, \
258 updated_at = excluded.updated_at",
259 )
260 .bind(cfg.mode)
261 .bind(cfg.smtp.as_ref().map(|s| s.host.as_str()))
262 .bind(cfg.smtp.as_ref().map(|s| s.port as i64))
263 .bind(cfg.smtp.as_ref().and_then(|s| s.username.as_deref()))
264 .bind(smtp_pw_enc)
265 .bind(smtp_pw_nonce)
266 .bind(cfg.smtp.as_ref().map(|s| s.from_address.as_str()))
267 .bind(cfg.smtp.as_ref().map(|s| s.tls))
268 .bind(cfg.webhook.as_ref().map(|w| w.url.as_str()))
269 .bind(wh_secret_enc)
270 .bind(wh_secret_nonce)
271 .bind(cfg.managed.as_ref().and_then(|m| m.from_address.as_deref()))
272 .execute(self.pool())
273 .await?;
274
275 Ok(())
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 use crate::db::Db;
284
285 const KEY_A: [u8; 32] = [7u8; 32];
286 const KEY_B: [u8; 32] = [11u8; 32];
287
288 async fn make_db() -> Db {
289 Db::connect("sqlite::memory:").await.unwrap()
291 }
292
293 #[tokio::test]
294 async fn get_returns_none_when_no_row() {
295 let db = make_db().await;
296 assert!(db.get_email_config(&KEY_A).await.unwrap().is_none());
297 }
298
299 #[tokio::test]
300 async fn set_then_get_round_trips_managed() {
301 let db = make_db().await;
302 db.set_email_config(
303 &SetEmailConfig {
304 mode: EmailConfigMode::Managed,
305 smtp: None,
306 webhook: None,
307 managed: Some(ManagedOverride {
308 from_address: Some("noreply@auth.acme.com".into()),
309 }),
310 },
311 &KEY_A,
312 )
313 .await
314 .unwrap();
315
316 let cfg = db.get_email_config(&KEY_A).await.unwrap().unwrap();
317 assert_eq!(cfg.mode, EmailConfigMode::Managed);
318 assert!(cfg.smtp.is_none());
319 assert!(cfg.webhook.is_none());
320 assert_eq!(
321 cfg.managed.unwrap().from_address.as_deref(),
322 Some("noreply@auth.acme.com")
323 );
324 }
325
326 #[tokio::test]
327 async fn set_then_get_round_trips_smtp_with_decryption() {
328 let db = make_db().await;
329 let plaintext_pw = "hunter2!";
330 db.set_email_config(
331 &SetEmailConfig {
332 mode: EmailConfigMode::Smtp,
333 smtp: Some(SmtpOverride {
334 host: "smtp.example.com".into(),
335 port: 587,
336 username: Some("alice".into()),
337 password: Some(plaintext_pw.into()),
338 from_address: "noreply@example.com".into(),
339 tls: SmtpTlsMode::StartTls,
340 }),
341 webhook: None,
342 managed: None,
343 },
344 &KEY_A,
345 )
346 .await
347 .unwrap();
348
349 let cfg = db.get_email_config(&KEY_A).await.unwrap().unwrap();
351 let smtp = cfg.smtp.unwrap();
352 assert_eq!(smtp.password.as_deref(), Some(plaintext_pw));
353 assert_eq!(smtp.tls, SmtpTlsMode::StartTls);
354
355 let raw_pw_enc: Vec<u8> =
357 sqlx::query_scalar("SELECT smtp_password_enc FROM allowthem_email_config")
358 .fetch_one(db.pool())
359 .await
360 .unwrap();
361 assert_ne!(raw_pw_enc.as_slice(), plaintext_pw.as_bytes());
362 }
363
364 #[tokio::test]
365 async fn set_then_get_round_trips_webhook_with_decryption() {
366 let db = make_db().await;
367 let secret = b"hmac-secret-bytes".to_vec();
368 db.set_email_config(
369 &SetEmailConfig {
370 mode: EmailConfigMode::Webhook,
371 smtp: None,
372 webhook: Some(WebhookOverride {
373 url: "https://hooks.acme.com/email".into(),
374 signing_secret: Some(secret.clone()),
375 }),
376 managed: None,
377 },
378 &KEY_A,
379 )
380 .await
381 .unwrap();
382
383 let cfg = db.get_email_config(&KEY_A).await.unwrap().unwrap();
384 let webhook = cfg.webhook.unwrap();
385 assert_eq!(webhook.url, "https://hooks.acme.com/email");
386 assert_eq!(webhook.signing_secret.as_deref(), Some(secret.as_slice()));
387 }
388
389 #[tokio::test]
390 async fn second_set_overwrites_first() {
391 let db = make_db().await;
392 db.set_email_config(
393 &SetEmailConfig {
394 mode: EmailConfigMode::Managed,
395 smtp: None,
396 webhook: None,
397 managed: Some(ManagedOverride { from_address: None }),
398 },
399 &KEY_A,
400 )
401 .await
402 .unwrap();
403 db.set_email_config(
404 &SetEmailConfig {
405 mode: EmailConfigMode::Webhook,
406 smtp: None,
407 webhook: Some(WebhookOverride {
408 url: "https://x".into(),
409 signing_secret: None,
410 }),
411 managed: None,
412 },
413 &KEY_A,
414 )
415 .await
416 .unwrap();
417
418 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM allowthem_email_config")
419 .fetch_one(db.pool())
420 .await
421 .unwrap();
422 assert_eq!(count, 1, "single-row guard must hold across upserts");
423
424 let cfg = db.get_email_config(&KEY_A).await.unwrap().unwrap();
425 assert_eq!(cfg.mode, EmailConfigMode::Webhook);
426 }
427
428 #[tokio::test]
429 async fn set_rejects_smtp_mode_without_smtp_block() {
430 let db = make_db().await;
431 let err = db
432 .set_email_config(
433 &SetEmailConfig {
434 mode: EmailConfigMode::Smtp,
435 smtp: None,
436 webhook: None,
437 managed: None,
438 },
439 &KEY_A,
440 )
441 .await
442 .unwrap_err();
443 assert!(matches!(err, AuthError::Validation(_)));
444 }
445
446 #[tokio::test]
447 async fn set_rejects_webhook_mode_without_webhook_block() {
448 let db = make_db().await;
449 let err = db
450 .set_email_config(
451 &SetEmailConfig {
452 mode: EmailConfigMode::Webhook,
453 smtp: None,
454 webhook: None,
455 managed: None,
456 },
457 &KEY_A,
458 )
459 .await
460 .unwrap_err();
461 assert!(matches!(err, AuthError::Validation(_)));
462 }
463
464 #[tokio::test]
465 async fn get_with_wrong_mfa_key_fails_decrypt() {
466 let db = make_db().await;
467 db.set_email_config(
468 &SetEmailConfig {
469 mode: EmailConfigMode::Smtp,
470 smtp: Some(SmtpOverride {
471 host: "smtp.example.com".into(),
472 port: 587,
473 username: Some("alice".into()),
474 password: Some("hunter2".into()),
475 from_address: "noreply@example.com".into(),
476 tls: SmtpTlsMode::StartTls,
477 }),
478 webhook: None,
479 managed: None,
480 },
481 &KEY_A,
482 )
483 .await
484 .unwrap();
485
486 let err = db.get_email_config(&KEY_B).await.unwrap_err();
487 assert!(matches!(err, AuthError::MfaEncryption(_)));
488 }
489}