use narwhal_core::{Error, Result, Value};
use std::fmt::Write as _;
fn quote_ident(name: &str) -> String {
format!("\"{}\"", name.replace('"', "\"\""))
}
struct ColumnInfo {
name: String,
data_type: String,
not_null: bool,
default: Option<String>,
identity: String,
generated: String,
}
struct PkInfo {
columns: Vec<String>,
}
pub(crate) async fn build_create_table(
conn: &super::PostgresConnection,
schema: &str,
table: &str,
) -> Result<String> {
let columns = fetch_columns(conn, schema, table).await?;
if columns.is_empty() {
return Err(Error::Schema(format!("table {schema}.{table} not found")));
}
let pk = fetch_pk(conn, schema, table).await?;
let qualified = format!("{}.{}", quote_ident(schema), quote_ident(table));
let mut out = String::with_capacity(256);
writeln!(&mut out, "CREATE TABLE {qualified} (").map_err(|e| Error::Other(e.to_string()))?;
let composite_pk = pk.columns.len() > 1;
let mut column_lines = Vec::with_capacity(columns.len());
for col in &columns {
let is_pk = pk.columns.contains(&col.name);
let mut line = format!(" {} {}", quote_ident(&col.name), col.data_type);
if col.identity == "a" {
line.push_str(" GENERATED ALWAYS AS IDENTITY");
} else if col.identity == "d" {
line.push_str(" GENERATED BY DEFAULT AS IDENTITY");
} else if col.generated == "s" {
let expr = col.default.as_deref().ok_or_else(|| {
Error::Other(format!(
"column '{}' marked generated STORED but has no expression",
col.name
))
})?;
write!(&mut line, " GENERATED ALWAYS AS ({expr}) STORED")
.map_err(|e| Error::Other(e.to_string()))?;
} else if let Some(default) = &col.default {
write!(&mut line, " DEFAULT {default}").map_err(|e| Error::Other(e.to_string()))?;
}
if col.not_null && (composite_pk || !is_pk) {
line.push_str(" NOT NULL");
}
if !composite_pk && is_pk {
line.push_str(" PRIMARY KEY");
}
column_lines.push(line);
}
if composite_pk {
let quoted: Vec<String> = pk.columns.iter().map(|c| quote_ident(c)).collect();
column_lines.push(format!(" PRIMARY KEY ({})", quoted.join(", ")));
}
out.push_str(&column_lines.join(",\n"));
out.push_str("\n);\n");
Ok(out)
}
async fn fetch_columns(
conn: &super::PostgresConnection,
schema: &str,
table: &str,
) -> Result<Vec<ColumnInfo>> {
const SQL: &str = "
SELECT a.attname,
pg_catalog.format_type(a.atttypid, a.atttypmod),
a.attnotnull,
pg_catalog.pg_get_expr(d.adbin, d.adrelid),
a.attidentity::text,
a.attgenerated::text
FROM pg_catalog.pg_attribute a
JOIN pg_catalog.pg_class c ON c.oid = a.attrelid
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
LEFT JOIN pg_catalog.pg_attrdef d
ON d.adrelid = a.attrelid AND d.adnum = a.attnum
WHERE n.nspname = $1
AND c.relname = $2
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum";
let result = conn
.run(
SQL,
&[
Value::String(schema.to_owned()),
Value::String(table.to_owned()),
],
)
.await?;
let mut columns = Vec::with_capacity(result.rows.len());
for row in result.rows {
let mut iter = row.0.into_iter();
let name = match iter.next() {
Some(Value::String(s)) => s,
_ => continue,
};
let data_type = match iter.next() {
Some(Value::String(s) | Value::Unknown(s)) => s,
_ => "unknown".into(),
};
let not_null = matches!(iter.next(), Some(Value::Bool(true)));
let default = match iter.next() {
Some(Value::String(s) | Value::Unknown(s)) => Some(s),
_ => None,
};
let identity = match iter.next() {
Some(Value::String(s) | Value::Unknown(s)) => s,
_ => String::new(),
};
let generated = match iter.next() {
Some(Value::String(s) | Value::Unknown(s)) => s,
_ => String::new(),
};
columns.push(ColumnInfo {
name,
data_type,
not_null,
default,
identity,
generated,
});
}
Ok(columns)
}
async fn fetch_pk(conn: &super::PostgresConnection, schema: &str, table: &str) -> Result<PkInfo> {
const SQL: &str = "
SELECT a.attname
FROM pg_catalog.pg_constraint con
JOIN pg_catalog.pg_class c ON c.oid = con.conrelid
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
CROSS JOIN LATERAL unnest(con.conkey) WITH ORDINALITY AS k(num, ord)
JOIN pg_catalog.pg_attribute a
ON a.attrelid = c.oid AND a.attnum = k.num
WHERE con.contype = 'p'
AND n.nspname = $1
AND c.relname = $2
ORDER BY k.ord";
let result = conn
.run(
SQL,
&[
Value::String(schema.to_owned()),
Value::String(table.to_owned()),
],
)
.await?;
let columns: Vec<String> = result
.rows
.into_iter()
.filter_map(|row| match row.0.into_iter().next() {
Some(Value::String(s)) => Some(s),
_ => None,
})
.collect();
Ok(PkInfo { columns })
}
#[cfg(test)]
mod tests {
#[test]
fn quote_ident_doubles_quotes() {
assert_eq!(super::quote_ident("my\"table"), "\"my\"\"table\"");
}
}