import argparse
from pathlib import Path
import subprocess
import sys
def main():
ap = argparse.ArgumentParser(description="Bulk convert .pth/.pt checkpoints to .safetensors")
ap.add_argument("--data-dir", default="data", help="Directory to scan for checkpoints")
ap.add_argument("--recursive", action="store_true", help="Scan recursively")
ap.add_argument("--overwrite", action="store_true", help="Overwrite existing .safetensors files")
args = ap.parse_args()
root = Path(args.data_dir)
if not root.exists():
raise SystemExit(f"data dir not found: {root}")
patterns = ["*.pth", "*.pt"]
checkpoints = []
for pat in patterns:
it = root.rglob(pat) if args.recursive else root.glob(pat)
checkpoints.extend(sorted(it))
if not checkpoints:
print(f"No .pth/.pt checkpoints found in {root}")
return
script = Path(__file__).parent / "export_weights_to_safetensors.py"
converted = 0
skipped = 0
failed = 0
for ckpt in checkpoints:
out = ckpt.with_suffix(".safetensors")
if out.exists() and not args.overwrite:
print(f"[skip] {ckpt} -> {out} (exists)")
skipped += 1
continue
cmd = [sys.executable, str(script), "--pth", str(ckpt), "--out", str(out)]
print(f"[run ] {' '.join(cmd)}")
proc = subprocess.run(cmd)
if proc.returncode == 0:
converted += 1
else:
failed += 1
print(f"[fail] {ckpt}")
print("\nSummary")
print(f" converted: {converted}")
print(f" skipped : {skipped}")
print(f" failed : {failed}")
if failed > 0:
raise SystemExit(1)
if __name__ == "__main__":
main()