use std::sync::Arc;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use sqlx::Database;
use crate::sql::sqlx;
use crate::tenancy::{
session::SessionSecret, ChainResolver, DatabaseConn, DatabasePools, Org, OrgResolver,
};
pub struct DatabaseTenantContext<DB: Database> {
pub pools: Arc<DatabasePools<DB>>,
pub resolver: ChainResolver,
pub session_secret: SessionSecret,
pub operator_secret: SessionSecret,
pub registry: crate::sql::Pool,
}
pub struct DatabaseTenant<DB: Database> {
pub org: Org,
conn: DatabaseConn<DB>,
}
impl<DB: Database> DatabaseTenant<DB> {
pub fn conn(&mut self) -> &mut DatabaseConn<DB> {
&mut self.conn
}
#[must_use]
pub fn into_conn(self) -> DatabaseConn<DB> {
self.conn
}
#[cfg(any(test, feature = "test_utils"))]
#[must_use]
pub fn for_test(org: Org, conn: DatabaseConn<DB>) -> Self {
Self { org, conn }
}
}
#[derive(Debug)]
pub enum DatabaseTenantRejection {
MissingContext,
NotFound,
Internal(String),
}
impl IntoResponse for DatabaseTenantRejection {
fn into_response(self) -> Response {
match self {
Self::MissingContext => (
StatusCode::INTERNAL_SERVER_ERROR,
"DatabaseTenantContext not installed — the server wasn't built \
with `Cli::tenants::<DB>()` for the matching backend.",
)
.into_response(),
Self::NotFound => (StatusCode::NOT_FOUND, "tenant not found").into_response(),
Self::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response(),
}
}
}
impl<S, DB> FromRequestParts<S> for DatabaseTenant<DB>
where
S: Send + Sync,
DB: Database + 'static,
{
type Rejection = DatabaseTenantRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let ctx = parts
.extensions
.get::<Arc<DatabaseTenantContext<DB>>>()
.ok_or(DatabaseTenantRejection::MissingContext)?
.clone();
let org = ctx
.resolver
.resolve(parts, &ctx.registry)
.await
.map_err(|e| DatabaseTenantRejection::Internal(e.to_string()))?
.ok_or(DatabaseTenantRejection::NotFound)?;
let conn = ctx
.pools
.acquire(&org)
.await
.map_err(|e| DatabaseTenantRejection::Internal(e.to_string()))?;
Ok(DatabaseTenant { org, conn })
}
}