rvz 0.2.1

RVZ compression library.
Documentation
import argparse
import hashlib
import struct
from Crypto.Cipher import AES

# fmt:off
common_key = bytes(
    [
        0xEB, 0xE4, 0x2A, 0x22, 0x5E, 0x85, 0x93, 0xE4,
        0x48, 0xD9, 0xC5, 0x45, 0x73, 0x81, 0xAA, 0xF7,
    ]
)

title_key = bytes(
    [
        0xA7, 0xDA, 0x4D, 0xEF, 0xCC, 0xAF, 0xB4, 0x1A,
        0xE8, 0x51, 0x91, 0xCE, 0xF4, 0x99, 0x54, 0x76,
    ]
)

iv = bytes(
    [
        0x00, 0x01, 0x00, 0x00, 0x00, 0x55, 0x50, 0x45,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
    ]
)

decrypted = bytes(
    [
        0x45, 0xBD, 0x9A, 0xB0, 0x6A, 0xA4, 0xC7, 0x47,
        0xAC, 0x36, 0x63, 0x02, 0x1B, 0xBF, 0xAF, 0x2D,
    ]
)

hashes_enc = bytes(
    [
        0x7C, 0x6B, 0x48, 0xF6, 0x46, 0x21, 0x0C, 0x1F,
        0xAF, 0x56, 0xAA, 0xCB, 0x31, 0xFA, 0xA3, 0x12,
        0xBA, 0x31, 0xCA, 0xDD, 0x7A, 0xAB, 0x3A, 0x0A,
        0x4B, 0xEB, 0x3B, 0x9D, 0xC5, 0x49, 0x63, 0x8D,
    ]
)

iv = bytes(
    [
        0x0B, 0x11, 0x14, 0xED, 0x7B, 0x33, 0x44, 0x4D,
        0x3A, 0x55, 0x4C, 0x38, 0xEC, 0x28, 0xDD, 0x2B,
    ]
)

data = bytes(
    [
        0x2C, 0x79, 0x79, 0xAB, 0x22, 0x5F, 0xF8, 0xBA,
        0x9F, 0x5B, 0x7C, 0x59, 0x31, 0x7D, 0x35, 0xAA,
    ]
)

H1_ZERO = bytes(
    [
        0x26, 0x94, 0x62, 0xC3, 0xC0, 0x85, 0xAD, 0x49,
        0x3D, 0x26, 0xCA, 0x70, 0xA0, 0x0C, 0xB7, 0x26,
        0x8C, 0x7B, 0xA0, 0x4C,
    ]
)

H2_ZERO = bytes(
    [
        0x9E, 0x69, 0xB0, 0xAC, 0x67, 0x72, 0x95, 0xA0,
        0xB4, 0x57, 0x14, 0xDD, 0x84, 0xD5, 0xFD, 0x4D,
        0x24, 0x63, 0x93, 0x66,
    ]
)
# fmt:on


def partition_tables(io):
    io.seek(0x40000)
    data = io.read(32)
    tables = []
    for i in range(0, 4):
        total_partitions = struct.unpack(">I", data[0 + i * 8 : 4 + i * 8])[0]
        print(data[4 + i * 8 : 8 + i * 8])
        offset = struct.unpack(">I", data[4 + i * 8 : 8 + i * 8])[0]
        tables.append(
            {
                "total_partitions": total_partitions,
                "offset": offset * 4,
            }
        )
    return tables


def get_partitions(io, table):
    io.seek(table["offset"])
    partitions = []
    for i in range(0, table["total_partitions"]):
        offset = struct.unpack(">I", io.read(4))[0]
        type_ = struct.unpack(">I", io.read(4))[0]
        partitions.append({"offset": offset * 4, "type": type_})
    return partitions


def get_partition(io, partition):
    io.seek(partition["offset"])
    ticket = io.read(0x2A4)
    tmd_size = struct.unpack(">I", io.read(0x4))[0]
    tmd_offset = struct.unpack(">I", io.read(0x4))[0] * 4
    cert_chain_size = struct.unpack(">I", io.read(0x4))[0]
    cert_chain_offset = struct.unpack(">I", io.read(0x4))[0] * 4
    h3_offset = struct.unpack(">I", io.read(0x4))[0] * 4
    data_offset = struct.unpack(">I", io.read(0x4))[0] * 4
    data_size = struct.unpack(">I", io.read(0x4))[0] * 4

    return {
        "ticket": ticket,
        "tmd_size": tmd_size,
        "tmd_offset": tmd_offset,
        "cert_chain_size": cert_chain_size,
        "cert_chain_offset": cert_chain_offset,
        "h3_offset": h3_offset,
        "data_offset": data_offset,
        "data_size": data_size,
    }

def get_h3_hashes(io, partition_entry, partition):
    io.seek(partition_entry["offset"] + partition["h3_offset"])
    hashes = io.read(0x18000)
    return [hashes[i:i + 20] for i in range(0, 0x18000, 20)]


def get_title_id(ticket):
    return ticket[0x1D0:0x1D8]


def get_title_iv(ticket):
    return ticket[0x1DC : 0x1DC + 8] + (bytes([0]) * 8)


def get_title_key(ticket):
    title_key = ticket[0x1BF : 0x1BF + 0x10]
    iv = get_title_iv(ticket)
    cipher = AES.new(common_key, AES.MODE_CBC, iv=iv)
    return cipher.decrypt(title_key)


def decrypt_hashes(title_key, sector):
    cipher = AES.new(title_key, AES.MODE_CBC, iv=bytes([0] * 16))
    return cipher.decrypt(sector[:0x400])


def decrypt_data(title_key, sector):
    iv = sector[0x3D0:0x3E0]
    hashes = decrypt_hashes(title_key, sector)
    cipher = AES.new(title_key, AES.MODE_CBC, iv=iv)
    data = cipher.decrypt(sector[0x400:])
    return (hashes, data)


def compute_h0(data):
    h0 = []
    for i in range(0, (0x7C00 // 0x400)):
        h0.append(hashlib.sha1(data[i * 0x400 : (i + 1) * 0x400]).digest())

    return h0


def compute_h1(h0s):
    assert len(h0s) == 31
    sha1 = hashlib.sha1()
    for hash_ in h0s:
        sha1.update(hash_)

    return sha1.digest()


def compute_h2(h1s):
    assert len(h1s) == 8
    sha1 = hashlib.sha1()
    for hash_ in h1s:
        sha1.update(hash_)

    return sha1.digest()

def compute_h3(h2s):
    h2_groups = [h2s[i:i+8] for i in range(0, len(h2s), 8)]
    h3 = []
    for group in h2_groups:
        sha1 = hashlib.sha1()
        for hash_ in group:
            sha1.update(hash_)
        h3.append(sha1.digest())

    return h3


def parse_h0(hashes):
    h0 = []
    for i in range(0, 31):
        h0.append(hashes[i * 20 : (i + 1) * 20])
    return h0


def parse_h1(hashes):
    h1 = []
    for i in range(0, 8):
        h1.append(hashes[0x280 + i * 20 : 0x280 + (i + 1) * 20])
    return h1


def parse_h2(hashes):
    h2 = []
    for i in range(0, 8):
        h2.append(hashes[0x340 + i * 20 : 0x340 + (i + 1) * 20])
    return h2


def main():

    cipher = AES.new(common_key, AES.MODE_CBC, iv=iv)

    global title_key
    decrypt = cipher.decrypt(title_key)
    print(decrypt == decrypted)

    cipher = AES.new(common_key, AES.MODE_CBC, iv=iv)

    print(cipher.encrypt(decrypted) == title_key)

    cipher = AES.new(decrypted, AES.MODE_CBC, iv=bytes([0] * 16))

    cipher = AES.new(decrypted, AES.MODE_CBC, iv=iv)

    print(cipher.decrypt(hashes_enc[0:32]).hex())

    global data
    print(cipher.decrypt(data).hex())

    cipher = AES.new(decrypted, AES.MODE_CBC, iv=iv)

    parser = argparse.ArgumentParser()
    parser.add_argument("filename")
    args = parser.parse_args()

    with open(args.filename, "rb") as file:
        tables = partition_tables(file)
        partitions = get_partitions(file, tables[0])
        print(f"tables: {tables}")
        print(f"partitions: {partitions}")

        with open("decrypted.data", "wb") as output_file:
            sector_count = 0
            h0_count = 0
            h1_count = 0
            h2_count = 0

            for partition_entry in partitions:
                partition = get_partition(file, partition_entry)
                title_key = get_title_key(partition["ticket"])
                print(f"partition: {partition}")
                print(get_title_id(partition["ticket"]).hex())
                print(get_title_iv(partition["ticket"]).hex())
                print(get_title_key(partition["ticket"]).hex())

                sector_start = (
                    partition_entry["offset"] + partition["data_offset"]
                ) // 0x8000
                sector_end = sector_start + (partition["data_size"] + 0x7FFF) // 0x8000
                print(sector_start, sector_end, sector_end - sector_start)
                file.seek(partition_entry["offset"] + partition["data_offset"])

                h1s = []
                h2s = []
                local_sector = 0
                # compute hashes
                h1_orig = bytes()
                h0_orig = bytes()
                h2_orig = bytes()
                h1 = bytes()
                h2 = bytes()
                for i in range(sector_start, sector_end):
                    hashes, data = decrypt_data(title_key, file.read(0x8000))
                    output_file.write(hashes)
                    output_file.write(data)
                    computed = hashlib.sha1(data)
                    h0 = compute_h0(data)
                    h0_orig = parse_h0(hashes)
                    h1_orig = parse_h1(hashes)
                    h2_orig = parse_h2(hashes)
                    for h in h0:
                        print(f"H0 {h0_count} {h.hex().upper()}")
                        if h != h0_orig[h0_count % 31]:
                            print(f"MATCH FAIL AT H0 {h0_count}")
                            return
                        h0_count += 1
                    h1 = compute_h1(h0)
                    print(f"H1 {h1_count} {h1.hex().upper()}")
                    if h1 != h1_orig[local_sector % 8]:
                        print(f"MATCH FAIL AT H1 {h1_count}")
                        return

                    h1s.append(h1)
                    h1_count += 1

                    if (local_sector % 8) == 7:
                        h2 = compute_h2(h1s)
                        h2s.append(h2)
                        h1s = []
                        print(f"H2 {h2_count} {h2.hex().upper()}")
                        if h2 != h2_orig[(local_sector // 8) % 8]:
                            print(f"MATCH FAIL AT H1 {h1_count}")
                            return
                        h2_count += 1
                    # FIXME next step is comparing hashes against H0
                    local_sector += 1
                    sector_count += 1

                if len(h1s) != 8:
                    while len(h1s) != 8:
                        h1s.append(H1_ZERO)
                        print(f"H1 {h1_count} {H1_ZERO.hex().upper()}")
                        if H1_ZERO != h1_orig[local_sector % 8]:
                            print(f"MATCH FAIL AT H1 {h1_count}")
                            return
                        h1_count += 1
                        local_sector += 1

                    h2 = compute_h2(h1s)
                    h2s.append(h2)
                    h1s = []
                    print(f"H2 {h2_count} {h2.hex().upper()}")
                    if h2 != h2_orig[((local_sector - 1) // 8) % 8]:
                        print(f"MATCH FAIL AT H2 {h2_count}")
                        return
                    h2_count += 1

                while h2_count % 8 != 0:
                    print(f"H2 {h2_count} {H2_ZERO.hex().upper()}")
                    if H2_ZERO != h2_orig[((local_sector) // 8) % 8]:
                        print(f"MATCH FAIL AT H2 {h2_count}")
                        return
                    h2s.append(H2_ZERO)
                    h2_count += 1
                    local_sector += 8

                h3s = compute_h3(h2s)
                h3_orig = get_h3_hashes(file, partition_entry, partition)
                for i, h3 in enumerate(h3s):
                    print(f"H3 {i} {h3.hex().upper()}")
                    if h3 != h3_orig[i]:
                        print(f"MATCH FAIL AT H3 {i}")
                        return



if __name__ == "__main__":
    main()