use pgrx::datum::DatumWithOid;
use pgrx::prelude::*;
use std::collections::HashMap;
use crate::queue::RefreshKey;
pub fn find_parents_for(
key: &RefreshKey,
graph: &crate::queue::EntityDepGraph,
) -> crate::TViewResult<Vec<RefreshKey>> {
let parent_entities = find_parent_entities(&key.entity, graph)?;
if parent_entities.is_empty() {
return Ok(Vec::new());
}
if key.is_dedup() {
return Ok(Vec::new());
}
let expected_parents = parent_entities.len().saturating_mul(8);
let mut parent_keys = Vec::with_capacity(expected_parents);
for parent_entity in parent_entities {
let affected_pks = find_affected_pks(&parent_entity, &key.entity, key.pk)?;
for pk in affected_pks {
parent_keys.push(RefreshKey::pk(&parent_entity, pk));
}
}
Ok(parent_keys)
}
pub fn find_parents_batch(
keys: &[RefreshKey],
graph: &crate::queue::EntityDepGraph,
) -> crate::TViewResult<HashMap<RefreshKey, Vec<RefreshKey>>> {
let mut result: HashMap<RefreshKey, Vec<RefreshKey>> = HashMap::with_capacity(keys.len());
let pk_keys: Vec<_> = keys.iter().filter(|k| !k.is_dedup()).cloned().collect();
if pk_keys.is_empty() {
return Ok(result);
}
let batch_groups = build_batch_groups(&pk_keys, graph)?;
for ((parent_entity, child_entity), child_pks) in batch_groups {
let affected_pk_map = find_affected_pks_batch(&parent_entity, &child_entity, &child_pks)?;
for key in &pk_keys {
if key.entity == child_entity
&& let Some(affected_pks) = affected_pk_map.get(&key.pk)
{
for pk in affected_pks {
result
.entry(key.clone())
.or_insert_with(|| Vec::with_capacity(8))
.push(RefreshKey::pk(&parent_entity, *pk));
}
}
}
}
Ok(result)
}
fn build_batch_groups(
keys: &[RefreshKey],
graph: &crate::queue::EntityDepGraph,
) -> crate::TViewResult<HashMap<(String, String), Vec<i64>>> {
let mut groups: HashMap<(String, String), Vec<i64>> = HashMap::with_capacity(4);
for key in keys {
let parent_entities = graph.parents.get(&key.entity).cloned().unwrap_or_default();
for parent_entity in parent_entities {
groups
.entry((parent_entity, key.entity.clone()))
.or_insert_with(|| Vec::with_capacity(8))
.push(key.pk);
}
}
Ok(groups)
}
fn find_affected_pks_batch(
parent_entity: &str,
child_entity: &str,
child_pks: &[i64],
) -> spi::Result<HashMap<i64, Vec<i64>>> {
let fk_col = format!("fk_{child_entity}");
let parent_table = format!("tv_{parent_entity}");
let parent_pk_col = format!("pk_{parent_entity}");
let query =
format!("SELECT {fk_col}, {parent_pk_col} FROM {parent_table} WHERE {fk_col} = ANY($1)");
Spi::connect(|client| {
let pks_array = child_pks.to_vec();
let args = vec![unsafe {
DatumWithOid::new(
pks_array,
PgOid::BuiltIn(PgBuiltInOids::INT8ARRAYOID).value(),
)
}];
let rows = client.select(&query, None, &args)?;
let mut result: HashMap<i64, Vec<i64>> = HashMap::with_capacity(child_pks.len());
for row in rows {
if let (Some(child_pk), Some(parent_pk)) = (
row[fk_col.as_str()].value::<i64>()?,
row[parent_pk_col.as_str()].value::<i64>()?,
) {
result
.entry(child_pk)
.or_insert_with(|| Vec::with_capacity(4))
.push(parent_pk);
}
}
Ok(result)
})
}
fn find_parent_entities(
child_entity: &str,
graph: &crate::queue::EntityDepGraph,
) -> spi::Result<Vec<String>> {
Ok(graph.parents.get(child_entity).cloned().unwrap_or_default())
}
fn find_affected_pks(
parent_entity: &str,
child_entity: &str,
child_pk: i64,
) -> spi::Result<Vec<i64>> {
let fk_col = format!("fk_{child_entity}");
let parent_table = format!("tv_{parent_entity}");
let parent_pk_col = format!("pk_{parent_entity}");
let query = format!("SELECT {parent_pk_col} FROM {parent_table} WHERE {fk_col} = $1");
let args = vec![unsafe {
DatumWithOid::new(child_pk, PgOid::BuiltIn(PgBuiltInOids::INT8OID).value())
}];
Spi::connect(|client| {
let rows = client.select(&query, None, &args)?;
let mut pks = Vec::new();
for row in rows {
if let Some(pk) = row[parent_pk_col.as_str()].value::<i64>()? {
pks.push(pk);
}
}
Ok(pks)
})
}
#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
mod tests {
use super::*;
use pgrx::prelude::Spi;
#[pg_test]
fn test_find_parents_batch_pre_allocation() {
Spi::run("CREATE TABLE tb_user (pk_user BIGSERIAL PRIMARY KEY, name TEXT)").unwrap();
Spi::run(
"CREATE TABLE tb_post (
pk_post BIGSERIAL PRIMARY KEY,
fk_user BIGINT REFERENCES tb_user(pk_user),
title TEXT
)",
)
.unwrap();
Spi::run(
"CREATE TABLE tb_comment (
pk_comment BIGSERIAL PRIMARY KEY,
fk_user BIGINT REFERENCES tb_user(pk_user),
fk_post BIGINT REFERENCES tb_post(pk_post),
text TEXT
)",
)
.unwrap();
Spi::run("INSERT INTO tb_user (pk_user, name) VALUES (1, 'Alice'), (2, 'Bob')").unwrap();
Spi::run("INSERT INTO tb_post (pk_post, fk_user, title) VALUES (1, 1, 'Post 1'), (2, 1, 'Post 2'), (3, 2, 'Post 3')").unwrap();
Spi::run(
"INSERT INTO tb_comment (pk_comment, fk_user, fk_post, text)
VALUES (1, 1, 1, 'Comment 1'), (2, 1, 2, 'Comment 2'), (3, 2, 3, 'Comment 3')",
)
.unwrap();
Spi::run(
"
SELECT pg_tviews_create('user', $$
SELECT pk_user, jsonb_build_object('name', name) AS data
FROM tb_user
$$)
",
)
.unwrap();
Spi::run(
"
SELECT pg_tviews_create('post', $$
SELECT pk_post, fk_user,
jsonb_build_object('title', title, 'author', v_user.data) AS data
FROM tb_post
LEFT JOIN v_user ON v_user.pk_user = tb_post.fk_user
$$)
",
)
.unwrap();
Spi::run(
"
SELECT pg_tviews_create('comment', $$
SELECT pk_comment, fk_user, fk_post,
jsonb_build_object('text', text) AS data
FROM tb_comment
$$)
",
)
.unwrap();
let graph = crate::queue::EntityDepGraph::load().unwrap();
let keys = vec![
crate::queue::RefreshKey::pk("user", 1),
crate::queue::RefreshKey::pk("user", 2),
];
let batched_result = find_parents_batch(&keys, &graph).unwrap();
assert!(!batched_result.is_empty(), "Should find parent entities");
let key1 = &keys[0];
if let Some(parents) = batched_result.get(key1) {
let post_parents: Vec<_> = parents.iter().filter(|p| p.entity == "post").collect();
assert!(!post_parents.is_empty(), "User 1 should have post parents");
}
let key2 = &keys[1];
if let Some(parents) = batched_result.get(key2) {
let post_parents: Vec<_> = parents.iter().filter(|p| p.entity == "post").collect();
assert!(!post_parents.is_empty(), "User 2 should have post parents");
}
}
#[pg_test]
fn test_find_parents_batch_no_parents() {
Spi::run("CREATE TABLE tb_tag (pk_tag BIGSERIAL PRIMARY KEY, name TEXT)").unwrap();
Spi::run("INSERT INTO tb_tag (pk_tag, name) VALUES (1, 'Tag1'), (2, 'Tag2')").unwrap();
Spi::run(
"
SELECT pg_tviews_create('tag', $$
SELECT pk_tag, jsonb_build_object('name', name) AS data
FROM tb_tag
$$)
",
)
.unwrap();
let graph = crate::queue::EntityDepGraph::load().unwrap();
let keys = vec![
crate::queue::RefreshKey::pk("tag", 1),
crate::queue::RefreshKey::pk("tag", 2),
];
let result = find_parents_batch(&keys, &graph).unwrap();
let has_tag_results = keys.iter().any(|k| result.contains_key(k));
if has_tag_results {
for parents in result.values() {
assert!(parents.is_empty(), "Tag should have no parents");
}
}
}
}