import argparse
import os
import sys
import time
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
CANDIDATES = {
"potion-code-16M": {
"hf_id": "minishlab/potion-code-16M",
"type": "m2v_download",
"source_dim": 256,
},
"coderankedembed": {
"hf_id": "nomic-ai/CodeRankEmbed",
"type": "distill",
"source_dim": 768,
},
"jina-code-v2": {
"hf_id": "jinaai/jina-embeddings-v2-base-code",
"type": "distill",
"source_dim": 768,
},
"codexembed-400m": {
"hf_id": "Salesforce/SFR-Embedding-Code-400M_R",
"type": "distill",
"source_dim": 768,
},
"codexembed-2b": {
"hf_id": "Salesforce/SFR-Embedding-Code-2B_R",
"type": "distill",
"source_dim": 2048,
},
}
def download_m2v_model(hf_id: str, output_dir: str) -> None:
from model2vec import StaticModel
print(f" Downloading Model2Vec model from {hf_id}...")
model = StaticModel.from_pretrained(hf_id)
model.save_pretrained(output_dir)
print(f" Saved to {output_dir}/")
def distill_model(hf_id: str, output_dir: str, pca_dims: int) -> None:
from model2vec.distill import distill, distill_from_model
print(f" Distilling {hf_id} -> {pca_dims}d Model2Vec...")
try:
model = distill(hf_id, pca_dims=pca_dims, trust_remote_code=True)
except Exception:
from transformers import AutoModel, AutoTokenizer
print(f" Standard distill failed, loading via AutoModel...")
tokenizer = AutoTokenizer.from_pretrained(hf_id, trust_remote_code=True)
base_model = AutoModel.from_pretrained(hf_id, trust_remote_code=True)
model = distill_from_model(base_model, tokenizer, pca_dims=pca_dims)
model.save_pretrained(output_dir)
print(f" Saved to {output_dir}/")
def main():
parser = argparse.ArgumentParser(description="Distill candidate models for prx eval")
parser.add_argument(
"--only",
type=str,
default=None,
help=f"Process only this model key. Choices: {', '.join(CANDIDATES.keys())}",
)
parser.add_argument(
"--dims",
type=int,
default=256,
help="PCA target dimensionality for distilled models (default: 256)",
)
parser.add_argument(
"--output-root",
type=str,
default="models/eval",
help="Root directory for output models (default: models/eval)",
)
args = parser.parse_args()
if args.only and args.only not in CANDIDATES:
print(f"Error: unknown model key '{args.only}'")
print(f"Available: {', '.join(CANDIDATES.keys())}")
sys.exit(1)
models_to_process = (
{args.only: CANDIDATES[args.only]} if args.only else CANDIDATES
)
os.makedirs(args.output_root, exist_ok=True)
results = {}
for key, spec in models_to_process.items():
output_dir = os.path.join(args.output_root, key)
print(f"\n[{key}] {spec['hf_id']}")
if os.path.exists(os.path.join(output_dir, "model.safetensors")):
print(f" Already exists at {output_dir}/, skipping. Delete to re-run.")
results[key] = "skipped"
continue
os.makedirs(output_dir, exist_ok=True)
start = time.time()
try:
if spec["type"] == "m2v_download":
download_m2v_model(spec["hf_id"], output_dir)
else:
distill_model(spec["hf_id"], output_dir, args.dims)
elapsed = time.time() - start
size_mb = os.path.getsize(os.path.join(output_dir, "model.safetensors")) / 1e6
print(f" Done in {elapsed:.1f}s, safetensors size: {size_mb:.1f} MB")
results[key] = f"ok ({size_mb:.1f} MB, {elapsed:.1f}s)"
except Exception as e:
print(f" FAILED: {e}")
results[key] = f"FAILED: {e}"
print("\n" + "=" * 60)
print("Summary:")
for key, status in results.items():
print(f" {key:25s} {status}")
print(f"\nModels saved to: {args.output_root}/")
print(f"Run eval with: python scripts/run_model_eval.py")
if __name__ == "__main__":
main()