from hashlib import sha1
from argparse import ArgumentParser
import sys
import pygit2
obj_prepared = 'PREPARE obj(TEXT, TEXT) AS ' + \
'INSERT INTO objects VALUES ($1, decode($2, \'hex\'))' + \
' ON CONFLICT DO NOTHING;'
parser = ArgumentParser()
parser.add_argument("repository")
parser.add_argument("output")
parser.add_argument(
"--no-prepared-header",
help="Disables the PREPARE statements.",
action="store_true"
)
parser.add_argument(
"--update",
help="Disables Truncation",
action="store_true"
)
parser.add_argument(
"--total",
help="Total Number of Objects"
)
total = 0
args = parser.parse_args()
if args.total:
total = int(args.total)
repo = pygit2.Repository(args.repository)
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
def show_info(cnt: int):
if total == 0:
eprint("{0} objects generated".format(cnt))
else:
percent = (cnt / total) * 100.0
eprint("{0}% ({1} objects out of {2})".format(int(percent), cnt, total))
def translate_type_id(i):
if i == 1:
return "commit"
elif i == 2:
return "tree"
elif i == 3:
return "blob"
elif i == 4:
return "tag"
else:
raise Exception("Unknown Type: {}" % i)
def encode_git_object(oid: pygit2.Oid):
type_id, data = repo.read(oid)
type_name = translate_type_id(type_id)
encoded = bytearray()
encoded.extend(type_name.encode())
encoded.extend(' '.encode())
encoded.extend(str(len(data)).encode())
encoded.extend(b'\x00')
encoded.extend(data)
sha = sha1()
sha.update(encoded)
calc_hash = sha.hexdigest()
if str(oid) != calc_hash:
raise Exception(
"Invalid Object Encoding: expected {}, encoded {}" % oid % calc_hash)
return encoded
def generate_sql_object(oid):
data = encode_git_object(oid)
sql = "EXECUTE obj('" + str(oid) + "', '" + data.hex() + "');"
return sql
def generate_sql_objects():
if not args.no_prepared_header:
yield obj_prepared
count = 0
for oid in repo:
yield generate_sql_object(oid)
count += 1
show_info(count)
def generate_sql_ref(ref):
sql = 'INSERT INTO "refs" ("name", "target") VALUES ('
sql += "'" + ref.name + "'"
sql += ", '" + str(ref.target) + "'"
sql += ') ON CONFLICT ("name") DO UPDATE SET "target" = '
sql += "'" + str(ref.target) + "';"
return sql
def generate_sql_refs():
for ref in repo.references:
yield generate_sql_ref(repo.references[ref])
yield generate_sql_ref(repo.lookup_reference("HEAD"))
def generate_sql_file():
if not args.update:
yield 'TRUNCATE "objects";'
yield 'TRUNCATE "refs";'
yield from generate_sql_objects()
yield from generate_sql_refs()
if args.output == '-':
for line in generate_sql_file():
print(line)
else:
with open(args.output, mode='w+') as file:
file.seek(0)
for line in generate_sql_file():
file.write(line)
file.write('\n')