use std::sync::Arc;
use chrono::{DateTime, Utc};
use serde::Serialize;
use crate::db::{ColumnMap, ConnExt, ConnQueryExt, Database, FromRow};
use crate::dns::{DomainVerifier, generate_verification_token};
use crate::error::{Error, Result};
use crate::{db, id};
const VERIFICATION_EXPIRY_HOURS: i64 = 48;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum ClaimStatus {
Pending,
Verified,
Failed,
}
impl ClaimStatus {
pub fn as_str(&self) -> &'static str {
match self {
Self::Pending => "pending",
Self::Verified => "verified",
Self::Failed => "failed",
}
}
fn from_str(s: &str) -> Result<Self> {
match s {
"pending" => Ok(Self::Pending),
"verified" => Ok(Self::Verified),
"failed" => Ok(Self::Failed),
_ => Err(Error::internal(format!("unknown claim status: {s}"))),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct DomainClaim {
pub id: String,
pub tenant_id: String,
pub domain: String,
pub verification_token: String,
pub status: ClaimStatus,
pub use_for_email: bool,
pub use_for_routing: bool,
pub created_at: String,
pub verified_at: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct TenantMatch {
pub tenant_id: String,
pub domain: String,
}
struct DomainRow {
id: String,
tenant_id: String,
domain: String,
verification_token: String,
status: String,
use_for_email: bool,
use_for_routing: bool,
created_at: String,
verified_at: Option<String>,
}
impl FromRow for DomainRow {
fn from_row(row: &libsql::Row) -> Result<Self> {
let cols = ColumnMap::from_row(row);
Ok(Self {
id: cols.get(row, "id")?,
tenant_id: cols.get(row, "tenant_id")?,
domain: cols.get(row, "domain")?,
verification_token: cols.get(row, "verification_token")?,
status: cols.get(row, "status")?,
use_for_email: cols.get(row, "use_for_email")?,
use_for_routing: cols.get(row, "use_for_routing")?,
created_at: cols.get(row, "created_at")?,
verified_at: cols.get(row, "verified_at")?,
})
}
}
impl DomainRow {
fn into_claim(self) -> Result<DomainClaim> {
let status = ClaimStatus::from_str(&self.status)?;
Ok(DomainClaim {
id: self.id,
tenant_id: self.tenant_id,
domain: self.domain,
verification_token: self.verification_token,
status,
use_for_email: self.use_for_email,
use_for_routing: self.use_for_routing,
created_at: self.created_at,
verified_at: self.verified_at,
})
}
fn into_claim_with_expiry(self) -> Result<DomainClaim> {
let mut claim = self.into_claim()?;
if claim.status == ClaimStatus::Pending && is_expired(&claim.created_at) {
claim.status = ClaimStatus::Failed;
}
Ok(claim)
}
}
struct MatchRow {
tenant_id: String,
domain: String,
}
impl FromRow for MatchRow {
fn from_row(row: &libsql::Row) -> Result<Self> {
let cols = ColumnMap::from_row(row);
Ok(Self {
tenant_id: cols.get(row, "tenant_id")?,
domain: cols.get(row, "domain")?,
})
}
}
pub fn validate_domain(domain: &str) -> Result<String> {
let domain = domain.trim().to_lowercase();
if domain.is_empty() {
return Err(Error::bad_request("domain must not be empty"));
}
if !domain.contains('.') {
return Err(Error::bad_request("domain must contain at least one dot"));
}
if domain.starts_with('.') || domain.ends_with('.') {
return Err(Error::bad_request(
"domain must not start or end with a dot",
));
}
if domain.len() > 253 {
return Err(Error::bad_request("domain must not exceed 253 characters"));
}
for label in domain.split('.') {
if label.is_empty() {
return Err(Error::bad_request("domain labels must not be empty"));
}
if label.len() > 63 {
return Err(Error::bad_request(
"domain labels must not exceed 63 characters",
));
}
if label.starts_with('-') || label.ends_with('-') {
return Err(Error::bad_request(
"domain labels must not start or end with a hyphen",
));
}
if !label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
return Err(Error::bad_request(
"domain labels must contain only alphanumeric characters and hyphens",
));
}
}
Ok(domain)
}
pub fn extract_email_domain(email: &str) -> Result<String> {
let email = email.trim();
let parts: Vec<&str> = email.splitn(2, '@').collect();
if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
return Err(Error::bad_request("invalid email address"));
}
validate_domain(parts[1])
}
struct Inner {
db: Database,
verifier: DomainVerifier,
}
#[derive(Clone)]
pub struct DomainService {
inner: Arc<Inner>,
}
impl DomainService {
pub fn new(db: Database, verifier: DomainVerifier) -> Self {
Self {
inner: Arc::new(Inner { db, verifier }),
}
}
pub async fn register(&self, tenant_id: &str, domain: &str) -> Result<DomainClaim> {
let domain = validate_domain(domain)?;
let existing: Option<DomainRow> = self
.inner
.db
.conn()
.query_optional(
"SELECT id, tenant_id, domain, verification_token, status, \
use_for_email, use_for_routing, created_at, verified_at \
FROM tenant_domains \
WHERE tenant_id = ?1 AND domain = ?2 AND status = 'pending' \
LIMIT 1",
libsql::params![tenant_id, domain.as_str()],
)
.await?;
if let Some(row) = existing {
let claim = row.into_claim_with_expiry()?;
if claim.status == ClaimStatus::Pending {
return Ok(claim);
}
}
let id = id::ulid();
let token = generate_verification_token();
let now = Utc::now().to_rfc3339();
self.inner
.db
.conn()
.execute_raw(
"INSERT INTO tenant_domains (id, tenant_id, domain, verification_token, status, created_at) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
libsql::params![id.as_str(), tenant_id, domain.as_str(), token.as_str(), "pending", now.as_str()],
)
.await
.map_err(Error::from)?;
Ok(DomainClaim {
id,
tenant_id: tenant_id.to_owned(),
domain,
verification_token: token,
status: ClaimStatus::Pending,
use_for_email: false,
use_for_routing: false,
created_at: now,
verified_at: None,
})
}
pub async fn verify(&self, id: &str) -> Result<DomainClaim> {
let row: DomainRow = self
.inner
.db
.conn()
.query_one(
"SELECT id, tenant_id, domain, verification_token, status, \
use_for_email, use_for_routing, created_at, verified_at \
FROM tenant_domains WHERE id = ?1",
libsql::params![id],
)
.await?;
let claim = row.into_claim()?;
if claim.status == ClaimStatus::Verified {
return Ok(claim);
}
if is_expired(&claim.created_at) {
self.inner
.db
.conn()
.execute_raw(
"UPDATE tenant_domains SET status = ?1 WHERE id = ?2",
libsql::params!["failed", id],
)
.await
.map_err(Error::from)?;
return Err(Error::bad_request(
"verification window has expired (48 hours)",
));
}
let txt_ok = self
.inner
.verifier
.check_txt(&claim.domain, &claim.verification_token)
.await?;
if !txt_ok {
return Err(Error::bad_request(
"DNS TXT record not found or does not match verification token",
));
}
let now = Utc::now().to_rfc3339();
self.inner
.db
.conn()
.execute_raw(
"UPDATE tenant_domains SET status = ?1, verified_at = ?2 WHERE id = ?3",
libsql::params!["verified", now.as_str(), id],
)
.await
.map_err(Error::from)?;
Ok(DomainClaim {
status: ClaimStatus::Verified,
verified_at: Some(now),
..claim
})
}
pub async fn remove(&self, id: &str) -> Result<()> {
self.inner
.db
.conn()
.execute_raw(
"DELETE FROM tenant_domains WHERE id = ?1",
libsql::params![id],
)
.await
.map_err(Error::from)?;
Ok(())
}
pub async fn enable_email(&self, id: &str) -> Result<()> {
self.require_verified(id).await?;
self.inner
.db
.conn()
.execute_raw(
"UPDATE tenant_domains SET use_for_email = 1 WHERE id = ?1",
libsql::params![id],
)
.await
.map_err(Error::from)?;
Ok(())
}
pub async fn disable_email(&self, id: &str) -> Result<()> {
self.inner
.db
.conn()
.execute_raw(
"UPDATE tenant_domains SET use_for_email = 0 WHERE id = ?1",
libsql::params![id],
)
.await
.map_err(Error::from)?;
Ok(())
}
pub async fn enable_routing(&self, id: &str) -> Result<()> {
self.require_verified(id).await?;
self.inner
.db
.conn()
.execute_raw(
"UPDATE tenant_domains SET use_for_routing = 1 WHERE id = ?1",
libsql::params![id],
)
.await
.map_err(Error::from)?;
Ok(())
}
pub async fn disable_routing(&self, id: &str) -> Result<()> {
self.inner
.db
.conn()
.execute_raw(
"UPDATE tenant_domains SET use_for_routing = 0 WHERE id = ?1",
libsql::params![id],
)
.await
.map_err(Error::from)?;
Ok(())
}
pub async fn lookup_email_domain(&self, email: &str) -> Result<Option<TenantMatch>> {
let domain = extract_email_domain(email)?;
let row: Option<MatchRow> = self
.inner
.db
.conn()
.query_optional(
"SELECT tenant_id, domain FROM tenant_domains \
WHERE domain = ?1 AND status = 'verified' AND use_for_email = 1 \
LIMIT 1",
libsql::params![domain.as_str()],
)
.await?;
Ok(row.map(|r| TenantMatch {
tenant_id: r.tenant_id,
domain: r.domain,
}))
}
pub async fn lookup_routing_domain(&self, domain: &str) -> Result<Option<TenantMatch>> {
let domain = validate_domain(domain)?;
let row: Option<MatchRow> = self
.inner
.db
.conn()
.query_optional(
"SELECT tenant_id, domain FROM tenant_domains \
WHERE domain = ?1 AND status = 'verified' AND use_for_routing = 1 \
LIMIT 1",
libsql::params![domain.as_str()],
)
.await?;
Ok(row.map(|r| TenantMatch {
tenant_id: r.tenant_id,
domain: r.domain,
}))
}
pub async fn resolve_tenant(&self, domain: &str) -> Result<Option<String>> {
Ok(self
.lookup_routing_domain(domain)
.await?
.map(|m| m.tenant_id))
}
pub async fn list(&self, tenant_id: &str) -> Result<Vec<DomainClaim>> {
let rows: Vec<DomainRow> = self
.inner
.db
.conn()
.query_all(
"SELECT id, tenant_id, domain, verification_token, status, \
use_for_email, use_for_routing, created_at, verified_at \
FROM tenant_domains WHERE tenant_id = ?1 \
ORDER BY created_at DESC",
libsql::params![tenant_id],
)
.await?;
rows.into_iter()
.map(|r| r.into_claim_with_expiry())
.collect()
}
async fn require_verified(&self, id: &str) -> Result<()> {
let status: String = self
.inner
.db
.conn()
.query_one_map(
"SELECT status FROM tenant_domains WHERE id = ?1",
libsql::params![id],
|row| {
let val = row.get_value(0).map_err(Error::from)?;
db::FromValue::from_value(val)
},
)
.await?;
if status != "verified" {
return Err(Error::bad_request(
"domain must be verified before enabling features",
));
}
Ok(())
}
}
fn is_expired(created_at: &str) -> bool {
let Ok(created) = DateTime::parse_from_rfc3339(created_at) else {
return false;
};
let age = Utc::now() - created.with_timezone(&Utc);
age > chrono::Duration::hours(VERIFICATION_EXPIRY_HOURS)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn valid_domain() {
assert_eq!(validate_domain("Example.COM").unwrap(), "example.com");
}
#[test]
fn domain_with_subdomain() {
assert_eq!(
validate_domain("sub.example.com").unwrap(),
"sub.example.com"
);
}
#[test]
fn domain_trimmed() {
assert_eq!(validate_domain(" example.com ").unwrap(), "example.com");
}
#[test]
fn empty_domain_rejected() {
assert!(validate_domain("").is_err());
}
#[test]
fn no_dot_rejected() {
assert!(validate_domain("localhost").is_err());
}
#[test]
fn leading_dot_rejected() {
assert!(validate_domain(".example.com").is_err());
}
#[test]
fn trailing_dot_rejected() {
assert!(validate_domain("example.com.").is_err());
}
#[test]
fn label_starting_with_hyphen_rejected() {
assert!(validate_domain("-example.com").is_err());
}
#[test]
fn label_ending_with_hyphen_rejected() {
assert!(validate_domain("example-.com").is_err());
}
#[test]
fn domain_too_long_rejected() {
let long = format!("{}.com", "a".repeat(250));
assert!(validate_domain(&long).is_err());
}
#[test]
fn label_too_long_rejected() {
let long = format!("{}.com", "a".repeat(64));
assert!(validate_domain(&long).is_err());
}
#[test]
fn invalid_chars_rejected() {
assert!(validate_domain("ex ample.com").is_err());
assert!(validate_domain("ex_ample.com").is_err());
}
#[test]
fn extract_valid_email_domain() {
assert_eq!(
extract_email_domain("user@Example.COM").unwrap(),
"example.com"
);
}
#[test]
fn extract_email_no_at_rejected() {
assert!(extract_email_domain("nope").is_err());
}
#[test]
fn extract_email_empty_local_rejected() {
assert!(extract_email_domain("@example.com").is_err());
}
#[test]
fn extract_email_empty_domain_rejected() {
assert!(extract_email_domain("user@").is_err());
}
#[test]
fn claim_status_round_trip() {
for status in [
ClaimStatus::Pending,
ClaimStatus::Verified,
ClaimStatus::Failed,
] {
let s = status.as_str();
assert_eq!(ClaimStatus::from_str(s).unwrap(), status);
}
}
#[test]
fn claim_status_unknown_rejected() {
assert!(ClaimStatus::from_str("bogus").is_err());
}
#[test]
fn fresh_claim_not_expired() {
let now = Utc::now().to_rfc3339();
assert!(!is_expired(&now));
}
#[test]
fn old_claim_expired() {
let old = (Utc::now() - chrono::Duration::hours(49)).to_rfc3339();
assert!(is_expired(&old));
}
#[test]
fn invalid_timestamp_not_expired() {
assert!(!is_expired("not-a-timestamp"));
}
}