from __future__ import annotations
import argparse
import datetime as _dt
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List
def _is_ci() -> bool:
return os.environ.get("CI", "").lower() in {"1", "true", "yes"}
def _load_fixture(path: Path) -> Dict[str, Any]:
with path.open("r", encoding="utf-8") as fh:
return json.load(fh)
def _build_langchain_pipeline(kg: List[Dict[str, str]], labels: Dict[str, str]) -> Any:
try:
from langchain.chains import GraphQAChain from langchain.graphs import NetworkxEntityGraph from langchain.llms import OpenAI except ImportError as exc: sys.stderr.write(
"ERROR: This script requires the LangChain runtime. "
"Install it locally with `pip install 'langchain==0.3.*'`.\n"
f"Underlying ImportError: {exc}\n"
)
sys.exit(2)
graph = NetworkxEntityGraph()
for triple in kg:
graph.add_triple(triple["s"], triple["p"], triple["o"])
chain = GraphQAChain.from_llm(OpenAI(temperature=0.0), graph=graph)
chain.graph_labels = labels return chain
def _rank_answers(chain: Any, question: str, top_k: int) -> List[str]:
response = chain.invoke({"query": question})
raw = response.get("result") if isinstance(response, dict) else str(response)
if not raw:
return []
raw = raw.strip()
if raw.startswith("["):
try:
parsed = json.loads(raw)
if isinstance(parsed, list):
return [str(item) for item in parsed][:top_k]
except json.JSONDecodeError:
pass
if "\n" in raw:
return [line.strip() for line in raw.splitlines() if line.strip()][:top_k]
return raw.split()[:top_k]
def main() -> int:
parser = argparse.ArgumentParser(
description="Capture LangChain GraphRAG reference outputs for "
"oxirs-graphrag KGQA benchmark Phase 2.",
)
parser.add_argument(
"--fixture",
type=Path,
required=True,
help="Path to webqsp_subset.json (or compatible KGQA fixture).",
)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Directory in which to write per-question JSON captures.",
)
parser.add_argument(
"--langchain-version",
type=str,
required=True,
help="Pinned LangChain version label recorded in each capture.",
)
parser.add_argument(
"--top-k",
type=int,
default=5,
help="Number of ranked answers to capture per question (default: 5).",
)
args = parser.parse_args()
if _is_ci():
sys.stderr.write(
"WARNING: This is an operator-only script and should not be "
"invoked inside CI. Aborting.\n"
)
return 1
fixture = _load_fixture(args.fixture)
args.output.mkdir(parents=True, exist_ok=True)
chain = _build_langchain_pipeline(fixture["kg"], fixture.get("labels", {}))
captured = 0
for question in fixture["questions"]:
qid = question["qid"]
ranked = _rank_answers(chain, question["question"], args.top_k)
payload: Dict[str, Any] = {
"qid": qid,
"ranked_answers": ranked,
"langchain_version": args.langchain_version,
"captured_at": _dt.datetime.now(_dt.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
}
out_path = args.output / f"{qid}.json"
with out_path.open("w", encoding="utf-8") as fh:
json.dump(payload, fh, ensure_ascii=False, indent=2)
captured += 1
sys.stderr.write(
f"Captured {captured} reference outputs to {args.output}\n"
"Use LANGCHAIN_REF_FIXTURES=$PWD/<output> with cargo bench to enable Phase 2.\n"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())