import argparse
import sys
import os
import json
from supabase import create_client, Client
SUPABASE_URL = os.getenv("SUPABASE_URL", "")
SUPABASE_ANON_KEY = os.getenv("SUPABASE_ANON_KEY", "")
CQL_RESERVED_KEYWORDS = {
"token",
"schema",
"table",
"key",
"create",
"from",
"primary",
"boolean",
"timestamp",
"user",
"uuid",
"to",
"update",
"index",
"text",
"int",
"decimal",
"blob",
"delete",
"order",
}
SQL_RESERVED_KEYWORDS = {
"user",
"table",
"authorization",
"default",
"group",
"primary",
"to",
"offset",
"current_date",
"current_time",
"from",
"order",
"collate",
"create",
}
def get_supabase_client() -> Client:
return create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
def fetch_schema_json():
client = get_supabase_client()
response = client.rpc("get_full_schema_json", {}).execute()
if hasattr(response, "data"):
return response.data
else:
print("Error fetching schema:", response)
sys.exit(1)
PG_TO_SQL_TYPE = {
"bigint": "BIGINT",
"integer": "INTEGER",
"smallint": "SMALLINT",
"text": "TEXT",
"uuid": "UUID",
"boolean": "BOOLEAN",
"timestamp with time zone": "TIMESTAMPTZ",
"timestamp without time zone": "TIMESTAMP",
"numeric": "NUMERIC",
"json": "JSON",
"jsonb": "JSONB",
"character varying": "VARCHAR",
"bytea": "BYTEA",
"text[]": "TEXT[]",
"uuid[]": "UUID[]",
}
PG_TO_CQL_TYPE = {
"bigint": "bigint",
"integer": "int",
"smallint": "smallint",
"text": "text",
"uuid": "uuid",
"boolean": "boolean",
"timestamp with time zone": "timestamp",
"timestamp without time zone": "timestamp",
"numeric": "decimal",
"json": "text",
"jsonb": "text",
"character varying": "text",
"bytea": "blob",
"text[]": "list<text>",
"uuid[]": "list<uuid>",
}
def map_pg_type_to_sql(pg_type):
if pg_type.startswith("character varying"):
return pg_type.upper()
return PG_TO_SQL_TYPE.get(pg_type, pg_type.upper())
def map_pg_type_to_cql(pg_type):
if pg_type.startswith("character varying"):
return "text"
if pg_type.endswith("[]"):
base = pg_type[:-2]
return f"list<{PG_TO_CQL_TYPE.get(base, 'text')}>"
return PG_TO_CQL_TYPE.get(pg_type, "text")
def quote_cql_identifier(name):
if (
name.lower() in CQL_RESERVED_KEYWORDS
or not name.replace("_", "").isalnum()
or name[0].isdigit()
or name.lower() == "user"
):
return f'"{name}"'
return name
def quote_sql_identifier(name):
if (
name.lower() in SQL_RESERVED_KEYWORDS
or not name.replace("_", "").isalnum()
or name[0].isdigit()
or name.lower() == "user"
):
return f'"{name}"'
return name
def cql_table_name(name):
return quote_cql_identifier(name.lstrip("_"))
def sql_table_name(name):
return quote_sql_identifier(name.lstrip("_"))
def render_sql_create_table(table, schema, if_not_exists=False):
table_name = table["table_name"]
columns = table["columns"]
lines = []
for col in columns:
colname = quote_sql_identifier(col["column_name"])
coltype = map_pg_type_to_sql(col["data_type"])
default = col.get("default")
if col["column_name"] == "id" and col["data_type"] == "bigint":
line = f" {colname} {coltype} GENERATED BY DEFAULT AS IDENTITY NOT NULL"
else:
is_nullable = col.get("is_nullable", True)
line = f" {colname} {coltype}"
if not is_nullable:
line += " NOT NULL"
if default:
line += f" DEFAULT {default}"
lines.append(line)
table_name_quoted = sql_table_name(table_name)
create_clause = "CREATE TABLE IF NOT EXISTS" if if_not_exists else "CREATE TABLE"
return (
f"{create_clause} {schema}.{table_name_quoted} (\n"
+ ",\n".join(lines)
+ "\n);\n"
)
def render_sql_create_view(view, schema):
view_name = view["table_name"]
view_name_quoted = sql_table_name(view_name)
return f"-- View: {schema}.{view_name_quoted}\nCREATE VIEW {schema}.{view_name_quoted} AS /* definition omitted */;\n"
def render_sql_extensions(extensions):
lines = []
for ext in extensions:
name = ext["name"]
version = ext["version"]
lines.append(f"CREATE EXTENSION IF NOT EXISTS {name} WITH VERSION '{version}';")
return "\n".join(lines) + "\n"
def render_cql_create_table(table, cql_keyspace="public"):
table_name = cql_table_name(table["table_name"])
columns = table["columns"]
lines = []
pk = None
for col in columns:
colname = quote_cql_identifier(col["column_name"])
coltype = map_pg_type_to_cql(col["data_type"])
lines.append(f" {colname} {coltype}")
if pk is None and col["column_name"] in ("id", "uuid"):
pk = colname
if not pk and columns:
pk = quote_cql_identifier(columns[0]["column_name"])
cql = f"CREATE TABLE IF NOT EXISTS {cql_keyspace}.{table_name} (\n"
cql += ",\n".join(lines)
if pk:
cql += f",\n PRIMARY KEY ({pk})"
cql += "\n);\n"
return cql
def write_seed_files(
schema_json, sql_path="seed.sql", cql_path="seed.cql", if_not_exists=False
):
schema = schema_json.get("schema", [])
extensions = schema_json.get("extensions", [])
tables = [
t
for t in schema
if t.get("table_type") == "table" and t.get("table_schema") == "public"
]
views = [
v
for v in schema
if v.get("table_type") == "view" and v.get("table_schema") == "public"
]
sql_out = ""
cql_out = ""
for table in tables:
sql_out += (
render_sql_create_table(
table, table["table_schema"], if_not_exists=if_not_exists
)
+ "\n"
)
cql_out += render_cql_create_table(table) + "\n"
with open(sql_path, "w", encoding="utf-8") as f:
f.write(sql_out)
with open(cql_path, "w", encoding="utf-8") as f:
f.write(cql_out)
views_sql = "\n".join([render_sql_create_view(v, v["table_schema"]) for v in views])
exts_sql = render_sql_extensions(extensions)
with open("views.sql", "w", encoding="utf-8") as f:
f.write(views_sql)
with open("extensions.sql", "w", encoding="utf-8") as f:
f.write(exts_sql)
print(f"Wrote {sql_path}, {cql_path}, views.sql, extensions.sql")
def _extract_table_constraints(table):
constraints = (
table.get("constraints")
or table.get("table_constraints")
or table.get("constraints_info")
or []
)
return constraints if isinstance(constraints, list) else []
def _format_fk_rule(rule):
if not rule:
return ""
rule_upper = str(rule).strip().upper()
if rule_upper in {"NO ACTION", "RESTRICT", "CASCADE", "SET NULL", "SET DEFAULT"}:
return rule_upper
return ""
def render_table_constraints_sql(table, schema):
table_name = sql_table_name(table["table_name"])
constraints = _extract_table_constraints(table)
non_fk_sql = []
fk_sql = []
for c in constraints:
ctype = (c.get("constraint_type") or c.get("type") or "").upper()
name = c.get("constraint_name") or c.get("name")
columns = (
c.get("columns")
or c.get("column_names")
or c.get("constrained_columns")
or []
)
col_list = ", ".join(quote_sql_identifier(col) for col in columns)
constraint_name_sql = f" {quote_sql_identifier(name)}" if name else ""
if ctype == "PRIMARY KEY" and columns:
non_fk_sql.append(
f"ALTER TABLE {schema}.{table_name} ADD CONSTRAINT{constraint_name_sql} PRIMARY KEY ({col_list});"
)
continue
if ctype == "UNIQUE" and columns:
non_fk_sql.append(
f"ALTER TABLE {schema}.{table_name} ADD CONSTRAINT{constraint_name_sql} UNIQUE ({col_list});"
)
continue
if ctype == "FOREIGN KEY" and columns:
ref_table = (
c.get("foreign_table")
or c.get("referenced_table")
or c.get("foreign_table_name")
)
ref_schema = (
c.get("foreign_table_schema")
or c.get("referenced_schema")
or c.get("schema")
or "public"
)
ref_columns = c.get("foreign_columns") or c.get("referenced_columns") or []
if ref_table and ref_columns:
ref_cols_sql = ", ".join(
quote_sql_identifier(col) for col in ref_columns
)
fk_stmt = (
f"ALTER TABLE {schema}.{table_name} ADD CONSTRAINT{constraint_name_sql} "
f"FOREIGN KEY ({col_list}) REFERENCES {ref_schema}.{sql_table_name(ref_table)} ({ref_cols_sql})"
)
on_delete = _format_fk_rule(
c.get("on_delete") or c.get("delete_rule") or c.get("delete")
)
on_update = _format_fk_rule(
c.get("on_update") or c.get("update_rule") or c.get("update")
)
if on_delete:
fk_stmt += f" ON DELETE {on_delete}"
if on_update:
fk_stmt += f" ON UPDATE {on_update}"
fk_stmt += ";"
fk_sql.append(fk_stmt)
return non_fk_sql, fk_sql
def write_constraints_files(schema_json, base_constraints_path="constraints.sql"):
schema = schema_json.get("schema", [])
tables = [
t
for t in schema
if t.get("table_type") == "table" and t.get("table_schema") == "public"
]
non_fk_statements = []
fk_statements = []
for table in tables:
non_fk, fks = render_table_constraints_sql(table, table["table_schema"])
non_fk_statements.extend(non_fk)
fk_statements.extend(fks)
non_fk_path = "constraints_nofk.sql"
fk_path = "constraints_fk.sql"
with open(non_fk_path, "w", encoding="utf-8") as f:
f.write("\n".join(non_fk_statements) + ("\n" if non_fk_statements else ""))
with open(fk_path, "w", encoding="utf-8") as f:
f.write("\n".join(fk_statements) + ("\n" if fk_statements else ""))
with open(base_constraints_path, "w", encoding="utf-8") as f:
combined = []
combined.extend(non_fk_statements)
combined.extend(fk_statements)
f.write("\n".join(combined) + ("\n" if combined else ""))
print(f"Wrote {base_constraints_path}, {non_fk_path}, {fk_path}")
def main():
parser = argparse.ArgumentParser(
description="Fetch full schema from Supabase and generate seed.sql, seed.cql, views.sql, and extensions.sql"
)
parser.add_argument("--sql", default="seed.sql", help="Output SQL file")
parser.add_argument("--cql", default="seed.cql", help="Output CQL file")
parser.add_argument(
"--include-constraints-sql",
action="store_true",
help="Also generate constraints.sql, constraints_nofk.sql, and constraints_fk.sql",
)
parser.add_argument(
"--if-not-exist-table",
action="store_true",
help="Use CREATE TABLE IF NOT EXISTS in SQL output",
)
args = parser.parse_args()
schema_json = fetch_schema_json()
if isinstance(schema_json, list) and len(schema_json) == 1:
schema_json = schema_json[0]
write_seed_files(
schema_json,
sql_path=args.sql,
cql_path=args.cql,
if_not_exists=args.if_not_exist_table,
)
if args.include_constraints_sql:
write_constraints_files(schema_json, base_constraints_path="constraints.sql")
if __name__ == "__main__":
main()