import argparse
import psycopg
import yaml
import common
from common import Database
from typing import Optional
def build_clap() -> Optional[argparse.Namespace]:
parser = argparse.ArgumentParser(prog="setup_db")
parser.add_argument(
"-f",
help="YAML file containing database config data."
)
args = parser.parse_args()
return args
def create_database(
db: common.Database
) -> bool:
conn = None
success = False
try:
conn = psycopg.connect(
host=db.host,
port=db.port,
user=db.admin_user,
password=db.admin_password,
autocommit = True
)
cur = conn.cursor()
cur.execute(
"SELECT * FROM pg_database WHERE datname = %s",
(db.name,)
)
result = cur.fetchall()
if result is None or (len(result) == 0):
cur.execute(f"CREATE DATABASE {db.name} WITH OWNER = postgres;")
print(f"✓ Database '{db.name}' created successfully")
success = True
else:
print(f"✓ Database '{db.name}' already exists")
success = False
except Exception as e:
print(f"✗ Error creating database: {e}")
success = False
finally:
if conn:
conn.close()
return success
def create_enums(
db: common.Database
) -> bool:
conn = None
success = False
try:
conn = psycopg.connect(
host=db.host,
port=db.port,
user=db.admin_user,
password=db.admin_password,
dbname=db.name,
autocommit = True
)
cur = conn.cursor()
for enum_sql in db.enums:
try:
enum_name = enum_sql.split('CREATE TYPE ')[1].split(' AS ENUM')[0].strip()
except IndexError:
print(f"✗ Error parsing enum SQL: {enum_sql}")
continue
cur.execute(
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = %s)",
(enum_name,)
)
exists = cur.fetchone()[0]
if not exists:
cur.execute(enum_sql)
print(f"✓ Enum '{enum_name}' created successfully")
else:
print(f"✓ Enum '{enum_name}' already exists")
success = True
except Exception as e:
print(f"✗ Error creating enums: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
return success
def create_domains(
db: common.Database
) -> bool:
conn = None
success = False
try:
conn = psycopg.connect(
host=db.host,
port=db.port,
user=db.admin_user,
password=db.admin_password,
dbname=db.name,
autocommit = True
)
cur = conn.cursor()
for domain_sql in db.domains:
try:
domain_name = domain_sql.split('CREATE DOMAIN ')[1].split(' AS')[0].strip()
except IndexError:
print(f"✗ Error parsing domain SQL: {domain_sql}")
continue
cur.execute(
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = %s)",
(domain_name,)
)
exists = cur.fetchone()[0]
if not exists:
cur.execute(domain_sql)
print(f"✓ Domain '{domain_name}' created successfully")
else:
print(f"✓ Domain '{domain_name}' already exists")
success = True
except Exception as e:
print(f"✗ Error creating domains: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
return success
def create_types(
db: common.Database
) -> bool:
conn = None
success = False
try:
conn = psycopg.connect(
host=db.host,
port=db.port,
user=db.admin_user,
password=db.admin_password,
dbname=db.name,
autocommit = True
)
cur = conn.cursor()
for type_def in db.types:
type_name = type_def['name']
cur.execute(
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = %s)",
(type_name,)
)
exists = cur.fetchone()[0]
if not exists:
attributes = []
for attr in type_def['attributes']:
parts = attr.split(' ', 1)
if len(parts) == 2:
attr_name, attr_type = parts
attributes.append(f"{attr_name} {attr_type}")
else:
attributes.append(attr)
attributes_str = ', '.join(attributes)
create_sql = f"CREATE TYPE {type_name} AS ({attributes_str});"
cur.execute(create_sql)
print(f"✓ Type '{type_name}' created successfully")
else:
print(f"✓ Type '{type_name}' already exists")
success = True
except Exception as e:
print(f"✗ Error creating types: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
return success
def create_tables(
db: common.Database
) -> bool:
conn = None
success = False
try:
conn = psycopg.connect(
host=db.host,
port=db.port,
user=db.admin_user,
password=db.admin_password,
dbname=db.name,
autocommit = True
)
cur = conn.cursor()
for table in db.tables:
print(table.to_sql_create())
cur.execute(table.to_sql_create())
print(f"✓ Table '{table.name}' created successfully")
success = True
except Exception as e:
print(f"✗ Error creating tables: {e}")
if conn:
conn.rollback()
return False
finally:
if conn:
conn.close()
return success
def create_users():
conn = None
def main():
args = build_clap()
if args is None:
print("Error - Failed to parse arguments")
return
databases: list[Database] = []
with open(args.f, 'r') as f:
config_dict = yaml.safe_load(f)
for db in config_dict['db']:
databases.append(common.Database.from_dict(db))
print(databases)
try:
for db in databases:
print(f"Starting PostgreSQL database setup for {db.name}\n")
create_database(db)
create_enums(db)
create_domains(db)
create_types(db)
create_tables(db)
create_users()
except Exception as e:
print("\n" + "=" * 50)
print("✗ Setup failed!")
print("=" * 50)
return
if __name__ == "__main__":
main()