brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
#!/usr/bin/env python3
"""Convert Brain-Harmony PyTorch checkpoint to safetensors format.

Usage:
    python scripts/convert_weights.py \
        --input checkpoints/harmonizer/model.pth \
        --output data/brainharmony.safetensors

The output file contains all components (encoder, predictor, target_encoder)
with the "module." prefix stripped. Keys are prefixed by component name:
    encoder.blocks.0.norm1.weight
    predictor.predictor_blocks.0.attn.qkv.weight
    target_encoder.blocks.0.norm1.weight
"""
import argparse
import torch
from safetensors.torch import save_file


def main():
    parser = argparse.ArgumentParser(description="Convert Brain-Harmony .pth to .safetensors")
    parser.add_argument("--input", required=True, help="Input .pth or .pth.tar checkpoint")
    parser.add_argument("--output", required=True, help="Output .safetensors file")
    parser.add_argument("--component", default="all",
                        choices=["all", "encoder", "predictor", "target_encoder"],
                        help="Which component(s) to export")
    args = parser.parse_args()

    print(f"Loading {args.input} ...")
    ckpt = torch.load(args.input, map_location="cpu", weights_only=False)

    # Handle different checkpoint formats
    if "model" in ckpt:
        state_dict = ckpt["model"]
    elif "state_dict" in ckpt:
        state_dict = ckpt["state_dict"]
    else:
        state_dict = ckpt

    tensors = {}

    if args.component == "all":
        # Try component-based format first
        components = ["encoder", "predictor", "target_encoder"]
        found_components = False
        for comp in components:
            if comp in ckpt:
                found_components = True
                sd = ckpt[comp]
                n = 0
                for key, param in sd.items():
                    clean_key = key.removeprefix("module.")
                    full_key = f"{comp}.{clean_key}"
                    tensors[full_key] = param.contiguous().float()
                    n += 1
                print(f"  {comp}: {n} tensors")

        # If no component-based format, use flat state dict
        if not found_components:
            for key, param in state_dict.items():
                clean_key = key.removeprefix("module.")
                tensors[clean_key] = param.contiguous().float()
            print(f"  flat: {len(tensors)} tensors")
    else:
        if args.component in ckpt:
            sd = ckpt[args.component]
        else:
            sd = {k: v for k, v in state_dict.items()
                  if k.startswith(args.component + ".") or
                  k.startswith("module." + args.component + ".")}

        for key, param in sd.items():
            clean_key = key.removeprefix("module.")
            full_key = f"{args.component}.{clean_key}"
            tensors[full_key] = param.contiguous().float()
        print(f"  {args.component}: {len(tensors)} tensors")

    print(f"Saving {len(tensors)} tensors to {args.output} ...")
    save_file(tensors, args.output)
    print("Done.")


if __name__ == "__main__":
    main()