use std::sync::Arc;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use sqlx::Database;
use crate::sql::sqlx;
use crate::tenancy::{
session::SessionSecret, ChainResolver, DefaultTenantDb, Org, OrgResolver, TenancyError,
TenantConn, TenantPools,
};
pub struct TenantContext<DB: Database = DefaultTenantDb> {
pub pools: Arc<TenantPools<DB>>,
pub resolver: ChainResolver,
pub session_secret: SessionSecret,
pub operator_secret: SessionSecret,
}
enum TenantConnCell<DB: Database> {
Ready(TenantConn<DB>),
Deferred(Arc<TenantPools<DB>>),
}
pub struct Tenant<DB: Database = DefaultTenantDb> {
pub org: Org,
conn: TenantConnCell<DB>,
pool: crate::sql::Pool,
}
impl<DB: Database> Tenant<DB> {
pub async fn pool_conn(&mut self) -> Result<&mut sqlx::pool::PoolConnection<DB>, TenancyError> {
self.ensure_conn().await?;
match &mut self.conn {
TenantConnCell::Ready(conn) => Ok(conn),
TenantConnCell::Deferred(_) => {
unreachable!("ensure_conn just populated the connection")
}
}
}
async fn ensure_conn(&mut self) -> Result<(), TenancyError> {
let pools = match &self.conn {
TenantConnCell::Ready(_) => return Ok(()),
TenantConnCell::Deferred(pools) => Arc::clone(pools),
};
let conn = pools.database_acquire(&self.org).await?;
self.conn = TenantConnCell::Ready(conn);
Ok(())
}
pub async fn into_conn(mut self) -> Result<TenantConn<DB>, TenancyError> {
self.ensure_conn().await?;
match self.conn {
TenantConnCell::Ready(conn) => Ok(conn),
TenantConnCell::Deferred(_) => {
unreachable!("ensure_conn just populated the connection")
}
}
}
#[must_use]
pub fn pool(&self) -> &crate::sql::Pool {
&self.pool
}
#[cfg(any(test, feature = "test_utils"))]
#[must_use]
pub fn for_test(org: Org, conn: TenantConn<DB>, pool: crate::sql::Pool) -> Self {
Self {
org,
conn: TenantConnCell::Ready(conn),
pool,
}
}
}
#[cfg(feature = "postgres")]
impl Tenant<sqlx::Postgres> {
pub fn conn(&mut self) -> &mut sqlx::PgConnection {
match &mut self.conn {
TenantConnCell::Ready(conn) => conn,
TenantConnCell::Deferred(_) => {
unreachable!("Tenant<Postgres> acquires its connection eagerly")
}
}
}
}
#[derive(Debug)]
pub enum TenantRejection {
MissingContext,
NotFound,
Internal(String),
}
impl IntoResponse for TenantRejection {
fn into_response(self) -> Response {
match self {
Self::MissingContext => (
StatusCode::INTERNAL_SERVER_ERROR,
"rustango::server::Builder did not run — Tenant extractor cannot find TenantContext",
)
.into_response(),
Self::NotFound => (StatusCode::NOT_FOUND, "tenant not found").into_response(),
Self::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response(),
}
}
}
#[cfg(feature = "postgres")]
impl<S> FromRequestParts<S> for Tenant<sqlx::Postgres>
where
S: Send + Sync,
{
type Rejection = TenantRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let ctx = parts
.extensions
.get::<Arc<TenantContext<sqlx::Postgres>>>()
.ok_or(TenantRejection::MissingContext)?
.clone();
let org = ctx
.resolver
.resolve(parts, &ctx.pools.registry_pool())
.await
.map_err(|e| TenantRejection::Internal(e.to_string()))?
.ok_or(TenantRejection::NotFound)?;
let conn = ctx
.pools
.acquire(&org)
.await
.map_err(|e| TenantRejection::Internal(e.to_string()))?;
let pool = ctx
.pools
.scoped_pool_dyn(&org)
.await
.map_err(|e| TenantRejection::Internal(e.to_string()))?;
Ok(Tenant {
org,
conn: TenantConnCell::Ready(conn),
pool,
})
}
}
#[cfg(feature = "sqlite")]
impl<S> FromRequestParts<S> for Tenant<sqlx::Sqlite>
where
S: Send + Sync,
{
type Rejection = TenantRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let ctx = parts
.extensions
.get::<Arc<TenantContext<sqlx::Sqlite>>>()
.ok_or(TenantRejection::MissingContext)?
.clone();
let org = ctx
.resolver
.resolve(parts, &ctx.pools.registry_pool())
.await
.map_err(|e| TenantRejection::Internal(e.to_string()))?
.ok_or(TenantRejection::NotFound)?;
let pool = ctx
.pools
.scoped_pool_dyn(&org)
.await
.map_err(|e| TenantRejection::Internal(e.to_string()))?;
Ok(Tenant {
org,
conn: TenantConnCell::Deferred(Arc::clone(&ctx.pools)),
pool,
})
}
}
#[cfg(feature = "mysql")]
impl<S> FromRequestParts<S> for Tenant<sqlx::MySql>
where
S: Send + Sync,
{
type Rejection = TenantRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let ctx = parts
.extensions
.get::<Arc<TenantContext<sqlx::MySql>>>()
.ok_or(TenantRejection::MissingContext)?
.clone();
let org = ctx
.resolver
.resolve(parts, &ctx.pools.registry_pool())
.await
.map_err(|e| TenantRejection::Internal(e.to_string()))?
.ok_or(TenantRejection::NotFound)?;
let pool = ctx
.pools
.scoped_pool_dyn(&org)
.await
.map_err(|e| TenantRejection::Internal(e.to_string()))?;
Ok(Tenant {
org,
conn: TenantConnCell::Deferred(Arc::clone(&ctx.pools)),
pool,
})
}
}