elinor 0.4.0

Evaluation Library in Information Retrieval
Documentation
#!/usr/bin/env python3

import argparse
import json
import subprocess
import sys


def run_elinor_evaluate(
    target_dir: str, qrels_jsonl: str, results_jsonl: str, metrics: list[str]
) -> dict[str, str]:
    metric_args = " ".join([f"-m {metric}" for metric in metrics])
    command = f"./{target_dir}/elinor-evaluate -t {qrels_jsonl} -p {results_jsonl} {metric_args}"
    result = subprocess.run(command, capture_output=True, shell=True)
    if result.returncode != 0:
        print(result.stderr.decode("utf-8"), file=sys.stderr)
        sys.exit(1)
    parsed: dict[str, str] = {}
    for line in result.stdout.decode("utf-8").split("\n"):
        if not line:
            continue
        metric, value = line.split()
        parsed[metric] = value
    return parsed


def compare_decimal_places(a: str, b: str, decimal_places: int) -> bool:
    return round(float(a), decimal_places) == round(float(b), decimal_places)


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("target_dir", help="e.g., target/release")
    p.add_argument("qrels_jsonl")
    p.add_argument("results_jsonl")
    p.add_argument("trec_output_json")
    p.add_argument("--decimal-places", type=int, default=3)
    args = p.parse_args()

    target_dir: str = args.target_dir
    qrels_jsonl: str = args.qrels_jsonl
    results_jsonl: str = args.results_jsonl
    trec_output_json: str = args.trec_output_json
    decimal_places: int = args.decimal_places

    with open(trec_output_json) as f:
        trec_results = json.load(f)

    # (trec_eval, elinor)
    metric_pairs = []
    metric_pairs.extend([(f"success_{k}", f"success@{k}") for k in [1, 5, 10]])
    metric_pairs.extend(
        [
            ("set_P", "precision"),
            ("set_recall", "recall"),
            ("set_F", "f1"),
            ("Rprec", "r_precision"),
            ("map", "ap"),
            ("recip_rank", "rr"),
            ("ndcg", "ndcg"),
            ("bpref", "bpref"),
        ]
    )
    ks = [5, 10, 15, 20, 30, 100, 200, 500, 1000]
    metric_pairs.extend([(f"P_{k}", f"precision@{k}") for k in ks])
    metric_pairs.extend([(f"recall_{k}", f"recall@{k}") for k in ks])
    metric_pairs.extend([(f"map_cut_{k}", f"ap@{k}") for k in ks])
    metric_pairs.extend([(f"ndcg_cut_{k}", f"ndcg@{k}") for k in ks])

    elinor_results = run_elinor_evaluate(
        target_dir,
        qrels_jsonl,
        results_jsonl,
        [metric for _, metric in metric_pairs],
    )

    # Add some additional basic metrics
    metric_pairs.extend(
        [
            ("num_q", "n_queries_in_true"),
            ("num_q", "n_queries_in_pred"),
            ("num_ret", "n_docs_in_pred"),
            ("num_rel", "n_relevant_docs"),
        ]
    )

    failed_rows: list[str] = []
    for trec_metric, elinor_metric in metric_pairs:
        trec_score = trec_results["trec_eval_output"][trec_metric]
        elinor_score = elinor_results[elinor_metric]
        match = compare_decimal_places(trec_score, elinor_score, decimal_places)
        row = f"{trec_metric}\t{elinor_metric}\t{trec_score}\t{elinor_score}\t{match}"
        print(f"{trec_metric}\t{elinor_metric}\t{trec_score}\t{elinor_score}\t{match}")
        if not match:
            failed_rows.append(row)

    if failed_rows:
        print("Mismatched cases:", file=sys.stderr)
        for row in failed_rows:
            print(row, file=sys.stderr)
        sys.exit(1)
    else:
        print(f"All metrics match 🎉 with {decimal_places=}", file=sys.stderr)