1use std::sync::Arc;
21
22use chrono::{DateTime, Utc};
23use serde::Serialize;
24
25use crate::db::{ColumnMap, ConnExt, ConnQueryExt, Database, FromRow};
26use crate::dns::{DomainVerifier, generate_verification_token};
27use crate::error::{Error, Result};
28use crate::{db, id};
29
30const VERIFICATION_EXPIRY_HOURS: i64 = 48;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
39#[serde(rename_all = "lowercase")]
40pub enum ClaimStatus {
41 Pending,
43 Verified,
45 Failed,
47}
48
49impl ClaimStatus {
50 pub fn as_str(&self) -> &'static str {
52 match self {
53 Self::Pending => "pending",
54 Self::Verified => "verified",
55 Self::Failed => "failed",
56 }
57 }
58
59 fn from_str(s: &str) -> Result<Self> {
60 match s {
61 "pending" => Ok(Self::Pending),
62 "verified" => Ok(Self::Verified),
63 "failed" => Ok(Self::Failed),
64 _ => Err(Error::internal(format!("unknown claim status: {s}"))),
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize)]
71pub struct DomainClaim {
72 pub id: String,
74 pub tenant_id: String,
76 pub domain: String,
78 pub verification_token: String,
80 pub status: ClaimStatus,
82 pub use_for_email: bool,
84 pub use_for_routing: bool,
86 pub created_at: String,
88 pub verified_at: Option<String>,
90}
91
92#[derive(Debug, Clone, Serialize)]
94pub struct TenantMatch {
95 pub tenant_id: String,
97 pub domain: String,
99}
100
101struct DomainRow {
106 id: String,
107 tenant_id: String,
108 domain: String,
109 verification_token: String,
110 status: String,
111 use_for_email: bool,
112 use_for_routing: bool,
113 created_at: String,
114 verified_at: Option<String>,
115}
116
117impl FromRow for DomainRow {
118 fn from_row(row: &libsql::Row) -> Result<Self> {
119 let cols = ColumnMap::from_row(row);
120 Ok(Self {
121 id: cols.get(row, "id")?,
122 tenant_id: cols.get(row, "tenant_id")?,
123 domain: cols.get(row, "domain")?,
124 verification_token: cols.get(row, "verification_token")?,
125 status: cols.get(row, "status")?,
126 use_for_email: cols.get(row, "use_for_email")?,
127 use_for_routing: cols.get(row, "use_for_routing")?,
128 created_at: cols.get(row, "created_at")?,
129 verified_at: cols.get(row, "verified_at")?,
130 })
131 }
132}
133
134impl DomainRow {
135 fn into_claim(self) -> Result<DomainClaim> {
136 let status = ClaimStatus::from_str(&self.status)?;
137 Ok(DomainClaim {
138 id: self.id,
139 tenant_id: self.tenant_id,
140 domain: self.domain,
141 verification_token: self.verification_token,
142 status,
143 use_for_email: self.use_for_email,
144 use_for_routing: self.use_for_routing,
145 created_at: self.created_at,
146 verified_at: self.verified_at,
147 })
148 }
149
150 fn into_claim_with_expiry(self) -> Result<DomainClaim> {
152 let mut claim = self.into_claim()?;
153 if claim.status == ClaimStatus::Pending && is_expired(&claim.created_at) {
154 claim.status = ClaimStatus::Failed;
155 }
156 Ok(claim)
157 }
158}
159
160struct MatchRow {
162 tenant_id: String,
163 domain: String,
164}
165
166impl FromRow for MatchRow {
167 fn from_row(row: &libsql::Row) -> Result<Self> {
168 let cols = ColumnMap::from_row(row);
169 Ok(Self {
170 tenant_id: cols.get(row, "tenant_id")?,
171 domain: cols.get(row, "domain")?,
172 })
173 }
174}
175
176pub fn validate_domain(domain: &str) -> Result<String> {
191 let domain = domain.trim().to_lowercase();
192
193 if domain.is_empty() {
194 return Err(Error::bad_request("domain must not be empty"));
195 }
196 if !domain.contains('.') {
197 return Err(Error::bad_request("domain must contain at least one dot"));
198 }
199 if domain.starts_with('.') || domain.ends_with('.') {
200 return Err(Error::bad_request(
201 "domain must not start or end with a dot",
202 ));
203 }
204 if domain.len() > 253 {
205 return Err(Error::bad_request("domain must not exceed 253 characters"));
206 }
207
208 for label in domain.split('.') {
209 if label.is_empty() {
210 return Err(Error::bad_request("domain labels must not be empty"));
211 }
212 if label.len() > 63 {
213 return Err(Error::bad_request(
214 "domain labels must not exceed 63 characters",
215 ));
216 }
217 if label.starts_with('-') || label.ends_with('-') {
218 return Err(Error::bad_request(
219 "domain labels must not start or end with a hyphen",
220 ));
221 }
222 if !label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
223 return Err(Error::bad_request(
224 "domain labels must contain only alphanumeric characters and hyphens",
225 ));
226 }
227 }
228
229 Ok(domain)
230}
231
232pub fn extract_email_domain(email: &str) -> Result<String> {
242 let email = email.trim();
243 let parts: Vec<&str> = email.splitn(2, '@').collect();
244 if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
245 return Err(Error::bad_request("invalid email address"));
246 }
247 validate_domain(parts[1])
248}
249
250struct Inner {
255 db: Database,
256 verifier: DomainVerifier,
257}
258
259#[derive(Clone)]
279pub struct DomainService {
280 inner: Arc<Inner>,
281}
282
283impl DomainService {
284 pub fn new(db: Database, verifier: DomainVerifier) -> Self {
286 Self {
287 inner: Arc::new(Inner { db, verifier }),
288 }
289 }
290
291 pub async fn register(&self, tenant_id: &str, domain: &str) -> Result<DomainClaim> {
305 let domain = validate_domain(domain)?;
306
307 let existing: Option<DomainRow> = self
310 .inner
311 .db
312 .conn()
313 .query_optional(
314 "SELECT id, tenant_id, domain, verification_token, status, \
315 use_for_email, use_for_routing, created_at, verified_at \
316 FROM tenant_domains \
317 WHERE tenant_id = ?1 AND domain = ?2 AND status = 'pending' \
318 LIMIT 1",
319 libsql::params![tenant_id, domain.as_str()],
320 )
321 .await?;
322
323 if let Some(row) = existing {
324 let claim = row.into_claim_with_expiry()?;
325 if claim.status == ClaimStatus::Pending {
326 return Ok(claim);
327 }
328 }
330
331 let id = id::ulid();
332 let token = generate_verification_token();
333 let now = Utc::now().to_rfc3339();
334
335 self.inner
336 .db
337 .conn()
338 .execute_raw(
339 "INSERT INTO tenant_domains (id, tenant_id, domain, verification_token, status, created_at) \
340 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
341 libsql::params![id.as_str(), tenant_id, domain.as_str(), token.as_str(), "pending", now.as_str()],
342 )
343 .await
344 .map_err(Error::from)?;
345
346 Ok(DomainClaim {
347 id,
348 tenant_id: tenant_id.to_owned(),
349 domain,
350 verification_token: token,
351 status: ClaimStatus::Pending,
352 use_for_email: false,
353 use_for_routing: false,
354 created_at: now,
355 verified_at: None,
356 })
357 }
358
359 pub async fn verify(&self, id: &str) -> Result<DomainClaim> {
372 let row: DomainRow = self
373 .inner
374 .db
375 .conn()
376 .query_one(
377 "SELECT id, tenant_id, domain, verification_token, status, \
378 use_for_email, use_for_routing, created_at, verified_at \
379 FROM tenant_domains WHERE id = ?1",
380 libsql::params![id],
381 )
382 .await?;
383
384 let claim = row.into_claim()?;
385
386 if claim.status == ClaimStatus::Verified {
387 return Ok(claim);
388 }
389
390 if is_expired(&claim.created_at) {
392 self.inner
393 .db
394 .conn()
395 .execute_raw(
396 "UPDATE tenant_domains SET status = ?1 WHERE id = ?2",
397 libsql::params!["failed", id],
398 )
399 .await
400 .map_err(Error::from)?;
401
402 return Err(Error::bad_request(
403 "verification window has expired (48 hours)",
404 ));
405 }
406
407 let txt_ok = self
409 .inner
410 .verifier
411 .check_txt(&claim.domain, &claim.verification_token)
412 .await?;
413
414 if !txt_ok {
415 return Err(Error::bad_request(
416 "DNS TXT record not found or does not match verification token",
417 ));
418 }
419
420 let now = Utc::now().to_rfc3339();
421 self.inner
422 .db
423 .conn()
424 .execute_raw(
425 "UPDATE tenant_domains SET status = ?1, verified_at = ?2 WHERE id = ?3",
426 libsql::params!["verified", now.as_str(), id],
427 )
428 .await
429 .map_err(Error::from)?;
430
431 Ok(DomainClaim {
432 status: ClaimStatus::Verified,
433 verified_at: Some(now),
434 ..claim
435 })
436 }
437
438 pub async fn remove(&self, id: &str) -> Result<()> {
444 self.inner
445 .db
446 .conn()
447 .execute_raw(
448 "DELETE FROM tenant_domains WHERE id = ?1",
449 libsql::params![id],
450 )
451 .await
452 .map_err(Error::from)?;
453 Ok(())
454 }
455
456 pub async fn enable_email(&self, id: &str) -> Result<()> {
463 self.require_verified(id).await?;
464 self.inner
465 .db
466 .conn()
467 .execute_raw(
468 "UPDATE tenant_domains SET use_for_email = 1 WHERE id = ?1",
469 libsql::params![id],
470 )
471 .await
472 .map_err(Error::from)?;
473 Ok(())
474 }
475
476 pub async fn disable_email(&self, id: &str) -> Result<()> {
482 self.inner
483 .db
484 .conn()
485 .execute_raw(
486 "UPDATE tenant_domains SET use_for_email = 0 WHERE id = ?1",
487 libsql::params![id],
488 )
489 .await
490 .map_err(Error::from)?;
491 Ok(())
492 }
493
494 pub async fn enable_routing(&self, id: &str) -> Result<()> {
501 self.require_verified(id).await?;
502 self.inner
503 .db
504 .conn()
505 .execute_raw(
506 "UPDATE tenant_domains SET use_for_routing = 1 WHERE id = ?1",
507 libsql::params![id],
508 )
509 .await
510 .map_err(Error::from)?;
511 Ok(())
512 }
513
514 pub async fn disable_routing(&self, id: &str) -> Result<()> {
520 self.inner
521 .db
522 .conn()
523 .execute_raw(
524 "UPDATE tenant_domains SET use_for_routing = 0 WHERE id = ?1",
525 libsql::params![id],
526 )
527 .await
528 .map_err(Error::from)?;
529 Ok(())
530 }
531
532 pub async fn lookup_email_domain(&self, email: &str) -> Result<Option<TenantMatch>> {
543 let domain = extract_email_domain(email)?;
544 let row: Option<MatchRow> = self
545 .inner
546 .db
547 .conn()
548 .query_optional(
549 "SELECT tenant_id, domain FROM tenant_domains \
550 WHERE domain = ?1 AND status = 'verified' AND use_for_email = 1 \
551 LIMIT 1",
552 libsql::params![domain.as_str()],
553 )
554 .await?;
555 Ok(row.map(|r| TenantMatch {
556 tenant_id: r.tenant_id,
557 domain: r.domain,
558 }))
559 }
560
561 pub async fn lookup_routing_domain(&self, domain: &str) -> Result<Option<TenantMatch>> {
568 let domain = validate_domain(domain)?;
569 let row: Option<MatchRow> = self
570 .inner
571 .db
572 .conn()
573 .query_optional(
574 "SELECT tenant_id, domain FROM tenant_domains \
575 WHERE domain = ?1 AND status = 'verified' AND use_for_routing = 1 \
576 LIMIT 1",
577 libsql::params![domain.as_str()],
578 )
579 .await?;
580 Ok(row.map(|r| TenantMatch {
581 tenant_id: r.tenant_id,
582 domain: r.domain,
583 }))
584 }
585
586 pub async fn resolve_tenant(&self, domain: &str) -> Result<Option<String>> {
596 Ok(self
597 .lookup_routing_domain(domain)
598 .await?
599 .map(|m| m.tenant_id))
600 }
601
602 pub async fn list(&self, tenant_id: &str) -> Result<Vec<DomainClaim>> {
611 let rows: Vec<DomainRow> = self
612 .inner
613 .db
614 .conn()
615 .query_all(
616 "SELECT id, tenant_id, domain, verification_token, status, \
617 use_for_email, use_for_routing, created_at, verified_at \
618 FROM tenant_domains WHERE tenant_id = ?1 \
619 ORDER BY created_at DESC",
620 libsql::params![tenant_id],
621 )
622 .await?;
623
624 rows.into_iter()
625 .map(|r| r.into_claim_with_expiry())
626 .collect()
627 }
628
629 async fn require_verified(&self, id: &str) -> Result<()> {
633 let status: String = self
634 .inner
635 .db
636 .conn()
637 .query_one_map(
638 "SELECT status FROM tenant_domains WHERE id = ?1",
639 libsql::params![id],
640 |row| {
641 let val = row.get_value(0).map_err(Error::from)?;
642 db::FromValue::from_value(val)
643 },
644 )
645 .await?;
646
647 if status != "verified" {
648 return Err(Error::bad_request(
649 "domain must be verified before enabling features",
650 ));
651 }
652 Ok(())
653 }
654}
655
656fn is_expired(created_at: &str) -> bool {
663 let Ok(created) = DateTime::parse_from_rfc3339(created_at) else {
664 return false;
665 };
666 let age = Utc::now() - created.with_timezone(&Utc);
667 age > chrono::Duration::hours(VERIFICATION_EXPIRY_HOURS)
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 #[test]
677 fn valid_domain() {
678 assert_eq!(validate_domain("Example.COM").unwrap(), "example.com");
679 }
680
681 #[test]
682 fn domain_with_subdomain() {
683 assert_eq!(
684 validate_domain("sub.example.com").unwrap(),
685 "sub.example.com"
686 );
687 }
688
689 #[test]
690 fn domain_trimmed() {
691 assert_eq!(validate_domain(" example.com ").unwrap(), "example.com");
692 }
693
694 #[test]
695 fn empty_domain_rejected() {
696 assert!(validate_domain("").is_err());
697 }
698
699 #[test]
700 fn no_dot_rejected() {
701 assert!(validate_domain("localhost").is_err());
702 }
703
704 #[test]
705 fn leading_dot_rejected() {
706 assert!(validate_domain(".example.com").is_err());
707 }
708
709 #[test]
710 fn trailing_dot_rejected() {
711 assert!(validate_domain("example.com.").is_err());
712 }
713
714 #[test]
715 fn label_starting_with_hyphen_rejected() {
716 assert!(validate_domain("-example.com").is_err());
717 }
718
719 #[test]
720 fn label_ending_with_hyphen_rejected() {
721 assert!(validate_domain("example-.com").is_err());
722 }
723
724 #[test]
725 fn domain_too_long_rejected() {
726 let long = format!("{}.com", "a".repeat(250));
727 assert!(validate_domain(&long).is_err());
728 }
729
730 #[test]
731 fn label_too_long_rejected() {
732 let long = format!("{}.com", "a".repeat(64));
733 assert!(validate_domain(&long).is_err());
734 }
735
736 #[test]
737 fn invalid_chars_rejected() {
738 assert!(validate_domain("ex ample.com").is_err());
739 assert!(validate_domain("ex_ample.com").is_err());
740 }
741
742 #[test]
745 fn extract_valid_email_domain() {
746 assert_eq!(
747 extract_email_domain("user@Example.COM").unwrap(),
748 "example.com"
749 );
750 }
751
752 #[test]
753 fn extract_email_no_at_rejected() {
754 assert!(extract_email_domain("nope").is_err());
755 }
756
757 #[test]
758 fn extract_email_empty_local_rejected() {
759 assert!(extract_email_domain("@example.com").is_err());
760 }
761
762 #[test]
763 fn extract_email_empty_domain_rejected() {
764 assert!(extract_email_domain("user@").is_err());
765 }
766
767 #[test]
770 fn claim_status_round_trip() {
771 for status in [
772 ClaimStatus::Pending,
773 ClaimStatus::Verified,
774 ClaimStatus::Failed,
775 ] {
776 let s = status.as_str();
777 assert_eq!(ClaimStatus::from_str(s).unwrap(), status);
778 }
779 }
780
781 #[test]
782 fn claim_status_unknown_rejected() {
783 assert!(ClaimStatus::from_str("bogus").is_err());
784 }
785
786 #[test]
789 fn fresh_claim_not_expired() {
790 let now = Utc::now().to_rfc3339();
791 assert!(!is_expired(&now));
792 }
793
794 #[test]
795 fn old_claim_expired() {
796 let old = (Utc::now() - chrono::Duration::hours(49)).to_rfc3339();
797 assert!(is_expired(&old));
798 }
799
800 #[test]
801 fn invalid_timestamp_not_expired() {
802 assert!(!is_expired("not-a-timestamp"));
803 }
804}