#![allow(dead_code)]
use sqlx::sqlite::SqlitePoolOptions;
use umbral::db::{DatabaseRouter, RouteContext, Schema, TenantKey};
#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize, serde::Deserialize, umbral::orm::Model)]
#[umbral(table = "sch_widget")]
pub struct Widget {
pub id: i64,
pub name: String,
}
struct TenantSchemaRouter;
impl DatabaseRouter for TenantSchemaRouter {
fn schema_for(&self, ctx: &RouteContext) -> Option<Schema> {
ctx.tenant().and_then(|t| Schema::new(t.as_str()))
}
}
async fn in_tenant<F, T>(tenant: &str, fut: F) -> T
where
F: std::future::Future<Output = T>,
{
umbral::db::route_context_scope(RouteContext::new().with_tenant(TenantKey::new(tenant)), fut)
.await
}
#[tokio::test(flavor = "multi_thread")]
async fn schema_router_isolates_tenant_data_via_attach() {
let pool = SqlitePoolOptions::new()
.max_connections(1)
.connect("sqlite::memory:")
.await
.expect("pool");
for schema in ["tenant_a", "tenant_b"] {
sqlx::query(&format!("ATTACH DATABASE ':memory:' AS {schema}"))
.execute(&pool)
.await
.expect("attach schema");
sqlx::query(&format!(
"CREATE TABLE {schema}.sch_widget \
(id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL)"
))
.execute(&pool)
.await
.expect("create tenant table");
}
sqlx::query(
"CREATE TABLE sch_widget (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL)",
)
.execute(&pool)
.await
.expect("create main table");
umbral::App::builder()
.settings(umbral::Settings::from_env().expect("settings"))
.database("default", pool.clone())
.router(TenantSchemaRouter)
.model::<Widget>()
.build()
.expect("App::build");
in_tenant("tenant_a", async {
Widget::objects()
.create(Widget {
id: 0,
name: "a-row".into(),
})
.await
.expect("create in tenant_a");
})
.await;
in_tenant("tenant_b", async {
Widget::objects()
.create(Widget {
id: 0,
name: "b-row".into(),
})
.await
.expect("create in tenant_b");
})
.await;
let a_rows = in_tenant("tenant_a", async {
Widget::objects().fetch().await.expect("fetch tenant_a")
})
.await;
assert_eq!(a_rows.len(), 1, "tenant_a sees exactly its own row");
assert_eq!(a_rows[0].name, "a-row");
let b_rows = in_tenant("tenant_b", async {
Widget::objects().fetch().await.expect("fetch tenant_b")
})
.await;
assert_eq!(b_rows.len(), 1, "tenant_b sees exactly its own row");
assert_eq!(b_rows[0].name, "b-row");
let a_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM tenant_a.sch_widget")
.fetch_one(&pool)
.await
.unwrap();
let b_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM tenant_b.sch_widget")
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(
(a_count, b_count),
(1, 1),
"one row landed in each tenant schema"
);
in_tenant("tenant_a", async {
let handle = tokio::spawn(async {
Widget::objects()
.create(Widget {
id: 0,
name: "bg".into(),
})
.await
.expect("spawned create");
});
handle.await.expect("join spawned task");
})
.await;
let main_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM main.sch_widget")
.fetch_one(&pool)
.await
.unwrap();
let a_after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM tenant_a.sch_widget")
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(
main_count, 1,
"the spawned task wrote to the default schema, not a tenant"
);
assert_eq!(
a_after, 1,
"the spawned task did NOT inherit tenant_a (still only its original row)"
);
}