import os
import re
import sys
import numpy as np
try:
import matplotlib.pyplot as plt HAS_PLT = True
except ImportError:
plt = None HAS_PLT = False
def parse_log(log_path):
info = {}
with open(log_path) as f:
for line in f:
m = re.search(r"RNG seed \(\w+\): (\d+)", line)
if m:
info["seed"] = int(m.group(1))
m = re.search(
r"Final chi2 = ([\d.]+) \(best at move (\d+)\), accepted (\d+)/(\d+) \(([\d.]+)%\)",
line,
)
if m:
info["chi2"] = float(m.group(1))
info["best_move"] = int(m.group(2))
info["accepted"] = int(m.group(3))
info["total_moves"] = int(m.group(4))
info["acceptance"] = float(m.group(5))
m = re.search(r"density = ([\d.]+) g/cm", line)
if m:
info["density"] = float(m.group(1))
m = re.search(r"rho0 = ([\d.]+) atoms/A\^3", line)
if m:
info["rho0"] = float(m.group(1))
m = re.match(r"Move \d+/\d+: (.+)", line)
if m:
body = m.group(1)
m2 = re.search(r"sq: ([\d.]+)", body)
if m2:
info["sq_chi2"] = float(m2.group(1))
m2 = re.search(r"gr: ([\d.]+)", body)
if m2:
info["gr_chi2"] = float(m2.group(1))
return info
def parse_cn_file(cn_path):
cns = {}
with open(cn_path) as f:
for line in f:
m = re.match(r"# (\S+): mean = ([\d.]+)", line)
if m:
cns[m.group(1)] = float(m.group(2))
return cns
config_path = sys.argv[1] if len(sys.argv) > 1 else "config.toml"
config_dir = os.path.dirname(config_path) or "."
if len(sys.argv) > 2:
run_dirs = sys.argv[2:]
else:
run_dirs = sorted(
d
for d in os.listdir(config_dir)
if os.path.isdir(os.path.join(config_dir, d))
and os.path.isfile(os.path.join(config_dir, d, "rsmith.log"))
)
run_dirs = [os.path.join(config_dir, d) for d in run_dirs]
if not run_dirs:
print("No run directories found. Usage: compare_ensemble.py [config.toml] [run01 run02 ...]")
sys.exit(1)
results = []
for run_dir in run_dirs:
log_path = os.path.join(run_dir, "rsmith.log")
if not os.path.isfile(log_path):
print(f" Skipping {run_dir}: no rsmith.log")
continue
info = parse_log(log_path)
info["dir"] = os.path.basename(run_dir)
for cn_name in ["analysis_refined_cn.dat", "analysis_analysis_cn.dat"]:
cn_path = os.path.join(run_dir, cn_name)
if os.path.isfile(cn_path):
info["cn"] = parse_cn_file(cn_path)
break
results.append(info)
if not results:
print("No valid runs found.")
sys.exit(1)
results.sort(key=lambda r: r.get("chi2", float("inf")))
print(f"\nEnsemble summary ({len(results)} runs):")
print("-" * 80)
has_components = any("sq_chi2" in r for r in results)
densities = [r.get("density") for r in results]
has_density = any(d is not None for d in densities)
show_density = has_density and len(set(d for d in densities if d is not None)) > 1
def fmt_row(r):
density_str = f" {r.get('density', float('nan')):>8.4f}" if show_density else ""
if has_components:
return (
f"{r['dir']:<10} {r.get('chi2', float('nan')):>10.4f} "
f"{r.get('sq_chi2', float('nan')):>10.4f} {r.get('gr_chi2', float('nan')):>10.4f}"
f"{density_str}"
f" {r.get('acceptance', float('nan')):>7.1f}% {r.get('best_move', 0):>10d} "
f"{r.get('seed', 0):>12d}"
)
else:
return (
f"{r['dir']:<10} {r.get('chi2', float('nan')):>10.6f}"
f"{density_str}"
f" {r.get('acceptance', float('nan')):>7.1f}% {r.get('best_move', 0):>10d} "
f"{r.get('seed', 0):>12d}"
)
density_hdr = f" {'g/cm3':>8}" if show_density else ""
if has_components:
print(f"{'Run':<10} {'chi2':>10} {'sq':>10} {'gr':>10}{density_hdr} {'accept%':>8} {'best_move':>10} {'seed':>12}")
else:
print(f"{'Run':<10} {'chi2':>10}{density_hdr} {'accept%':>8} {'best_move':>10} {'seed':>12}")
print("-" * 80)
for r in results:
print(fmt_row(r))
chi2_vals = [r["chi2"] for r in results if "chi2" in r]
if len(chi2_vals) > 1:
print("-" * 80)
print(f"{'chi2:':<10} min = {min(chi2_vals):.6f}, max = {max(chi2_vals):.6f}, "
f"mean = {np.mean(chi2_vals):.6f}, std = {np.std(chi2_vals):.6f}")
best = results[0]
print(f"\nBest run: {best['dir']} (chi2 = {best.get('chi2', float('nan')):.6f})")
cn_results = [(r["dir"], r["cn"]) for r in results if "cn" in r]
if cn_results:
all_pairs = sorted(set(p for _, cn in cn_results for p in cn))
print(f"\nCoordination numbers:")
print("-" * (12 + 10 * len(all_pairs)))
header = f"{'Run':<12}" + "".join(f"{p:>10}" for p in all_pairs)
print(header)
print("-" * (12 + 10 * len(all_pairs)))
for name, cn in cn_results:
row = f"{name:<12}" + "".join(f"{cn.get(p, float('nan')):>10.3f}" for p in all_pairs)
print(row)
if len(cn_results) > 1:
print("-" * (12 + 10 * len(all_pairs)))
means = {p: np.mean([cn.get(p, np.nan) for _, cn in cn_results]) for p in all_pairs}
stds = {p: np.std([cn.get(p, np.nan) for _, cn in cn_results]) for p in all_pairs}
print(f"{'mean':<12}" + "".join(f"{means[p]:>10.3f}" for p in all_pairs))
print(f"{'std':<12}" + "".join(f"{stds[p]:>10.3f}" for p in all_pairs))
if HAS_PLT:
import tomllib
with open(config_path, "rb") as f:
config = tomllib.load(f)
sq_types = []
for data_type, key in [("xray", "xray_sq"), ("neutron", "neutron_sq")]:
if config.get("data", {}).get(key) is not None:
sq_types.append((data_type, key))
has_partial_gr = any(
os.path.isfile(os.path.join(d, "refined_gr.dat")) for d in run_dirs
)
has_total_gr = any(
os.path.isfile(os.path.join(d, "refined_total_gr.dat")) for d in run_dirs
)
has_total_fr = any(
os.path.isfile(os.path.join(d, "refined_total_fr.dat")) for d in run_dirs
)
has_gr_exp = config.get("data", {}).get("xray_gr") is not None
has_fr_exp = config.get("data", {}).get("xray_fr") is not None
n_panels = (len(sq_types) + (1 if has_total_gr else 0)
+ (1 if has_total_fr else 0) + (1 if has_partial_gr else 0))
if n_panels > 0:
fig, axes = plt.subplots(n_panels, 1, figsize=(8, 4 * n_panels), squeeze=False)
panel = 0
for data_type, key in sq_types:
ax = axes[panel, 0]
panel += 1
exp_path = os.path.join(config_dir, config["data"][key]["file"])
if os.path.isfile(exp_path):
sq_exp = np.loadtxt(exp_path)
ax.plot(sq_exp[:, 0], sq_exp[:, 1], "k--", lw=2, alpha=0.7, label="Experiment")
q_max = sq_exp[:, 0].max()
else:
q_max = 25.0
for r in results:
run_dir = next(d for d in run_dirs if os.path.basename(d) == r["dir"])
for name in [f"refined_{data_type}_sq.dat", "refined_sq.dat"]:
sq_path = os.path.join(run_dir, name)
if os.path.isfile(sq_path):
sq = np.loadtxt(sq_path, comments="#")
mask = sq[:, 0] <= q_max
chi2_label = f" ({r.get('chi2', 0):.4f})" if "chi2" in r else ""
ax.plot(sq[mask, 0], sq[mask, 1], lw=0.8, alpha=0.6,
label=f"{r['dir']}{chi2_label}")
break
label_nice = "X-ray" if data_type == "xray" else "Neutron"
ax.set(xlabel="Q (1/\u00c5)", ylabel="S(Q)", xlim=(0.3, q_max))
ax.set_title(f"{label_nice} S(Q) \u2014 ensemble comparison")
ax.legend(fontsize=7, ncol=2)
if has_total_gr:
ax = axes[panel, 0]
panel += 1
if has_gr_exp:
gr_exp_path = os.path.join(config_dir, config["data"]["xray_gr"]["file"])
if os.path.isfile(gr_exp_path):
gr_exp = np.loadtxt(gr_exp_path)
ax.plot(gr_exp[:, 0], gr_exp[:, 1], "k--", lw=2, alpha=0.7, label="Experiment")
for r in results:
run_dir = next(d for d in run_dirs if os.path.basename(d) == r["dir"])
tgr_path = os.path.join(run_dir, "refined_total_gr.dat")
if not os.path.isfile(tgr_path):
continue
tgr = np.loadtxt(tgr_path, comments="#")
chi2_label = f" ({r.get('chi2', 0):.4f})" if "chi2" in r else ""
ax.plot(tgr[:, 0], tgr[:, 1], lw=0.8, alpha=0.6,
label=f"{r['dir']}{chi2_label}")
ax.axhline(1, color="gray", ls="--", lw=0.5)
ax.set(xlabel="r (\u00c5)", ylabel="g(r)", xlim=(0, 10))
ax.set_title("Total X-ray g(r) \u2014 ensemble comparison")
ax.legend(fontsize=7, ncol=2)
if has_total_fr:
ax = axes[panel, 0]
panel += 1
if sq_types:
_, first_key = sq_types[0]
exp_sq_path = os.path.join(config_dir, config["data"][first_key]["file"])
if os.path.isfile(exp_sq_path):
sq_exp_data = np.loadtxt(exp_sq_path)
q_exp = sq_exp_data[:, 0]
s_exp = sq_exp_data[:, 1]
fr_cfg = config.get("data", {}).get("xray_fr", config.get("data", {}).get("xray_gr", {}))
qmax_fr = fr_cfg.get("qmax", q_exp.max())
use_lorch = fr_cfg.get("lorch", True)
mask_q = q_exp <= qmax_fr
q_eff = q_exp[mask_q]
s_eff = s_exp[mask_q]
dq_exp = q_eff[1] - q_eff[0] if len(q_eff) > 1 else 1.0
window = np.ones_like(q_eff)
if use_lorch:
arg = np.pi * q_eff / qmax_fr
window = np.where(arg > 1e-10, np.sin(arg) / arg, 1.0)
r_plot = np.linspace(0.1, 10, 500)
fr_exp = np.zeros_like(r_plot)
pref = 2.0 * dq_exp / np.pi
for i, ri in enumerate(r_plot):
fr_exp[i] = pref * np.sum(q_eff * window * (s_eff - 1.0) * np.sin(q_eff * ri))
ax.plot(r_plot, fr_exp, "k--", lw=2, alpha=0.7, label="Experiment (from S(Q))")
for r in results:
run_dir = next(d for d in run_dirs if os.path.basename(d) == r["dir"])
tfr_path = os.path.join(run_dir, "refined_total_fr.dat")
if not os.path.isfile(tfr_path):
continue
tfr = np.loadtxt(tfr_path, comments="#")
chi2_label = f" ({r.get('chi2', 0):.4f})" if "chi2" in r else ""
ax.plot(tfr[:, 0], tfr[:, 1], lw=0.8, alpha=0.6,
label=f"{r['dir']}{chi2_label}")
ax.axhline(0, color="gray", ls="--", lw=0.5)
ax.set(xlabel="r (\u00c5)", ylabel="f(r)", xlim=(0, 10))
ax.set_title("Total X-ray f(r) \u2014 ensemble comparison")
ax.legend(fontsize=7, ncol=2)
if has_partial_gr:
ax = axes[panel, 0]
panel += 1
pair_labels: list[str] = []
for d in run_dirs:
gr_path = os.path.join(d, "refined_gr.dat")
if os.path.isfile(gr_path):
with open(gr_path) as fh:
header = fh.readline().strip().lstrip("# ").split()
pair_labels = [col.replace("g_", "").replace("_", "") for col in header[1:]]
pair_labels = [re.sub(r"([A-Z][a-z]?)([A-Z])", r"\1-\2", lbl) for lbl in pair_labels]
break
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
n_pairs = len(pair_labels) if pair_labels else 0
for ri, r in enumerate(results):
run_dir = next(d for d in run_dirs if os.path.basename(d) == r["dir"])
gr_path = os.path.join(run_dir, "refined_gr.dat")
if not os.path.isfile(gr_path):
continue
gr = np.loadtxt(gr_path, comments="#")
r_vals = gr[:, 0]
for pi in range(n_pairs):
color = colors[pi % len(colors)]
label = f"{pair_labels[pi]}" if ri == 0 else None
ax.plot(r_vals, gr[:, 1 + pi], lw=0.6, alpha=0.4, color=color, label=label)
ax.axhline(1, color="gray", ls="--", lw=0.5)
ax.set(xlabel="r (\u00c5)", ylabel="g(r)", xlim=(0, 10))
ax.set_title("Partial g(r) \u2014 ensemble comparison")
if pair_labels:
ax.legend(fontsize=7, ncol=2)
plt.tight_layout()
outfile = os.path.join(config_dir, "ensemble_comparison.png")
plt.savefig(outfile, dpi=150, bbox_inches="tight")
print(f"\nSaved ensemble comparison to {outfile}")
plt.show()