use pgrx::pg_sys;
use pgrx::prelude::*;
use std::ffi::CStr;
use std::sync::{LazyLock, Mutex};
use crate::TViewError;
use crate::ddl::drop_tview;
static mut PREV_PROCESS_UTILITY_HOOK: pg_sys::ProcessUtility_hook_type = None;
static mut HOOK_IN_PROGRESS: bool = false;
pub unsafe fn install_hook() {
unsafe {
PREV_PROCESS_UTILITY_HOOK = pg_sys::ProcessUtility_hook;
pg_sys::ProcessUtility_hook = Some(tview_process_utility_hook);
}
}
pub unsafe fn ensure_hook_installed() {
unsafe {
static mut HOOK_INSTALLED: bool = false;
if !HOOK_INSTALLED {
install_hook();
HOOK_INSTALLED = true;
}
}
}
#[pg_guard]
#[allow(clippy::too_many_arguments)] unsafe extern "C-unwind" fn tview_process_utility_hook(
pstmt: *mut pg_sys::PlannedStmt,
query_string: *const ::std::os::raw::c_char,
read_only_tree: bool,
context: pg_sys::ProcessUtilityContext::Type,
params: pg_sys::ParamListInfo,
query_env: *mut pg_sys::QueryEnvironment,
dest: *mut pg_sys::DestReceiver,
qc: *mut pg_sys::QueryCompletion,
) {
if unsafe { HOOK_IN_PROGRESS } {
unsafe {
call_prev_hook_or_standard(
pstmt,
query_string,
read_only_tree,
context,
params,
query_env,
dest,
qc,
)
};
return;
}
unsafe { HOOK_IN_PROGRESS = true };
if !pstmt.is_null() && unsafe { !(*pstmt).utilityStmt.is_null() } {
let utility_stmt = unsafe { (*pstmt).utilityStmt };
if unsafe { (*utility_stmt).type_ } == pg_sys::NodeTag::T_TransactionStmt {
#[allow(clippy::cast_ptr_alignment)] let xact_stmt = utility_stmt.cast::<pg_sys::TransactionStmt>();
if !xact_stmt.is_null() {
let kind = unsafe { (*xact_stmt).kind };
if kind == pg_sys::TransactionStmtKind::TRANS_STMT_COMMIT {
if let Err(e) = crate::queue::flush_refresh_queue() {
unsafe { HOOK_IN_PROGRESS = false };
error!("TVIEW refresh failed before COMMIT: {e:?}");
}
if let Err(e) = crate::audit::flush_audit_buffer() {
unsafe { HOOK_IN_PROGRESS = false };
error!("Audit flush failed before COMMIT: {e:?}");
}
}
if kind == pg_sys::TransactionStmtKind::TRANS_STMT_PREPARE
&& !crate::queue::is_queue_empty()
{
unsafe { HOOK_IN_PROGRESS = false };
error!(
"pg_tviews: PREPARE TRANSACTION is not supported when \
TVIEW refreshes are pending; commit or rollback first"
);
}
}
}
}
let result = std::panic::catch_unwind(|| -> Result<bool, TViewError> {
let query_str = if query_string.is_null() {
"[NULL]".to_string()
} else {
unsafe { CStr::from_ptr(query_string) }
.to_string_lossy()
.to_string()
};
let query_lower = query_str.to_lowercase();
if query_lower.contains("create extension") || query_lower.contains("drop extension") {
return Ok(false); }
if pstmt.is_null() {
return Ok(false); }
let pstmt_ref = unsafe { &*pstmt };
if pstmt_ref.utilityStmt.is_null() {
return Ok(false); }
let utility_stmt = pstmt_ref.utilityStmt;
let node_tag = unsafe { (*utility_stmt).type_ };
if node_tag == pg_sys::NodeTag::T_CreateTableAsStmt {
#[allow(clippy::cast_ptr_alignment)]
let ctas = utility_stmt.cast::<pg_sys::CreateTableAsStmt>();
match unsafe { handle_create_table_as(ctas, query_string) } {
Ok(true) => return Ok(true),
Ok(false) => {}
Err(e) => return Err(e),
}
}
if node_tag == pg_sys::NodeTag::T_DropStmt {
#[allow(clippy::cast_ptr_alignment)] let drop_stmt = utility_stmt.cast::<pg_sys::DropStmt>();
match unsafe { handle_drop_table(drop_stmt, query_string) } {
Ok(true) => return Ok(true),
Ok(false) => {}
Err(e) => return Err(e),
}
}
if node_tag == pg_sys::NodeTag::T_AlterTableStmt {
#[allow(clippy::cast_ptr_alignment)] let alter_stmt = utility_stmt.cast::<pg_sys::AlterTableStmt>();
match unsafe { handle_alter_table(alter_stmt, query_string) } {
Ok(true) => return Ok(true),
Ok(false) => {}
Err(e) => return Err(e),
}
}
Ok(false)
});
let should_pass_through = match result {
Ok(Ok(handled)) => !handled, Ok(Err(handler_err)) => {
unsafe { HOOK_IN_PROGRESS = false };
error!("{handler_err}");
#[allow(unreachable_code)] {
true
}
}
Err(panic_info) => {
unsafe { HOOK_IN_PROGRESS = false };
let panic_msg = panic_info
.downcast_ref::<&str>()
.map(|s| (*s).to_string())
.or_else(|| panic_info.downcast_ref::<String>().cloned())
.unwrap_or_else(|| format!("{panic_info:?}"));
error!(
"PANIC in ProcessUtility hook: {} - This is a bug in pg_tviews - please report it!",
panic_msg
);
#[allow(unreachable_code)] {
true
}
}
};
if should_pass_through {
unsafe {
call_prev_hook_or_standard(
pstmt,
query_string,
read_only_tree,
context,
params,
query_env,
dest,
qc,
);
}
drain_pending_populates();
}
unsafe { HOOK_IN_PROGRESS = false };
}
unsafe fn handle_create_table_as(
ctas: *mut pg_sys::CreateTableAsStmt,
query_string: *const ::std::os::raw::c_char,
) -> Result<bool, TViewError> {
unsafe {
if ctas.is_null() {
return Ok(false);
}
let ctas_ref = &*ctas;
if ctas_ref.into.is_null() {
return Ok(false);
}
let into = &*ctas_ref.into;
if into.rel.is_null() {
return Ok(false);
}
let rel = &*into.rel;
if rel.relname.is_null() {
return Ok(false);
}
let Ok(table_name) = CStr::from_ptr(rel.relname).to_str() else {
return Ok(false);
};
if !table_name.starts_with("tv_") {
return Ok(false);
}
let schema_name = if rel.schemaname.is_null() {
String::new()
} else {
CStr::from_ptr(rel.schemaname)
.to_str()
.unwrap_or("")
.to_string()
};
let entity_name = &table_name[3..];
if entity_name.is_empty() {
return Err(TViewError::InvalidTViewName {
name: table_name.to_string(),
reason: "must be tv_<entity>".to_string(),
});
}
let select_sql = if query_string.is_null() {
return Err(crate::internal_error!(
"No query string provided for CREATE TABLE AS"
));
} else if let Ok(sql) = CStr::from_ptr(query_string).to_str() {
let sql_lower = sql.to_lowercase();
let table_pattern = format!("{} as", table_name.to_lowercase());
if let Some(table_pos) = sql_lower.find(&table_pattern) {
let select_start = table_pos + table_pattern.len();
let select_part = sql[select_start..].trim();
select_part.trim_end_matches(';').trim().to_string()
} else {
return Err(TViewError::InvalidSelectStatement {
sql: sql.to_string(),
reason: format!("Could not find '{table_pattern}' in query"),
});
}
} else {
return Err(crate::internal_error!("Failed to parse query string"));
};
match validate_tview_select(&select_sql) {
Ok(()) => {
if let Err(e) = store_pending_tview_select(table_name, &schema_name, &select_sql) {
return Err(crate::internal_error!(
"Failed to store SELECT for '{}': {}",
table_name,
e
));
}
Ok(false) }
Err(e) => {
warning!(
"TVIEW syntax warning for '{}': {} — attempting conversion anyway",
table_name,
e
);
if let Err(store_err) =
store_pending_tview_select(table_name, &schema_name, &select_sql)
{
warning!("Failed to store SELECT for '{}': {}", table_name, store_err);
}
Ok(false) }
}
}
}
fn validate_tview_select(select_sql: &str) -> Result<(), String> {
let sql_lower = select_sql.to_lowercase();
if let Some(pos) = sql_lower.find("select") {
let after = &sql_lower[pos + 6..].trim_start();
if after.starts_with('*') {
return Ok(());
}
}
let has_id = sql_lower.contains(" as id")
|| sql_lower.contains(" id,")
|| sql_lower.contains(" id ")
|| sql_lower.contains(".id,")
|| sql_lower.contains(".id ")
|| sql_lower.contains(".id\n")
|| sql_lower.contains(".id::"); if !has_id {
return Err("Missing required 'id' column (UUID)".to_string());
}
let has_data = sql_lower.contains("jsonb_build_object")
|| sql_lower.contains(" as data")
|| sql_lower.contains(" data,")
|| sql_lower.contains(" data ");
if !has_data {
return Err("Missing required 'data' column (JSONB)".to_string());
}
Ok(())
}
fn store_pending_tview_select(
table_name: &str,
schema_name: &str,
select_sql: &str,
) -> Result<(), String> {
PENDING_TVIEW_SELECTS
.lock()
.map_err(|e| format!("Failed to lock cache: {e}"))?
.insert(
table_name.to_string(),
(schema_name.to_string(), select_sql.to_string()),
);
Ok(())
}
static PENDING_TVIEW_SELECTS: LazyLock<Mutex<std::collections::HashMap<String, (String, String)>>> =
LazyLock::new(|| Mutex::new(std::collections::HashMap::new()));
pub fn take_pending_tview_select(table_name: &str) -> Option<(String, String)> {
PENDING_TVIEW_SELECTS.lock().ok()?.remove(table_name)
}
static PENDING_POPULATES: LazyLock<Mutex<Vec<PendingPopulate>>> =
LazyLock::new(|| Mutex::new(Vec::new()));
struct PendingPopulate {
tv_table_name: String,
view_name: String,
schema_name: String,
}
pub fn enqueue_pending_populate(tv_table_name: &str, view_name: &str, schema_name: &str) {
if let Ok(mut queue) = PENDING_POPULATES.lock() {
queue.push(PendingPopulate {
tv_table_name: tv_table_name.to_string(),
view_name: view_name.to_string(),
schema_name: schema_name.to_string(),
});
}
}
fn drain_pending_populates() {
let entries: Vec<PendingPopulate> = PENDING_POPULATES
.lock()
.map(|mut q| q.drain(..).collect())
.unwrap_or_default();
for entry in entries {
let view_oid = match Spi::get_one::<pg_sys::Oid>(&format!(
"SELECT c.oid FROM pg_class c JOIN pg_namespace n ON c.relnamespace = n.oid \
WHERE c.relname::text = '{}' AND n.nspname::text = '{}' AND c.relkind = 'v'",
entry.view_name, entry.schema_name
)) {
Ok(Some(oid)) => oid,
Ok(None) => {
error!(
"pg_tviews: deferred populate failed — view {}.{} not found",
entry.schema_name, entry.view_name
);
}
Err(e) => {
error!(
"pg_tviews: deferred populate failed — cannot resolve view {}.{}: {e}",
entry.schema_name, entry.view_name
);
}
};
let view_columns = match crate::utils::get_view_columns_by_oid(view_oid) {
Ok(cols) if !cols.is_empty() => cols,
Ok(_) => {
error!(
"pg_tviews: deferred populate failed — view {}.{} has no columns",
entry.schema_name, entry.view_name
);
}
Err(e) => {
error!(
"pg_tviews: deferred populate failed — cannot get columns for {}.{}: {e}",
entry.schema_name, entry.view_name
);
}
};
let qi_schema = crate::utils::quote_identifier(&entry.schema_name);
let qi_tview = crate::utils::quote_identifier(&entry.tv_table_name);
let qi_view = crate::utils::quote_identifier(&entry.view_name);
let col_list = view_columns
.iter()
.map(|c| crate::utils::quote_identifier(c))
.collect::<Vec<_>>()
.join(", ");
let insert_sql = format!(
"INSERT INTO {qi_schema}.{qi_tview} ({col_list}) \
SELECT {col_list} FROM {qi_schema}.{qi_view}"
);
if let Err(e) = Spi::run(&insert_sql) {
error!(
"pg_tviews: deferred populate failed for {}: {e}",
entry.tv_table_name
);
}
}
}
unsafe fn handle_drop_table(
drop_stmt: *mut pg_sys::DropStmt,
_query_string: *const ::std::os::raw::c_char,
) -> Result<bool, TViewError> {
unsafe {
if drop_stmt.is_null() {
return Ok(false);
}
let drop_ref = &*drop_stmt;
if drop_ref.removeType != pg_sys::ObjectType::OBJECT_TABLE {
return Ok(false);
}
let objects = drop_ref.objects;
if objects.is_null() {
return Ok(false);
}
let if_exists = drop_ref.missing_ok;
let num_tables = pg_sys::list_length(objects);
let mut tv_entries: Vec<(i32, String)> = Vec::new(); let mut has_non_tv = false;
for i in 0..num_tables {
let name_list = pg_sys::list_nth(objects, i) as *mut pg_sys::List;
if name_list.is_null() {
has_non_tv = true;
continue;
}
let name_parts = pg_sys::list_length(name_list);
if name_parts == 0 {
has_non_tv = true;
continue;
}
let last_part = pg_sys::list_nth(name_list, name_parts - 1) as *mut pg_sys::String;
if last_part.is_null() {
has_non_tv = true;
continue;
}
let sval = (*last_part).sval;
if sval.is_null() {
has_non_tv = true;
continue;
}
let Ok(table_name) = CStr::from_ptr(sval).to_str() else {
has_non_tv = true;
continue;
};
if table_name.starts_with("tv_") {
tv_entries.push((i, table_name.to_string()));
} else {
has_non_tv = true;
}
}
if tv_entries.is_empty() {
return Ok(false);
}
for (_, name) in &tv_entries {
match drop_tview(name, if_exists) {
Ok(()) => {}
Err(e) => {
if if_exists {
notice!("TVIEW '{}' does not exist, skipping", name);
} else {
return Err(e);
}
}
}
}
if has_non_tv {
for (idx, _) in tv_entries.iter().rev() {
pg_sys::list_delete_nth_cell(objects, *idx);
}
return Ok(false);
}
Ok(true)
} }
unsafe fn handle_alter_table(
alter_stmt: *mut pg_sys::AlterTableStmt,
_query_string: *const ::std::os::raw::c_char,
) -> Result<bool, TViewError> {
unsafe {
if alter_stmt.is_null() {
return Ok(false);
}
let alter_ref = &*alter_stmt;
let relation = alter_ref.relation;
if relation.is_null() {
return Ok(false);
}
let rel_ref = &*relation;
let table_name_cstr = rel_ref.relname;
if table_name_cstr.is_null() {
return Ok(false);
}
let table_name = CStr::from_ptr(table_name_cstr).to_str().unwrap_or("");
if !table_name.starts_with("tv_") {
return Ok(false);
}
let cmds = alter_ref.cmds;
if cmds.is_null() {
return Ok(false);
}
let num_cmds = pg_sys::list_length(cmds);
for i in 0..num_cmds {
let cmd_node = pg_sys::list_nth(cmds, i);
if cmd_node.is_null() {
continue;
}
let cmd = cmd_node as *mut pg_sys::AlterTableCmd;
if cmd.is_null() {
continue;
}
let cmd_ref = &*cmd;
if cmd_ref.subtype == pg_sys::AlterTableType::AT_SetUnLogged {
return Ok(false); } else if cmd_ref.subtype == pg_sys::AlterTableType::AT_SetLogged {
return Ok(false); }
}
Ok(false)
}
}
#[allow(clippy::too_many_arguments)] unsafe fn call_prev_hook_or_standard(
pstmt: *mut pg_sys::PlannedStmt,
query_string: *const ::std::os::raw::c_char,
read_only_tree: bool,
context: pg_sys::ProcessUtilityContext::Type,
params: pg_sys::ParamListInfo,
query_env: *mut pg_sys::QueryEnvironment,
dest: *mut pg_sys::DestReceiver,
qc: *mut pg_sys::QueryCompletion,
) {
unsafe {
match PREV_PROCESS_UTILITY_HOOK {
Some(prev_hook) => {
prev_hook(
pstmt,
query_string,
read_only_tree,
context,
params,
query_env,
dest,
qc,
);
}
None => {
pg_sys::standard_ProcessUtility(
pstmt,
query_string,
read_only_tree,
context,
params,
query_env,
dest,
qc,
);
}
}
}
}