ganesh 0.26.3

Minimization and sampling in Rust, simplified
Documentation
# /// script
# requires-python = ">=3.13"
# dependencies = [
#     "matplotlib",
#     "matplotloom",
#     "numpy",
#     "corner",
#     "joblib",
#     "polars",
# ]
# ///
import pickle
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
from corner import corner, overplot_lines
from joblib import Parallel, delayed
from matplotloom import Loom

if __name__ == '__main__':
    print('Plotting dataset...')
    with Path('data.pkl').open('rb') as f:
        data = np.array(pickle.load(f)).transpose()
    plt.hist2d(*data, bins=100, cmap='gist_heat_r')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Sampled Dataset')
    plt.savefig('data.svg')
    plt.close()

    print('Plotting traces (no burn-in)...')
    parameter_labels = [
        r'$\mu_0$',
        r'$\mu_1$',
        r'$\Sigma_{00}$',
        r'$\Sigma_{01}$',
        r'$\Sigma_{11}$',
    ]
    with Path('fit.pkl').open('rb') as f:
        fit_result_data = pickle.load(f)
    truths = np.array(fit_result_data[0])
    fit_result = np.array(fit_result_data[1])
    fit_result_err = np.array(fit_result_data[2])
    with Path('chain.pkl').open('rb') as f:
        chain, burn = pickle.load(f)
    chain = np.array(chain)
    n_walkers, n_steps, n_parameters = chain.shape
    _, ax = plt.subplots(nrows=n_parameters, sharex=True, figsize=(10, 50))
    steps = np.arange(n_steps)
    for i in range(n_parameters):
        for j in range(n_walkers):
            ax[i].plot(steps[burn:], chain[j, burn:, i], color='k', alpha=0.1)
            ax[i].plot(steps[:burn], chain[j, :burn, i], color='k', ls='--', alpha=0.1)
        ax[i].plot(steps[burn:], chain[0, burn:, i], color='m', label='Walker 0')
        ax[i].plot(
            steps[:burn],
            chain[0, :burn, i],
            color='m',
            ls='--',
            label='Walker 0 (burn-in)',
        )
        ax[i].axhline(fit_result[i], color='b', label='Best fit')
        ax[i].axhline(truths[i], color='r', label='Truth')
        ax[i].set_xlabel('Step')
        ax[i].set_ylabel(parameter_labels[i])
        ax[i].legend()
    plt.savefig('traces.svg')
    plt.close()

    print('Plotting traces (with burn-in)...')
    _, ax = plt.subplots(nrows=n_parameters, sharex=True, figsize=(10, 50))
    steps = np.arange(n_steps)
    for i in range(n_parameters):
        for j in range(n_walkers):
            ax[i].plot(steps[burn:], chain[j, burn:, i], color='k', alpha=0.1)
        ax[i].plot(steps[burn:], chain[0, burn:, i], color='m', label='Walker 0')
        ax[i].axhline(fit_result[i], color='b', label='Best fit')
        ax[i].axhline(truths[i], color='r', label='Truth')
        ax[i].set_xlabel('Step')
        ax[i].set_ylabel(parameter_labels[i])
        ax[i].legend()
    plt.savefig('traces_burned.svg')
    plt.close()

    print('Plotting corner plot...')
    ci = 68.27
    with Path('flat_chain.pkl').open('rb') as f:
        flat_chain = np.array(pickle.load(f))
    fig = corner(
        flat_chain,
        labels=parameter_labels,
        truths=truths,
        truth_color='r',
        quantiles=[(50 - ci / 2) / 100, 0.5, (50 + ci / 2) / 100],
        show_titles=True,
        title_fmt='.4f',
    )
    overplot_lines(
        fig,
        fit_result,
        color='b',
    )
    plt.savefig('corner.svg')
    plt.close()

    def compute_ranges(chain, pad_frac=0.02):
        mins = chain.min(axis=(0, 1))
        maxs = chain.max(axis=(0, 1))
        widths = np.maximum(maxs - mins, 1e-9)
        mins = mins - pad_frac * widths
        maxs = maxs + pad_frac * widths
        return [(float(a), float(b)) for a, b in zip(mins, maxs, strict=True)]

    def make_frame(i, chain, labels, ranges, loom):
        j0 = max(0, i - 10)
        window = chain[:, j0 : i + 1, :]
        flat = window.reshape(-1, window.shape[-1])

        fig = corner(
            flat,
            labels=labels,
            range=ranges,
            plot_contours=False,
            show_titles=False,
        )
        loom.save_frame(fig, i)
        plt.close(fig)

    burned_chain = chain[:, burn:, :]
    n_steps = burned_chain.shape[1]
    ranges = compute_ranges(burned_chain)

    print('Making animated corner plot...')
    with Loom('walkers_corner.gif', fps=20, parallel=True, overwrite=True) as loom:
        Parallel(n_jobs=-1, prefer='processes')(
            delayed(make_frame)(i, burned_chain, parameter_labels, ranges, loom)
            for i in range(n_steps)
        )

    parameter_labels_unicode = ['μ₀', 'μ₁', 'Σ₀₀', 'Σ₀₁', 'Σ₁₁']
    qlo, qmid, qhi = (50 - ci / 2, 50, 50 + ci / 2)
    lo, mid, hi = np.percentile(flat_chain, [qlo, qmid, qhi], axis=0)
    mcmc_err_minus = mid - lo
    mcmc_err_plus = hi - mid
    fit_col = [f'{v:.6g}' for v in fit_result]
    cov_col = [f'±{e:.3g}' for e in fit_result_err]
    mcmc_col = [
        f'-{em:.3g} / +{ep:.3g}'
        for em, ep in zip(mcmc_err_minus, mcmc_err_plus, strict=True)
    ]
    truth_col = [f'{t:.6g}' for t in truths]
    print('Summary')
    print(
        pl.DataFrame(
            {
                'Parameter': parameter_labels_unicode,
                'Truths': truth_col,
                'Fit Result': fit_col,
                'Uncertainty (Covariance)': cov_col,
                'Uncertainty (MCMC)': mcmc_col,
            }
        )
    )