interpn 0.11.0

N-dimensional interpolation/extrapolation methods, no-std and no-alloc compatible.
Documentation
"""Benchmarks examining memory usage"""

# ruff: noqa B023
import gc
import time
from pathlib import Path

from memory_profiler import memory_usage

import numpy as np
from scipy.interpolate import RegularGridInterpolator
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from interpn import (
    MultilinearRectilinear,
    MultilinearRegular,
    MulticubicRegular,
    MulticubicRectilinear,
    NearestRegular,
    NearestRectilinear,
)


def bench_eval_mem_vs_dims():
    usages = {
        "Scipy RegularGridInterpolator Linear": [],
        "Scipy RegularGridInterpolator Cubic": [],
        "InterpN MultilinearRegular": [],
        "InterpN MultilinearRectilinear": [],
        "InterpN MulticubicRegular": [],
        "InterpN MulticubicRectilinear": [],
        "InterpN NearestRegular": [],
        "InterpN NearestRectilinear": [],
    }
    ndims_to_test = [x for x in range(1, 9)]
    nobs = 10000
    for ndims in ndims_to_test:
        ngrid = 4  # Size of grid on each dimension

        grids = [np.linspace(-1.0, 1.0, ngrid) for _ in range(ndims)]
        xgrid = np.meshgrid(*grids, indexing="ij")
        zgrid = np.random.uniform(-1.0, 1.0, xgrid[0].size)

        dims = [x.size for x in grids]
        starts = np.array([x[0] for x in grids])
        steps = np.array([x[1] - x[0] for x in grids])

        # Initialize all interpolator methods
        # Scipy RegularGridInterpolator is actually a more general rectilinear method
        rectilinear_sp = RegularGridInterpolator(
            grids, zgrid.reshape(xgrid[0].shape), bounds_error=None
        )
        cubic_rectilinear_sp = RegularGridInterpolator(
            grids, zgrid.reshape(xgrid[0].shape), bounds_error=None, method="cubic"
        )
        rectilinear_interpn = MultilinearRectilinear.new(grids, zgrid)
        regular_interpn = MultilinearRegular.new(dims, starts, steps, zgrid)
        cubic_regular_interpn = MulticubicRegular.new(dims, starts, steps, zgrid)
        cubic_rectilinear_interpn = MulticubicRectilinear.new(grids, zgrid)

        m = max(int(float(nobs) ** (1.0 / ndims) + 2), 2)

        # Baseline interpolating on the same domain,
        # keeping the points entirely inside the domain to give a clear
        # cut between interpolation and extrapolation
        obsgrid = np.meshgrid(
            *[np.linspace(-0.99, 0.99, m) for _ in range(ndims)], indexing="ij"
        )
        obsgrid = [
            x.flatten()[0:nobs] for x in obsgrid
        ]  # Trim to the exact right number

        # Preallocate output for potential perf advantage
        # Allocate at eval for 1:1 comparison with Scipy
        interps = {
            "Scipy RegularGridInterpolator Linear": rectilinear_sp,
            "Scipy RegularGridInterpolator Cubic": cubic_rectilinear_sp,
            "InterpN MultilinearRegular": lambda p: regular_interpn.eval,
            "InterpN MultilinearRectilinear": lambda p: rectilinear_interpn.eval,
            "InterpN MulticubicRegular": lambda p: cubic_regular_interpn.eval,
            "InterpN MulticubicRectilinear": lambda p: cubic_rectilinear_interpn.eval,
        }

        # Interpolation in random order
        points_interpn = [np.random.permutation(x.flatten()) for x in obsgrid]
        points_sp = np.ascontiguousarray(np.array(points_interpn).T)
        points = {
            "Scipy RegularGridInterpolator Linear": points_sp,
            "Scipy RegularGridInterpolator Cubic": points_sp,
            "InterpN MultilinearRegular": points_interpn,
            "InterpN MultilinearRectilinear": points_interpn,
            "InterpN MulticubicRegular": points_interpn,
            "InterpN MulticubicRectilinear": points_interpn,
        }

        for name, func in interps.items():
            print(ndims, name)
            gc.collect()
            time.sleep(0.1)
            p = points[name]
            mems = memory_usage((func, (p,), {}), interval=1e-9, backend="psutil")
            usages[name].append(max(mems))

        kinds = {
            "Scipy RegularGridInterpolator Linear": "Linear",
            "Scipy RegularGridInterpolator Cubic": "Cubic",
            "InterpN MultilinearRegular": "Linear",
            "InterpN MultilinearRectilinear": "Linear",
            "InterpN MulticubicRegular": "Cubic",
            "InterpN MulticubicRectilinear": "Cubic",
        }

    dash_styles = ["dot", "solid", "dash", "dashdot", "longdashdot"]
    fig = make_subplots(
        rows=1,
        cols=2,
        shared_yaxes=True,
        subplot_titles=["Linear", "Cubic"],
        horizontal_spacing=0.08,
    )

    for col, kind in enumerate(["Linear", "Cubic"], start=1):
        usages_this_kind = [(k, v) for k, v in usages.items() if kinds[k] == kind]
        if not usages_this_kind:
            continue
        for idx, (label, values) in enumerate(usages_this_kind):
            fig.add_trace(
                go.Scatter(
                    x=ndims_to_test[: len(values)],
                    y=values,
                    mode="lines+markers",
                    line=dict(
                        color="black",
                        width=2,
                        dash=dash_styles[idx % len(dash_styles)],
                    ),
                    opacity=0.5 if idx == 0 else 1.0,
                    name=label,
                    showlegend=col == 1,
                ),
                row=1,
                col=col,
            )

    fig.update_yaxes(
        type="log",
        title_text="Peak Memory Usage [MB]",
        row=1,
        col=1,
        showline=True,
        linecolor="black",
        linewidth=1,
        mirror=True,
        ticks="outside",
        tickcolor="black",
        showgrid=False,
        zeroline=False,
    )
    fig.update_yaxes(
        type="log",
        row=1,
        col=2,
        showline=True,
        linecolor="black",
        linewidth=1,
        mirror=True,
        ticks="outside",
        tickcolor="black",
        showgrid=False,
        zeroline=False,
    )
    fig.update_xaxes(
        title_text="Number of Dimensions",
        row=1,
        col=1,
        showline=True,
        linecolor="black",
        linewidth=1,
        mirror=True,
        ticks="outside",
        tickcolor="black",
        showgrid=False,
        zeroline=False,
    )
    fig.update_xaxes(
        title_text="Number of Dimensions",
        row=1,
        col=2,
        showline=True,
        linecolor="black",
        linewidth=1,
        mirror=True,
        ticks="outside",
        tickcolor="black",
        showgrid=False,
        zeroline=False,
    )
    fig.update_layout(
        title=dict(
            text=f"Interpolation on 4x...x4 N-Dimensional Grid — {nobs} Observation Points",
            y=0.97,
            yanchor="top",
        ),
        height=450,
        margin=dict(t=80, l=60, r=200, b=90),
        legend=dict(
            orientation="v",
            yanchor="top",
            y=1.0,
            x=1.02,
            xanchor="left",
        ),
        plot_bgcolor="rgba(0,0,0,0)",
        paper_bgcolor="rgba(0,0,0,0)",
        font=dict(color="black"),
    )

    output_path = Path(__file__).parent / "../docs/ram_vs_dims.svg"
    fig.write_image(str(output_path))
    fig.write_html(
        str(output_path.with_suffix(".html")), include_plotlyjs="cdn", full_html=False
    )
    fig.show()


if __name__ == "__main__":
    bench_eval_mem_vs_dims()