chela 0.0.2

High-performance Machine Learning, Auto-Differentiation and Tensor Algebra crate for Rust
Documentation
from collections import defaultdict
import colorsys

import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter, StrMethodFormatter
import numpy as np

from .result import Result


def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
    hex_color = hex_color.lstrip("#")
    r = int(hex_color[0:2], 16) / 255
    g = int(hex_color[2:4], 16) / 255
    b = int(hex_color[4:6], 16) / 255
    return r, g, b


def scale_lightness(hex_color: str, scale_l: float) -> tuple[float, float, float]:
    rgb = hex_to_rgb(hex_color)
    h, l, s = colorsys.rgb_to_hls(*rgb)
    return colorsys.hls_to_rgb(h, min(1, l * scale_l), s=min(1, s * scale_l))


colors = {
    "NumPy":        "#4dabcf",
    "PyTorch CPU":  "#f2765d",
    "PyTorch MPS":  "#812ce5",
    "Chela CPU":    "#ce422b",
    "Chela Einsum": "#ba2323"
}


def plot_results(x, results: dict[str, list[Result]],
                 title: str,
                 xlog=True,
                 ylog=False) -> None:
    plt.figure(figsize=(10, 7))
    ax = plt.gca()

    plt.title(title)

    for label, result in results.items():
        upper_bound = [res.mean + res.se() for res in result]
        lower_bound = [res.mean - res.se() for res in result]

        color = colors.get(label)

        ax.fill_between(x, upper_bound, lower_bound, color=color, alpha=0.2, ec=None)
        ax.plot(x, result, label=label, color=color)

    ax.yaxis.set_major_formatter(FuncFormatter(lambda val, _: f"{val:.3f} ms"))

    if xlog:
        plt.xscale("symlog")
    if ylog:
        plt.yscale("symlog")

    plt.legend()

    plt.show()


def plot_barplot(results: dict[str, dict[str, Result]],
                 title: str,
                 normalize: str = "NumPy") -> None:

    suites = []
    data: dict[str, np.ndarray] | defaultdict = defaultdict(list)
    errors = defaultdict(list)
    max_ = 0

    for suite, results in results.items():
        suites.append(suite.replace("->", ""))

        for series, result in results.items():
            data[series].append(1000 * float(result))
            errors[series].append(1000 * result.se())

    normalizer = np.array(data[normalize])
    for (series, result), (_, error) in zip(data.items(), errors.items()):
        data[series] = np.array(result) / normalizer
        errors[series] = np.array(error) / normalizer
        max_ = max(max_, data[series].max())

    x = np.arange(len(suites))
    width = 0.25
    multiplier = 0

    fig, ax = plt.subplots(figsize=(max(7, int(1.5 * len(suites))), 7))
    ax.set_ylabel(f"Time (relative to {normalize})")
    ax.set_title(title)

    for (series, measurement), (_, error) in zip(data.items(), errors.items()):
        color = colors.get(series)
        ecolor = scale_lightness(color, 0.9)

        offset = width * multiplier

        rects = ax.bar(x + offset, height=measurement, width=width, color=color, label=series,
                       yerr=error, error_kw={"elinewidth": 1, "mew": 1, "capsize": 4, "ecolor": ecolor,
                                             "mfc":        ecolor, "mec": ecolor, "ms": 4, "fmt": "o"})
        ax.bar_label(rects, fmt="{:.2f}x", padding=3, size=6)
        multiplier += 1

    ax.set_xticks(x + width, suites, size=7)
    ax.set_ylim(0, min(1.1 * max_, 10))
    ax.legend(loc="best", ncols=3)
    ax.yaxis.set_major_formatter(StrMethodFormatter("{x:.1f}x"))

    plt.show()


__all__ = ["plot_results", "plot_barplot"]