seizuretransformer 0.0.1

SeizureTransformer EEG model in Rust (Burn + wgpu)
Documentation
#!/usr/bin/env python3
import argparse
from pathlib import Path
import subprocess
import sys


def main():
    ap = argparse.ArgumentParser(description="Bulk convert .pth/.pt checkpoints to .safetensors")
    ap.add_argument("--data-dir", default="data", help="Directory to scan for checkpoints")
    ap.add_argument("--recursive", action="store_true", help="Scan recursively")
    ap.add_argument("--overwrite", action="store_true", help="Overwrite existing .safetensors files")
    args = ap.parse_args()

    root = Path(args.data_dir)
    if not root.exists():
        raise SystemExit(f"data dir not found: {root}")

    patterns = ["*.pth", "*.pt"]
    checkpoints = []
    for pat in patterns:
        it = root.rglob(pat) if args.recursive else root.glob(pat)
        checkpoints.extend(sorted(it))

    if not checkpoints:
        print(f"No .pth/.pt checkpoints found in {root}")
        return

    script = Path(__file__).parent / "export_weights_to_safetensors.py"

    converted = 0
    skipped = 0
    failed = 0

    for ckpt in checkpoints:
        out = ckpt.with_suffix(".safetensors")
        if out.exists() and not args.overwrite:
            print(f"[skip] {ckpt} -> {out} (exists)")
            skipped += 1
            continue

        cmd = [sys.executable, str(script), "--pth", str(ckpt), "--out", str(out)]
        print(f"[run ] {' '.join(cmd)}")
        proc = subprocess.run(cmd)
        if proc.returncode == 0:
            converted += 1
        else:
            failed += 1
            print(f"[fail] {ckpt}")

    print("\nSummary")
    print(f"  converted: {converted}")
    print(f"  skipped  : {skipped}")
    print(f"  failed   : {failed}")

    if failed > 0:
        raise SystemExit(1)


if __name__ == "__main__":
    main()