use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
use axum::{Extension, Router};
use sqlx::Database;
use tower::ServiceExt as _;
use crate::extractors::TenantContext;
#[cfg(feature = "postgres")]
use crate::sql::sqlx::PgPool;
use crate::tenancy::{
admin::TenantAdminBuilder,
operator_console::{self, SessionSecret},
ChainResolver, DefaultTenantDb, HeaderResolver, SubdomainResolver, TenantPools,
};
pub type ApiRouter = Router<()>;
pub struct Builder<DB: Database = DefaultTenantDb> {
apex: String,
registry_url: String,
pools: Arc<TenantPools<DB>>,
registry: sqlx::Pool<DB>,
show_only: Vec<String>,
admin_title: Option<String>,
admin_subtitle: Option<String>,
api: Option<ApiRouter>,
admin_actions: Vec<PendingAction>,
init_tenancy_fn: crate::tenancy::manage::InitTenancyFn,
routes: crate::tenancy::RouteConfig,
health_endpoints: bool,
static_dirs: Vec<(String, std::path::PathBuf)>,
_phantom: PhantomData<DB>,
}
struct PendingAction {
table: &'static str,
name: &'static str,
handler: crate::admin::AdminActionFn,
}
#[cfg(feature = "postgres")]
impl Builder<sqlx::Postgres> {
pub async fn from_env() -> Result<Self, Box<dyn std::error::Error>> {
let apex = std::env::var("RUSTANGO_APEX_DOMAIN").unwrap_or_else(|_| "localhost".into());
let registry_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://rustango:rustango@localhost:5432/rustango_test".into());
let registry = PgPool::connect(®istry_url).await?;
Ok(Self::from_pool(registry, registry_url, apex))
}
}
impl<DB: Database> Builder<DB> {
pub fn from_pool(
registry: sqlx::Pool<DB>,
registry_url: impl Into<String>,
apex: impl Into<String>,
) -> Self {
let pools = Arc::new(TenantPools::<DB>::new(registry.clone()));
Self {
apex: apex.into(),
registry_url: registry_url.into(),
pools,
registry,
show_only: Vec::new(),
admin_title: None,
admin_subtitle: None,
api: None,
admin_actions: Vec::new(),
init_tenancy_fn: crate::tenancy::init_tenancy,
routes: crate::tenancy::RouteConfig::default(),
health_endpoints: false,
static_dirs: Vec::new(),
_phantom: PhantomData,
}
}
#[must_use]
pub fn with_health(mut self) -> Self {
self.health_endpoints = true;
self
}
#[must_use]
pub fn with_static(
mut self,
prefix: impl Into<String>,
root_dir: impl Into<std::path::PathBuf>,
) -> Self {
self.static_dirs.push((prefix.into(), root_dir.into()));
self
}
#[must_use]
pub fn routes(mut self, routes: crate::tenancy::RouteConfig) -> Self {
self.routes = routes;
self
}
#[must_use]
pub fn user_model<U: crate::tenancy::TenantUserModel>(mut self) -> Self {
self.init_tenancy_fn = crate::tenancy::init_tenancy_with::<U>;
self
}
#[must_use]
pub fn admin_title(mut self, title: impl Into<String>) -> Self {
self.admin_title = Some(title.into());
self
}
#[must_use]
pub fn admin_subtitle(mut self, subtitle: impl Into<String>) -> Self {
self.admin_subtitle = Some(subtitle.into());
self
}
#[must_use]
pub fn admin_show_only<I, S>(mut self, models: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.show_only = models.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn api(mut self, router: ApiRouter) -> Self {
self.api = Some(router);
self
}
#[must_use]
pub fn admin_register_action<F>(
mut self,
model_table: &'static str,
action_name: &'static str,
handler: F,
) -> Self
where
F: for<'a> Fn(
&'a crate::sql::Pool,
&'a [crate::core::SqlValue],
) -> crate::admin::AdminActionFuture<'a>
+ Send
+ Sync
+ 'static,
{
self.admin_actions.push(PendingAction {
table: model_table,
name: action_name,
handler: std::sync::Arc::new(handler),
});
self
}
pub async fn seed_with<F, Fut>(self, hook: F) -> Result<Self, Box<dyn std::error::Error>>
where
F: FnOnce(Arc<TenantPools<DB>>, sqlx::Pool<DB>, String) -> Fut,
Fut: Future<Output = Result<(), Box<dyn std::error::Error>>>,
{
hook(
self.pools.clone(),
self.registry.clone(),
self.registry_url.clone(),
)
.await?;
Ok(self)
}
pub async fn migrate<P: AsRef<std::path::Path>>(
self,
project_root: P,
) -> Result<Self, Box<dyn std::error::Error>>
where
crate::sql::Pool: From<sqlx::Pool<DB>>,
{
let root = project_root.as_ref();
std::fs::create_dir_all(root)?;
let dirs = crate::migrate::discover_migration_dirs(root);
if dirs.is_empty() && root_has_json_files(root) {
(self.init_tenancy_fn)(root)?;
let _ = crate::tenancy::migrate_registry(self.pools.as_ref(), root).await?;
let _ =
crate::tenancy::migrate_tenants_dyn(self.pools.as_ref(), root, &self.registry_url)
.await?;
return Ok(self);
}
let flat = root.join("migrations");
std::fs::create_dir_all(&flat)?;
(self.init_tenancy_fn)(&flat)?;
let dirs = crate::migrate::discover_migration_dirs(root);
for dir in &dirs {
let _ = crate::tenancy::migrate_registry(self.pools.as_ref(), dir).await?;
let _ =
crate::tenancy::migrate_tenants_dyn(self.pools.as_ref(), dir, &self.registry_url)
.await?;
}
Ok(self)
}
pub async fn serve(self, addr: &str) -> Result<(), Box<dyn std::error::Error>>
where
crate::sql::Pool: From<sqlx::Pool<DB>>,
{
let resolver_for_admin = build_resolver(&self.apex);
if self.pools.pool_config().prewarm_active_tenants {
match self.pools.prewarm_database_tenants().await {
Ok(report) => {
tracing::info!(
target: "rustango::server",
warmed = report.warmed,
failed = report.failed,
skipped_cap = report.skipped_cap,
"tenant pools pre-warmed at boot",
);
}
Err(e) => {
tracing::warn!(
target: "rustango::server",
error = %e,
"tenant-pool pre-warm failed (non-fatal; lazy build will retry)",
);
}
}
}
let session_secret_for_tenant = SessionSecret::from_env_or_disk(std::path::Path::new(
"./var/.rustango_tenant_session.key",
));
let operator_secret = SessionSecret::from_env_or_disk(std::path::Path::new(
"./var/.rustango_operator_session.key",
));
let ctx = Arc::new(TenantContext {
pools: self.pools.clone(),
resolver: build_resolver(&self.apex),
session_secret: session_secret_for_tenant.clone(),
operator_secret: operator_secret.clone(),
});
let mut tenant_admin_builder = TenantAdminBuilder::new(
self.pools.clone(),
self.registry_url.clone(),
resolver_for_admin,
)
.routes(self.routes.clone());
if !self.show_only.is_empty() {
tenant_admin_builder = tenant_admin_builder.show_only(self.show_only.clone());
}
if let Some(t) = self.admin_title {
tenant_admin_builder = tenant_admin_builder.title(t);
}
if let Some(s) = self.admin_subtitle {
tenant_admin_builder = tenant_admin_builder.subtitle(s);
}
for action in self.admin_actions {
let handler = action.handler;
tenant_admin_builder = tenant_admin_builder.register_action(
action.table,
action.name,
move |pool, pks| handler(pool, pks),
);
}
let tenant_admin = tenant_admin_builder
.with_session(session_secret_for_tenant.clone())
.build();
let had_api = self.api.is_some();
let api = if had_api || self.health_endpoints || !self.static_dirs.is_empty() {
let mut r = self.api.unwrap_or_default();
for (prefix, root) in &self.static_dirs {
r = r.nest(
prefix,
crate::static_files::static_router(crate::static_files::StaticFiles::new(
root.clone(),
)),
);
}
if self.health_endpoints {
r = r.merge(crate::health::health_router(self.registry.clone()));
}
Some(r)
} else {
None
};
let admin_routes = build_admin_routes(&tenant_admin, &self.routes);
let tenant_app = match api {
Some(router) => router.layer(Extension(ctx.clone())).merge(admin_routes),
None => admin_routes,
};
let brand_storage_for_op = crate::tenancy::branding::default_brand_storage();
let operator_admin = operator_console::router_with_impersonation(
self.registry,
self.pools.clone().into_invalidator(),
operator_secret,
brand_storage_for_op,
session_secret_for_tenant.clone(),
self.routes.impersonation_handoff_url.clone(),
);
let app = Router::new().fallback_service(tower::service_fn({
let operator = operator_admin.clone();
let tenants = tenant_app.clone();
let apex = self.apex.clone();
move |req: axum::http::Request<axum::body::Body>| {
let mut operator = operator.clone();
let mut tenants = tenants.clone();
let apex = apex.clone();
async move {
let host = req
.headers()
.get(axum::http::header::HOST)
.and_then(|v| v.to_str().ok())
.map(|s| s.split(':').next().unwrap_or(s).to_owned())
.unwrap_or_default();
let response = if host == apex {
operator.as_service().oneshot(req).await
} else {
tenants.as_service().oneshot(req).await
};
response.map_err(|e| -> std::convert::Infallible {
panic!("axum router service is Infallible: {e}")
})
}
}
}));
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<std::net::SocketAddr>(),
)
.await?;
Ok(())
}
}
fn build_resolver(apex: &str) -> ChainResolver {
ChainResolver::new()
.push(SubdomainResolver::new(apex.to_owned()))
.push(HeaderResolver::default())
}
fn build_admin_routes(tenant_admin: &Router, routes: &crate::tenancy::RouteConfig) -> Router {
use axum::routing::any;
let make = || {
let svc = tenant_admin.clone();
move |req: axum::http::Request<axum::body::Body>| {
let svc = svc.clone();
async move {
let (parts, body) = req.into_parts();
let mut builder = axum::http::Request::builder()
.method(&parts.method)
.uri(&parts.uri);
for (k, v) in &parts.headers {
builder = builder.header(k, v);
}
let fresh = builder.body(body).expect("valid request");
svc.clone()
.oneshot(fresh)
.await
.unwrap_or_else(|_| unreachable!("Router is Infallible"))
}
}
};
let admin_slash = format!("{}/", routes.admin_url);
let admin_glob = format!("{}/{{*rest}}", routes.admin_url);
let static_glob = format!("{}/{{*rest}}", routes.static_url);
let brand_glob = format!("{}/{{*rest}}", routes.brand_url);
let mut r = Router::new()
.route(&routes.admin_url, any(make()))
.route(&admin_slash, any(make()))
.route(&admin_glob, any(make()))
.route(&routes.login_url, any(make()))
.route(&routes.logout_url, any(make()))
.route(&routes.change_password_url, any(make()))
.route(&routes.impersonation_handoff_url, any(make()))
.route(&static_glob, any(make()))
.route(&brand_glob, any(make()))
.route("/__end-impersonation", any(make()));
if routes.admin_url != "/__admin" {
r = r
.route("/__admin", any(make()))
.route("/__admin/", any(make()))
.route("/__admin/{*rest}", any(make()));
}
r
}
fn root_has_json_files(root: &std::path::Path) -> bool {
let Ok(read) = std::fs::read_dir(root) else {
return false;
};
read.filter_map(Result::ok)
.any(|e| e.path().extension().and_then(|s| s.to_str()) == Some("json"))
}