use pgrx::AllocatedByPostgres;
use pgrx::datum::DatumWithOid;
use pgrx::heap_tuple::PgHeapTuple;
use pgrx::pg_sys;
use pgrx::prelude::*;
use std::collections::HashMap;
use std::sync::{LazyLock, Mutex};
pub fn spi_run_ddl(sql: &str) -> Result<(), String> {
use std::ffi::CString;
let c_sql = CString::new(sql).map_err(|e| format!("DDL SQL contains null byte: {e}"))?;
unsafe {
#[allow(clippy::cast_possible_wrap)] let connect_result = pg_sys::SPI_connect_ext(pg_sys::SPI_OPT_NONATOMIC as i32);
#[allow(clippy::cast_possible_wrap)]
if connect_result != pg_sys::SPI_OK_CONNECT as i32 {
return Err(format!("SPI_connect_ext failed: {connect_result}"));
}
let opts = pg_sys::SPIExecuteOptions {
read_only: false,
allow_nonatomic: true,
tcount: 0,
..pg_sys::SPIExecuteOptions::default()
};
let execute_result =
pg_sys::SPI_execute_extended(c_sql.as_ptr(), std::ptr::from_ref(&opts));
pg_sys::SPI_finish();
if execute_result < 0 {
return Err(format!(
"SPI_execute_extended returned error code {execute_result} for DDL: {sql}"
));
}
}
Ok(())
}
pub fn spi_get_string(query: &str) -> spi::Result<Option<String>> {
Spi::connect(|client| {
let mut rows = client.select(query, Some(1), &[])?;
match rows.next() {
Some(row) => Ok(row[1].value::<String>()?),
None => Ok(None),
}
})
}
use pgrx::pg_sys::Oid;
pub enum IntExtraction {
Value(i64),
Null,
Missing,
}
pub fn tuple_get_i64(tuple: &PgHeapTuple<'_, AllocatedByPostgres>, col: &str) -> IntExtraction {
match tuple.get_by_name::<i64>(col) {
Ok(Some(v)) => return IntExtraction::Value(v),
Ok(None) => return IntExtraction::Null,
Err(_) => {} }
match tuple.get_by_name::<i32>(col) {
Ok(Some(v)) => IntExtraction::Value(i64::from(v)),
Ok(None) => IntExtraction::Null,
Err(_) => IntExtraction::Missing,
}
}
pub fn extract_pk(trigger: &PgTrigger) -> spi::Result<i64> {
let tuple = trigger
.new()
.or_else(|| trigger.old())
.expect("Row must exist for AFTER trigger");
let table_oid = trigger
.relation()
.map_err(|_| crate::TViewError::SpiError {
query: "get trigger relation".to_string(),
error: "Failed to get trigger relation".to_string(),
})?
.oid();
let entity = crate::catalog::entity_for_table(table_oid)?.ok_or_else(|| {
crate::TViewError::SpiError {
query: "entity_for_table".to_string(),
error: format!("Table OID {table_oid:?} not managed by pg_tviews"),
}
})?;
let pk_column = format!("pk_{entity}");
match tuple_get_i64(&tuple, &pk_column) {
IntExtraction::Value(v) => Ok(v),
IntExtraction::Null => Err(crate::TViewError::SpiError {
query: pk_column.clone(),
error: format!("{pk_column} must not be NULL"),
}
.into()),
IntExtraction::Missing => Err(crate::TViewError::SpiError {
query: pk_column.clone(),
error: format!("{pk_column} column not found on tuple (expected INTEGER or BIGINT)"),
}
.into()),
}
}
pub fn lookup_view_for_source(view_oid: Oid) -> spi::Result<String> {
relname_from_oid(view_oid)
}
static OID_RELNAME_CACHE: LazyLock<Mutex<HashMap<Oid, String>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
static OID_QUALIFIED_RELNAME_CACHE: LazyLock<Mutex<HashMap<Oid, String>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
pub fn invalidate_oid_relname_cache() {
OID_RELNAME_CACHE
.lock()
.unwrap_or_else(|e| e.into_inner())
.clear();
OID_QUALIFIED_RELNAME_CACHE
.lock()
.unwrap_or_else(|e| e.into_inner())
.clear();
}
pub static VIEW_COLUMNS_CACHE: LazyLock<Mutex<HashMap<String, Vec<String>>>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
pub fn invalidate_view_columns_cache() {
let mut cache = VIEW_COLUMNS_CACHE.lock().unwrap_or_else(|e| e.into_inner());
cache.clear();
}
pub type DedupDmlCache = HashMap<String, (String, String)>;
pub static DEDUP_DML_CACHE: LazyLock<Mutex<DedupDmlCache>> =
LazyLock::new(|| Mutex::new(HashMap::new()));
pub fn invalidate_dedup_dml_cache() {
let mut cache = DEDUP_DML_CACHE.lock().unwrap_or_else(|e| e.into_inner());
cache.clear();
}
pub fn relname_from_oid(oid: Oid) -> spi::Result<String> {
{
let cache = OID_RELNAME_CACHE.lock().unwrap_or_else(|e| e.into_inner());
if let Some(name) = cache.get(&oid) {
return Ok(name.clone());
}
}
let name: String = Spi::connect(|client| {
let args =
vec![unsafe { DatumWithOid::new(oid, PgOid::BuiltIn(PgBuiltInOids::OIDOID).value()) }];
let mut rows = client.select(
"SELECT relname::text AS relname FROM pg_class WHERE oid = $1",
None,
&args,
)?;
if let Some(row) = rows.next() {
row["relname"].value::<String>()?.ok_or_else(|| {
spi::Error::from(crate::TViewError::SpiError {
query: "SELECT relname::text AS relname FROM pg_class WHERE oid = $1"
.to_string(),
error: "relname column is NULL".to_string(),
})
})
} else {
Err(spi::Error::from(crate::TViewError::SpiError {
query: "SELECT relname::text AS relname FROM pg_class WHERE oid = $1".to_string(),
error: format!("No pg_class entry for oid: {oid:?}"),
}))
}
})?;
OID_RELNAME_CACHE
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(oid, name.clone());
Ok(name)
}
pub fn qualified_relname_from_oid(oid: Oid) -> spi::Result<String> {
{
let cache = OID_QUALIFIED_RELNAME_CACHE
.lock()
.unwrap_or_else(|e| e.into_inner());
if let Some(name) = cache.get(&oid) {
return Ok(name.clone());
}
}
let qname: String = Spi::connect(|client| {
let args = vec![unsafe {
DatumWithOid::new(oid, PgOid::BuiltIn(PgBuiltInOids::OIDOID).value())
}];
let mut rows = client.select(
"SELECT quote_ident(n.nspname) || '.' || quote_ident(c.relname) AS qname \
FROM pg_class c \
JOIN pg_namespace n ON n.oid = c.relnamespace \
WHERE c.oid = $1",
None,
&args,
)?;
if let Some(row) = rows.next() {
row["qname"].value::<String>()?.ok_or_else(|| {
spi::Error::from(crate::TViewError::SpiError {
query: "qualified_relname_from_oid".to_string(),
error: "qname column is NULL".to_string(),
})
})
} else {
Err(spi::Error::from(crate::TViewError::SpiError {
query: "qualified_relname_from_oid".to_string(),
error: format!("No pg_class entry for oid: {oid:?}"),
}))
}
})?;
OID_QUALIFIED_RELNAME_CACHE
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(oid, qname.clone());
Ok(qname)
}
pub fn get_view_columns(view_name: &str) -> spi::Result<Vec<String>> {
{
let cache = VIEW_COLUMNS_CACHE.lock().unwrap_or_else(|e| e.into_inner());
if let Some(cols) = cache.get(view_name) {
return Ok(cols.clone());
}
}
let cols: Vec<String> = Spi::connect(|client| -> spi::Result<Vec<String>> {
let args = vec![unsafe {
DatumWithOid::new(view_name, PgOid::BuiltIn(PgBuiltInOids::TEXTOID).value())
}];
let rows = client.select(
"SELECT a.attname::text \
FROM pg_attribute a \
JOIN pg_class c ON c.oid = a.attrelid \
WHERE c.relname = $1 AND a.attnum > 0 AND NOT a.attisdropped \
ORDER BY a.attnum",
None,
&args,
)?;
let mut result = Vec::with_capacity(10);
for r in rows {
if let Some(name) = r["attname"].value::<String>()? {
result.push(name);
}
}
Ok(result)
})?;
VIEW_COLUMNS_CACHE
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(view_name.to_string(), cols.clone());
Ok(cols)
}
pub fn get_view_columns_by_oid(rel_oid: Oid) -> spi::Result<Vec<String>> {
let name = relname_from_oid(rel_oid)?;
get_view_columns(&name)
}
#[must_use]
pub fn quote_identifier(name: &str) -> String {
format!("\"{}\"", name.replace('"', "\"\""))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quote_identifier_normal() {
assert_eq!(quote_identifier("post"), "\"post\"");
}
#[test]
fn test_quote_identifier_uppercase() {
assert_eq!(quote_identifier("Post"), "\"Post\"");
}
#[test]
fn test_quote_identifier_with_underscore() {
assert_eq!(quote_identifier("pk_user"), "\"pk_user\"");
}
#[test]
fn test_quote_identifier_with_internal_quotes() {
assert_eq!(quote_identifier("test\"col"), "\"test\"\"col\"");
}
#[test]
fn test_oid_relname_cache_invalidation() {
use pg_sys::Oid;
invalidate_oid_relname_cache();
{
let mut cache = OID_RELNAME_CACHE.lock().unwrap();
cache.insert(Oid::from(123), "test_table".to_string());
}
{
let cache = OID_RELNAME_CACHE.lock().unwrap();
assert!(cache.get(&Oid::from(123)).is_some());
}
invalidate_oid_relname_cache();
{
let cache = OID_RELNAME_CACHE.lock().unwrap();
assert!(cache.is_empty());
}
}
#[test]
fn test_view_columns_cache_invalidation() {
invalidate_view_columns_cache();
{
let mut cache = VIEW_COLUMNS_CACHE.lock().unwrap();
cache.insert(
"v_user".to_string(),
vec!["id".to_string(), "name".to_string()],
);
cache.insert(
"v_post".to_string(),
vec!["id".to_string(), "title".to_string(), "user_id".to_string()],
);
}
{
let cache = VIEW_COLUMNS_CACHE.lock().unwrap();
assert_eq!(cache.len(), 2);
assert!(cache.contains_key("v_user"));
assert!(cache.contains_key("v_post"));
}
invalidate_view_columns_cache();
{
let cache = VIEW_COLUMNS_CACHE.lock().unwrap();
assert!(cache.is_empty());
}
}
}