use crate::core::Column as _;
use crate::sql::sqlx::PgPool;
use crate::sql::Fetcher;
use async_trait::async_trait;
use http::request::Parts;
use http::HeaderName;
use super::error::TenancyError;
use super::org::Org;
#[async_trait]
pub trait OrgResolver: Send + Sync + 'static {
async fn resolve(&self, parts: &Parts, registry: &PgPool) -> Result<Option<Org>, TenancyError>;
}
pub struct SubdomainResolver {
pub apex_domain: String,
}
impl SubdomainResolver {
#[must_use]
pub fn new(apex_domain: impl Into<String>) -> Self {
Self {
apex_domain: apex_domain.into(),
}
}
}
#[async_trait]
impl OrgResolver for SubdomainResolver {
async fn resolve(&self, parts: &Parts, registry: &PgPool) -> Result<Option<Org>, TenancyError> {
let Some(host) = host_from_parts(parts) else {
return Ok(None);
};
if host == self.apex_domain {
return Ok(None);
}
find_active_org_by(registry, Org::host_pattern.eq(host.to_owned())).await
}
}
pub struct PathPrefixResolver;
#[async_trait]
impl OrgResolver for PathPrefixResolver {
async fn resolve(&self, parts: &Parts, registry: &PgPool) -> Result<Option<Org>, TenancyError> {
let path = parts.uri.path();
let Some(first) = path
.trim_start_matches('/')
.split('/')
.next()
.filter(|s| !s.is_empty())
else {
return Ok(None);
};
let candidate = format!("/{first}");
find_active_org_by(registry, Org::path_prefix.eq(candidate)).await
}
}
pub struct HeaderResolver {
pub header_name: HeaderName,
}
impl HeaderResolver {
#[must_use]
pub fn new(header_name: HeaderName) -> Self {
Self { header_name }
}
}
impl Default for HeaderResolver {
fn default() -> Self {
Self {
header_name: HeaderName::from_static("x-org"),
}
}
}
#[async_trait]
impl OrgResolver for HeaderResolver {
async fn resolve(&self, parts: &Parts, registry: &PgPool) -> Result<Option<Org>, TenancyError> {
let Some(value) = parts.headers.get(&self.header_name) else {
return Ok(None);
};
let slug = match value.to_str() {
Ok(s) => s.trim(),
Err(_) => return Ok(None),
};
if slug.is_empty() {
return Ok(None);
}
find_active_org_by(registry, Org::slug.eq(slug.to_owned())).await
}
}
pub struct PortResolver;
#[async_trait]
impl OrgResolver for PortResolver {
async fn resolve(&self, parts: &Parts, registry: &PgPool) -> Result<Option<Org>, TenancyError> {
let Some(port) = parts.uri.port_u16() else {
return Ok(None);
};
find_active_org_by(registry, Org::port.eq(i32::from(port))).await
}
}
pub struct ChainResolver {
resolvers: Vec<Box<dyn OrgResolver>>,
}
impl ChainResolver {
#[must_use]
pub fn new() -> Self {
Self {
resolvers: Vec::new(),
}
}
#[must_use]
pub fn push<R: OrgResolver>(mut self, resolver: R) -> Self {
self.resolvers.push(Box::new(resolver));
self
}
#[must_use]
pub fn standard(apex_domain: impl Into<String>) -> Self {
Self::new()
.push(SubdomainResolver::new(apex_domain))
.push(HeaderResolver::default())
}
}
impl Default for ChainResolver {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl OrgResolver for ChainResolver {
async fn resolve(&self, parts: &Parts, registry: &PgPool) -> Result<Option<Org>, TenancyError> {
for resolver in &self.resolvers {
match resolver.resolve(parts, registry).await? {
Some(org) => return Ok(Some(org)),
None => continue,
}
}
Ok(None)
}
}
fn host_from_parts(parts: &Parts) -> Option<&str> {
if let Some(value) = parts.headers.get(http::header::HOST) {
if let Ok(s) = value.to_str() {
return Some(s.split(':').next().unwrap_or(s));
}
}
parts.uri.host()
}
async fn find_active_org_by<F>(registry: &PgPool, filter: F) -> Result<Option<Org>, TenancyError>
where
F: Into<rustango::core::TypedFilter<Org>>,
{
let typed: rustango::core::TypedFilter<Org> = filter.into();
let rows: Vec<Org> = Org::objects()
.where_(typed)
.where_(Org::active.eq(true))
.fetch(registry)
.await
.map_err(|e| TenancyError::Driver(driver_from_exec(e)))?;
Ok(rows.into_iter().next())
}
fn driver_from_exec(e: rustango::sql::ExecError) -> rustango::sql::sqlx::Error {
use crate::sql::ExecError;
match e {
ExecError::Driver(err) => err,
other => rustango::sql::sqlx::Error::Protocol(format!("resolver query: {other}")),
}
}