#![allow(dead_code)]
use serde::{Deserialize, Serialize};
use tokio::sync::OnceCell;
use umbral::orm::{ForeignKey, ReverseSet};
use umbral_core::db;
#[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize, umbral::orm::Model)]
#[umbral(table = "rva_parent")]
pub struct Parent {
pub id: i64,
pub name: String,
#[sqlx(skip)]
#[serde(skip)]
#[umbral(reverse_fk = "parent")]
pub child_set: ReverseSet<Child>,
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize, umbral::orm::Model)]
#[umbral(table = "rva_child")]
pub struct Child {
pub id: i64,
pub label: String,
pub parent: ForeignKey<Parent>,
}
#[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize, umbral::orm::Model)]
#[umbral(table = "rva_link")]
pub struct Link {
pub id: i64,
pub note: String,
pub fk_a: ForeignKey<Parent>,
pub fk_b: ForeignKey<Parent>,
}
static BOOT: OnceCell<()> = OnceCell::const_new();
async fn boot() {
BOOT.get_or_init(|| async {
let settings = umbral::Settings::from_env().expect("figment defaults");
let pool = db::connect_sqlite("sqlite::memory:")
.await
.expect("in-memory sqlite");
umbral::App::builder()
.settings(settings)
.database("default", pool.clone())
.model::<Parent>()
.model::<Child>()
.model::<Link>()
.build()
.expect("App::build");
sqlx::query(
"CREATE TABLE rva_parent (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL
)",
)
.execute(&pool)
.await
.expect("CREATE TABLE rva_parent");
sqlx::query(
"CREATE TABLE rva_child (
id INTEGER PRIMARY KEY AUTOINCREMENT,
label TEXT NOT NULL,
parent INTEGER NOT NULL REFERENCES rva_parent(id)
)",
)
.execute(&pool)
.await
.expect("CREATE TABLE rva_child");
sqlx::query(
"CREATE TABLE rva_link (
id INTEGER PRIMARY KEY AUTOINCREMENT,
note TEXT NOT NULL,
fk_a INTEGER NOT NULL REFERENCES rva_parent(id),
fk_b INTEGER NOT NULL REFERENCES rva_parent(id)
)",
)
.execute(&pool)
.await
.expect("CREATE TABLE rva_link");
for name in ["p1", "p2", "p3"] {
sqlx::query("INSERT INTO rva_parent (name) VALUES (?)")
.bind(name)
.execute(&pool)
.await
.expect("seed parent");
}
for (label, parent) in [("c1", 1), ("c2", 1), ("c3", 2)] {
sqlx::query("INSERT INTO rva_child (label, parent) VALUES (?, ?)")
.bind(label)
.bind(parent)
.execute(&pool)
.await
.expect("seed child");
}
for (note, fk_a, fk_b) in [("L1", 1, 2), ("L2", 1, 3)] {
sqlx::query("INSERT INTO rva_link (note, fk_a, fk_b) VALUES (?, ?, ?)")
.bind(note)
.bind(fk_a)
.bind(fk_b)
.execute(&pool)
.await
.expect("seed link");
}
})
.await;
}
use umbral::orm::ReverseRelations;
async fn get_parent(name: &str) -> Parent {
Parent::objects()
.filter(parent::NAME.eq(name))
.first()
.await
.expect("query parent")
.expect("parent row")
}
#[tokio::test]
async fn reverse_returns_children_pointing_at_this_instance() {
boot().await;
let p1 = get_parent("p1").await;
let mut kids = p1
.reverse::<Child>()
.expect("discover FK column on Child")
.fetch()
.await
.expect("fetch children");
kids.sort_by(|a, b| a.label.cmp(&b.label));
let labels: Vec<&str> = kids.iter().map(|c| c.label.as_str()).collect();
assert_eq!(labels, vec!["c1", "c2"], "exactly p1's two children");
for c in &kids {
assert_eq!(c.parent.id(), 1, "child FK must resolve to p1's id");
}
}
#[tokio::test]
async fn reverse_is_chainable_filter_and_count() {
boot().await;
let p1 = get_parent("p1").await;
let narrowed = p1
.reverse::<Child>()
.expect("discover FK")
.filter(child::LABEL.eq("c2"))
.fetch()
.await
.expect("filtered fetch");
assert_eq!(narrowed.len(), 1);
assert_eq!(narrowed[0].label, "c2");
let n = p1
.reverse::<Child>()
.expect("discover FK")
.count()
.await
.expect("count children");
assert_eq!(n, 2, "p1 has two children");
}
#[tokio::test]
async fn reverse_with_zero_children_is_empty_not_error() {
boot().await;
let p3 = get_parent("p3").await;
let kids = p3
.reverse::<Child>()
.expect("discover FK column")
.fetch()
.await
.expect("zero children is Ok, not Err");
assert!(kids.is_empty(), "p3 has no children → empty Vec");
}
#[tokio::test]
async fn ambiguous_two_fks_errors_then_reverse_via_resolves() {
boot().await;
let p1 = get_parent("p1").await;
let err = match p1.reverse::<Link>() {
Ok(_) => panic!("two FKs to Parent must be ambiguous"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("fk_a") && msg.contains("fk_b"),
"error names both candidate FK columns: {msg}"
);
assert!(
msg.contains("reverse_via"),
"error directs the caller to reverse_via: {msg}"
);
let by_a = p1
.reverse_via::<Link>("fk_a")
.expect("explicit fk_a column")
.fetch()
.await
.expect("fetch via fk_a");
let mut a_notes: Vec<&str> = by_a.iter().map(|l| l.note.as_str()).collect();
a_notes.sort();
assert_eq!(a_notes, vec!["L1", "L2"], "fk_a links both point at p1");
let by_b = p1
.reverse_via::<Link>("fk_b")
.expect("explicit fk_b column")
.fetch()
.await
.expect("fetch via fk_b");
assert!(by_b.is_empty(), "no link has fk_b pointing at p1");
let bad_msg = match p1.reverse_via::<Link>("nope_col") {
Ok(_) => panic!("unknown column must error"),
Err(e) => e.to_string(),
};
assert!(bad_msg.contains("nope_col"), "error names the bad column");
}
#[tokio::test]
async fn reverse_to_a_model_with_no_fk_back_errors() {
boot().await;
let c1 = Child::objects()
.filter(child::LABEL.eq("c1"))
.first()
.await
.expect("query child")
.expect("child row");
let err = match c1.reverse::<Parent>() {
Ok(_) => panic!("Parent has no FK to Child"),
Err(e) => e,
};
assert!(
err.to_string().contains("no foreign key"),
"error explains there is no FK back: {err}"
);
}
#[tokio::test]
async fn declared_reverse_set_and_generic_accessor_agree() {
boot().await;
let mut parents = Parent::objects()
.filter(parent::NAME.eq("p1"))
.prefetch_related("child_set")
.fetch()
.await
.expect("prefetch child_set");
let p1 = parents.remove(0);
let declared = p1
.child_set
.resolved()
.expect("child_set resolved after prefetch");
let mut declared_labels: Vec<&str> = declared.iter().map(|c| c.label.as_str()).collect();
declared_labels.sort();
assert_eq!(declared_labels, vec!["c1", "c2"]);
let mut generic = p1
.reverse::<Child>()
.expect("discover FK")
.fetch()
.await
.expect("generic fetch");
generic.sort_by(|a, b| a.label.cmp(&b.label));
let generic_labels: Vec<&str> = generic.iter().map(|c| c.label.as_str()).collect();
assert_eq!(
generic_labels, declared_labels,
"declared ReverseSet and generic accessor agree"
);
}