use crate::error::{TViewError, TViewResult};
use crate::schema::TViewSchema;
use crate::utils::quote_identifier;
use pgrx::datum::DatumWithOid;
use pgrx::prelude::*;
pub fn convert_existing_table_to_tview(table_name: &str) -> TViewResult<()> {
let entity_name = extract_entity_name(table_name)?;
do_conversion(table_name, entity_name)?;
Ok(())
}
fn do_conversion(table_name: &str, entity_name: &str) -> TViewResult<()> {
validate_tview_structure(table_name, entity_name)?;
let schema = infer_schema_from_table(table_name)?;
let data_backup = backup_table_data(table_name, &schema)?;
let base_tables = infer_base_tables(table_name)?;
let qi_table = quote_identifier(table_name);
crate::utils::spi_run_ddl(&format!("DROP TABLE {qi_table} CASCADE")).map_err(|e| {
TViewError::SpiError {
query: format!("DROP TABLE {qi_table} CASCADE"),
error: e,
}
})?;
reconstruct_as_tview(table_name, entity_name, &schema, &base_tables, &data_backup)?;
Ok(())
}
fn validate_tview_structure(table_name: &str, _entity_name: &str) -> TViewResult<()> {
let columns = get_table_columns(table_name)?;
let id_col = columns.iter().find(|c| c.name == "id").ok_or_else(|| {
TViewError::RequiredColumnMissing {
column_name: "id".to_string(),
context: format!(
"Table '{}' must have an 'id' column (UUID). Found: {}",
table_name,
columns
.iter()
.map(|c| c.name.clone())
.collect::<Vec<_>>()
.join(", ")
),
}
})?;
if id_col.data_type != "uuid" {
return Err(TViewError::InvalidSelectStatement {
sql: table_name.to_string(),
reason: format!("Column 'id' must be UUID, found {}", id_col.data_type),
});
}
let data_col = columns.iter().find(|c| c.name == "data").ok_or_else(|| {
TViewError::RequiredColumnMissing {
column_name: "data".to_string(),
context: format!(
"Table '{}' must have a 'data' column (JSONB). Found: {}",
table_name,
columns
.iter()
.map(|c| c.name.clone())
.collect::<Vec<_>>()
.join(", ")
),
}
})?;
if data_col.data_type != "jsonb" {
return Err(TViewError::InvalidSelectStatement {
sql: table_name.to_string(),
reason: format!("Column 'data' must be JSONB, found {}", data_col.data_type),
});
}
Ok(())
}
#[derive(Debug)]
struct ColumnInfo {
name: String,
data_type: String,
#[allow(dead_code)] is_nullable: bool,
}
fn get_table_columns(table_name: &str) -> TViewResult<Vec<ColumnInfo>> {
let mut columns = Vec::new();
Spi::connect(|client| {
let args = vec![unsafe {
DatumWithOid::new(table_name, PgOid::BuiltIn(PgBuiltInOids::TEXTOID).value())
}];
let results = client.select(
"SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_name = $1
ORDER BY ordinal_position",
None,
&args,
)?;
for row in results {
columns.push(ColumnInfo {
name: row["column_name"].value()?.unwrap_or_default(),
data_type: row["data_type"].value()?.unwrap_or_default(),
is_nullable: row["is_nullable"].value::<String>()?.unwrap_or_default() == "YES",
});
}
Ok::<_, spi::Error>(())
})?;
Ok(columns)
}
fn infer_schema_from_table(table_name: &str) -> TViewResult<TViewSchema> {
let columns = get_table_columns(table_name)?;
let pk_col = columns
.iter()
.find(|c| c.name.starts_with("pk_"))
.ok_or_else(|| TViewError::RequiredColumnMissing {
column_name: "pk_<entity>".to_string(),
context: format!(
"Table '{}' must have a primary key column named 'pk_<entity>' \
(e.g., pk_user, pk_post). Found: {}",
table_name,
columns
.iter()
.map(|c| c.name.clone())
.collect::<Vec<_>>()
.join(", ")
),
})?;
let entity_name = pk_col.name.strip_prefix("pk_").unwrap();
Ok(TViewSchema {
entity_name: Some(entity_name.to_string()),
pk_column: Some(pk_col.name.clone()),
id_column: Some("id".to_string()),
data_column: Some("data".to_string()),
identifier_column: Some(pk_col.name.clone()),
fk_columns: vec![],
uuid_fk_columns: vec![],
additional_columns: vec![],
additional_columns_with_types: vec![],
})
}
fn backup_table_data(table_name: &str, _schema: &TViewSchema) -> TViewResult<Vec<BackupRow>> {
let backup = Spi::connect(|client| {
let qi_table = quote_identifier(table_name);
let query = format!("SELECT * FROM {qi_table}");
let results = client.select(&query, None, &[])?;
let mut backup = Vec::new();
if results.is_empty() {
return Ok::<_, spi::Error>(backup);
}
for row in results {
let id = row["id"].value()?; let data = row["data"].value()?;
backup.push(BackupRow { id, data });
}
Ok::<_, spi::Error>(backup)
})?;
Ok(backup)
}
fn infer_base_tables(table_name: &str) -> TViewResult<Vec<String>> {
if let Some(hinted_tables) = get_base_table_hints(table_name)? {
return Ok(hinted_tables);
}
let inferred = infer_base_tables_from_data(table_name)?;
if !inferred.is_empty() {
return Ok(inferred);
}
Ok(Vec::new())
}
fn infer_base_tables_from_data(table_name: &str) -> TViewResult<Vec<String>> {
let mut base_tables = Vec::new();
Spi::connect(|client| {
let qi_table = quote_identifier(table_name);
let query = format!("SELECT data FROM {qi_table} LIMIT 5");
let results = client.select(&query, None, &[])?;
for row in results {
if let Some(data) = row["data"].value::<String>()? {
if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(&data) {
extract_table_references(&json_value, &mut base_tables);
}
}
}
Ok::<_, spi::Error>(())
})?;
base_tables.sort();
base_tables.dedup();
let existing_tables: Vec<String> = base_tables
.into_iter()
.filter(|table| table_exists(table))
.collect();
Ok(existing_tables)
}
fn extract_table_references(json: &serde_json::Value, tables: &mut Vec<String>) {
match json {
serde_json::Value::Object(obj) => {
for (key, value) in obj {
if key.starts_with("fk_") && key.len() > 3 {
let table_name = format!("tb_{}", &key[3..]);
tables.push(table_name);
} else if key.ends_with("_id") && key.len() > 3 {
let table_name = format!("tb_{}", &key[..key.len() - 3]);
tables.push(table_name);
}
extract_table_references(value, tables);
}
}
serde_json::Value::Array(arr) => {
for item in arr {
extract_table_references(item, tables);
}
}
_ => {}
}
}
fn table_exists(table_name: &str) -> bool {
let args = vec![unsafe {
DatumWithOid::new(table_name, PgOid::BuiltIn(PgBuiltInOids::TEXTOID).value())
}];
Spi::get_one_with_args::<bool>(
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name = $1)",
&args,
)
.unwrap_or(Some(false))
.unwrap_or(false)
}
fn get_base_table_hints(table_name: &str) -> TViewResult<Option<Vec<String>>> {
let args = vec![unsafe {
DatumWithOid::new(table_name, PgOid::BuiltIn(PgBuiltInOids::TEXTOID).value())
}];
let comment: Option<String> = Spi::get_one_with_args(
"SELECT obj_description(pg_class.oid, 'pg_class') as comment
FROM pg_class
WHERE relname = $1",
&args,
)?;
if let Some(comment) = comment {
if let Some(bases_part) = comment.split("TVIEW_BASES:").nth(1).map(str::trim) {
let tables: Vec<String> = bases_part
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
if !tables.is_empty() {
return Ok(Some(tables));
}
}
}
Ok(None)
}
fn reconstruct_as_tview(
table_name: &str,
entity_name: &str,
schema: &TViewSchema,
_base_tables: &[String],
data_backup: &[BackupRow],
) -> TViewResult<()> {
let view_name = format!("v_{entity_name}");
let qi_view = quote_identifier(&view_name);
let qi_table = quote_identifier(table_name);
if data_backup.is_empty() {
Spi::run(&format!(
"CREATE VIEW {qi_view} AS SELECT
NULL::uuid as id,
NULL::jsonb as data
WHERE false"
))?;
} else {
let mut values = Vec::new();
for row in data_backup {
if let (Some(id), Some(data)) = (&row.id, &row.data) {
let id_ref: &str = id;
let data_ref: &str = data;
let escaped = Spi::get_one_with_args::<String>(
"SELECT quote_literal($1)::text || '::uuid, ' || quote_literal($2)::text || '::jsonb'",
&[
unsafe { DatumWithOid::new(id_ref, PgOid::BuiltIn(PgBuiltInOids::TEXTOID).value()) },
unsafe { DatumWithOid::new(data_ref, PgOid::BuiltIn(PgBuiltInOids::TEXTOID).value()) },
],
).map_err(|e| TViewError::SpiError {
query: "quote_literal for backup row".to_string(),
error: e.to_string(),
})?.ok_or_else(|| TViewError::SpiError {
query: "quote_literal for backup row".to_string(),
error: "NULL result from quote_literal".to_string(),
})?;
values.push(format!("({escaped})"));
}
}
Spi::run(&format!(
"CREATE VIEW {qi_view} AS SELECT * FROM (VALUES {}) AS t(id, data)",
values.join(", ")
))?;
}
Spi::run(&format!(
"CREATE VIEW {qi_table} AS SELECT * FROM {qi_view}"
))?;
register_tview_metadata(entity_name, &view_name, table_name, schema)?;
Ok(())
}
fn register_tview_metadata(
entity_name: &str,
view_name: &str,
tview_name: &str,
_schema: &TViewSchema,
) -> TViewResult<()> {
let view_args = vec![unsafe {
DatumWithOid::new(view_name, PgOid::BuiltIn(PgBuiltInOids::TEXTOID).value())
}];
let view_oid = Spi::get_one_with_args::<pg_sys::Oid>(
"SELECT pg_class.oid FROM pg_class WHERE relname::text = $1",
&view_args,
)?
.ok_or_else(|| TViewError::CatalogError {
operation: format!("Get OID for view {view_name}"),
pg_error: "View not found".to_string(),
})?;
let table_args = vec![unsafe {
DatumWithOid::new(tview_name, PgOid::BuiltIn(PgBuiltInOids::TEXTOID).value())
}];
let table_oid = Spi::get_one_with_args::<pg_sys::Oid>(
"SELECT pg_class.oid FROM pg_class WHERE relname::text = $1",
&table_args,
)?
.ok_or_else(|| TViewError::CatalogError {
operation: format!("Get OID for table {tview_name}"),
pg_error: "Table not found".to_string(),
})?;
let definition = format!("SELECT * FROM {}", quote_identifier(view_name));
let insert_sql = format!(
"INSERT INTO pg_tview_meta (entity, view_oid, table_oid, definition, fk_columns, uuid_fk_columns)
VALUES ($1, {}, {}, $2, '{{}}', '{{}}')
ON CONFLICT (entity) DO UPDATE SET
view_oid = EXCLUDED.view_oid,
table_oid = EXCLUDED.table_oid,
definition = EXCLUDED.definition,
fk_columns = EXCLUDED.fk_columns,
uuid_fk_columns = EXCLUDED.uuid_fk_columns",
view_oid.to_u32(),
table_oid.to_u32(),
);
let args = [
unsafe { DatumWithOid::new(entity_name, PgOid::BuiltIn(PgBuiltInOids::TEXTOID).value()) },
unsafe { DatumWithOid::new(definition, PgOid::BuiltIn(PgBuiltInOids::TEXTOID).value()) },
];
Spi::run_with_args(&insert_sql, &args)?;
Ok(())
}
#[derive(Debug)]
struct BackupRow {
id: Option<String>,
data: Option<String>,
}
fn extract_entity_name(table_name: &str) -> TViewResult<&str> {
table_name
.strip_prefix("tv_")
.ok_or_else(|| TViewError::InvalidSelectStatement {
sql: table_name.to_string(),
reason: "Table name must start with tv_".to_string(),
})
}