import gc
import time
from pathlib import Path
from memory_profiler import memory_usage
import numpy as np
from scipy.interpolate import RegularGridInterpolator
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from interpn import (
MultilinearRectilinear,
MultilinearRegular,
MulticubicRegular,
MulticubicRectilinear,
NearestRegular,
NearestRectilinear,
)
def bench_eval_mem_vs_dims():
usages = {
"Scipy RegularGridInterpolator Linear": [],
"Scipy RegularGridInterpolator Cubic": [],
"InterpN MultilinearRegular": [],
"InterpN MultilinearRectilinear": [],
"InterpN MulticubicRegular": [],
"InterpN MulticubicRectilinear": [],
"InterpN NearestRegular": [],
"InterpN NearestRectilinear": [],
}
ndims_to_test = [x for x in range(1, 9)]
nobs = 10000
for ndims in ndims_to_test:
ngrid = 4
grids = [np.linspace(-1.0, 1.0, ngrid) for _ in range(ndims)]
xgrid = np.meshgrid(*grids, indexing="ij")
zgrid = np.random.uniform(-1.0, 1.0, xgrid[0].size)
dims = [x.size for x in grids]
starts = np.array([x[0] for x in grids])
steps = np.array([x[1] - x[0] for x in grids])
rectilinear_sp = RegularGridInterpolator(
grids, zgrid.reshape(xgrid[0].shape), bounds_error=None
)
cubic_rectilinear_sp = RegularGridInterpolator(
grids, zgrid.reshape(xgrid[0].shape), bounds_error=None, method="cubic"
)
rectilinear_interpn = MultilinearRectilinear.new(grids, zgrid)
regular_interpn = MultilinearRegular.new(dims, starts, steps, zgrid)
cubic_regular_interpn = MulticubicRegular.new(dims, starts, steps, zgrid)
cubic_rectilinear_interpn = MulticubicRectilinear.new(grids, zgrid)
m = max(int(float(nobs) ** (1.0 / ndims) + 2), 2)
obsgrid = np.meshgrid(
*[np.linspace(-0.99, 0.99, m) for _ in range(ndims)], indexing="ij"
)
obsgrid = [
x.flatten()[0:nobs] for x in obsgrid
]
interps = {
"Scipy RegularGridInterpolator Linear": rectilinear_sp,
"Scipy RegularGridInterpolator Cubic": cubic_rectilinear_sp,
"InterpN MultilinearRegular": lambda p: regular_interpn.eval,
"InterpN MultilinearRectilinear": lambda p: rectilinear_interpn.eval,
"InterpN MulticubicRegular": lambda p: cubic_regular_interpn.eval,
"InterpN MulticubicRectilinear": lambda p: cubic_rectilinear_interpn.eval,
}
points_interpn = [np.random.permutation(x.flatten()) for x in obsgrid]
points_sp = np.ascontiguousarray(np.array(points_interpn).T)
points = {
"Scipy RegularGridInterpolator Linear": points_sp,
"Scipy RegularGridInterpolator Cubic": points_sp,
"InterpN MultilinearRegular": points_interpn,
"InterpN MultilinearRectilinear": points_interpn,
"InterpN MulticubicRegular": points_interpn,
"InterpN MulticubicRectilinear": points_interpn,
}
for name, func in interps.items():
print(ndims, name)
gc.collect()
time.sleep(0.1)
p = points[name]
mems = memory_usage((func, (p,), {}), interval=1e-9, backend="psutil")
usages[name].append(max(mems))
kinds = {
"Scipy RegularGridInterpolator Linear": "Linear",
"Scipy RegularGridInterpolator Cubic": "Cubic",
"InterpN MultilinearRegular": "Linear",
"InterpN MultilinearRectilinear": "Linear",
"InterpN MulticubicRegular": "Cubic",
"InterpN MulticubicRectilinear": "Cubic",
}
dash_styles = ["dot", "solid", "dash", "dashdot", "longdashdot"]
fig = make_subplots(
rows=1,
cols=2,
shared_yaxes=True,
subplot_titles=["Linear", "Cubic"],
horizontal_spacing=0.08,
)
for col, kind in enumerate(["Linear", "Cubic"], start=1):
usages_this_kind = [(k, v) for k, v in usages.items() if kinds[k] == kind]
if not usages_this_kind:
continue
for idx, (label, values) in enumerate(usages_this_kind):
fig.add_trace(
go.Scatter(
x=ndims_to_test[: len(values)],
y=values,
mode="lines+markers",
line=dict(
color="black",
width=2,
dash=dash_styles[idx % len(dash_styles)],
),
opacity=0.5 if idx == 0 else 1.0,
name=label,
showlegend=col == 1,
),
row=1,
col=col,
)
fig.update_yaxes(
type="log",
title_text="Peak Memory Usage [MB]",
row=1,
col=1,
showline=True,
linecolor="black",
linewidth=1,
mirror=True,
ticks="outside",
tickcolor="black",
showgrid=False,
zeroline=False,
)
fig.update_yaxes(
type="log",
row=1,
col=2,
showline=True,
linecolor="black",
linewidth=1,
mirror=True,
ticks="outside",
tickcolor="black",
showgrid=False,
zeroline=False,
)
fig.update_xaxes(
title_text="Number of Dimensions",
row=1,
col=1,
showline=True,
linecolor="black",
linewidth=1,
mirror=True,
ticks="outside",
tickcolor="black",
showgrid=False,
zeroline=False,
)
fig.update_xaxes(
title_text="Number of Dimensions",
row=1,
col=2,
showline=True,
linecolor="black",
linewidth=1,
mirror=True,
ticks="outside",
tickcolor="black",
showgrid=False,
zeroline=False,
)
fig.update_layout(
title=dict(
text=f"Interpolation on 4x...x4 N-Dimensional Grid — {nobs} Observation Points",
y=0.97,
yanchor="top",
),
height=450,
margin=dict(t=80, l=60, r=200, b=90),
legend=dict(
orientation="v",
yanchor="top",
y=1.0,
x=1.02,
xanchor="left",
),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
font=dict(color="black"),
)
output_path = Path(__file__).parent / "../docs/ram_vs_dims.svg"
fig.write_image(str(output_path))
fig.write_html(
str(output_path.with_suffix(".html")), include_plotlyjs="cdn", full_html=False
)
fig.show()
if __name__ == "__main__":
bench_eval_mem_vs_dims()