import argparse
import torch
from safetensors.torch import save_file
def main():
parser = argparse.ArgumentParser(description="Convert Brain-Harmony .pth to .safetensors")
parser.add_argument("--input", required=True, help="Input .pth or .pth.tar checkpoint")
parser.add_argument("--output", required=True, help="Output .safetensors file")
parser.add_argument("--component", default="all",
choices=["all", "encoder", "predictor", "target_encoder"],
help="Which component(s) to export")
args = parser.parse_args()
print(f"Loading {args.input} ...")
ckpt = torch.load(args.input, map_location="cpu", weights_only=False)
if "model" in ckpt:
state_dict = ckpt["model"]
elif "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
else:
state_dict = ckpt
tensors = {}
if args.component == "all":
components = ["encoder", "predictor", "target_encoder"]
found_components = False
for comp in components:
if comp in ckpt:
found_components = True
sd = ckpt[comp]
n = 0
for key, param in sd.items():
clean_key = key.removeprefix("module.")
full_key = f"{comp}.{clean_key}"
tensors[full_key] = param.contiguous().float()
n += 1
print(f" {comp}: {n} tensors")
if not found_components:
for key, param in state_dict.items():
clean_key = key.removeprefix("module.")
tensors[clean_key] = param.contiguous().float()
print(f" flat: {len(tensors)} tensors")
else:
if args.component in ckpt:
sd = ckpt[args.component]
else:
sd = {k: v for k, v in state_dict.items()
if k.startswith(args.component + ".") or
k.startswith("module." + args.component + ".")}
for key, param in sd.items():
clean_key = key.removeprefix("module.")
full_key = f"{args.component}.{clean_key}"
tensors[full_key] = param.contiguous().float()
print(f" {args.component}: {len(tensors)} tensors")
print(f"Saving {len(tensors)} tensors to {args.output} ...")
save_file(tensors, args.output)
print("Done.")
if __name__ == "__main__":
main()