use std::sync::Arc;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use crate::sql::sqlx;
use crate::tenancy::{ChainResolver, Org, OrgResolver, TenantConn, TenantPools};
pub struct TenantContext {
pub pools: Arc<TenantPools>,
pub resolver: ChainResolver,
}
pub struct Tenant {
pub org: Org,
conn: TenantConn,
}
impl Tenant {
pub fn conn(&mut self) -> &mut sqlx::PgConnection {
&mut self.conn
}
#[must_use]
pub fn into_conn(self) -> TenantConn {
self.conn
}
}
#[derive(Debug)]
pub enum TenantRejection {
MissingContext,
NotFound,
Internal(String),
}
impl IntoResponse for TenantRejection {
fn into_response(self) -> Response {
match self {
Self::MissingContext => (
StatusCode::INTERNAL_SERVER_ERROR,
"rustango::server::Builder did not run — Tenant extractor cannot find TenantContext",
)
.into_response(),
Self::NotFound => (StatusCode::NOT_FOUND, "tenant not found").into_response(),
Self::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response(),
}
}
}
impl<S> FromRequestParts<S> for Tenant
where
S: Send + Sync,
{
type Rejection = TenantRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let ctx = parts
.extensions
.get::<Arc<TenantContext>>()
.ok_or(TenantRejection::MissingContext)?
.clone();
let org = ctx
.resolver
.resolve(parts, ctx.pools.registry())
.await
.map_err(|e| TenantRejection::Internal(e.to_string()))?
.ok_or(TenantRejection::NotFound)?;
let conn = ctx
.pools
.acquire(&org)
.await
.map_err(|e| TenantRejection::Internal(e.to_string()))?;
Ok(Tenant { org, conn })
}
}