/// @module std::math::interpolation
/// Grid Interpolation — Trilinear and bilinear interpolation on flat grids
///
/// Grids are stored as flat 1D arrays with manual index computation.
/// Query points are provided as Mat<number> (Nx3 or Nx2).
/// Clamp a value to [lo, hi].
fn clamp<T: Ord>(x: T, lo: T, hi: T) -> T {
if x < lo { lo }
else if x > hi { hi }
else { x }
}
/// Linearly interpolate between two values.
fn lerp<T: Add + Sub + Mul>(a: T, b: T, t: T) -> T {
a + (b - a) * t
}
/// Trilinear interpolation on a flat 3D grid.
///
/// @param grid - Flat array of grid values in [z][y][x] order
/// @param shape - Grid dimensions [nz, ny, nx]
/// @param lo - Lower bounds [z_lo, y_lo, x_lo]
/// @param hi - Upper bounds [z_hi, y_hi, x_hi]
/// @param points - Nx3 Mat<number> of query points, each row [x, y, z]
/// @returns Array<number> of interpolated values, one per query point
///
/// @example
/// trilinear([0,1,2,3,4,5,6,7], [2,2,2], [0,0,0], [1,1,1], points)
pub fn trilinear(grid, shape, lo, hi, points) {
let nz = shape[0]
let ny = shape[1]
let nx = shape[2]
let n = points.shape()[0]
let mut result = []
let mut row = 0
while row < n {
let pt = points.row(row)
let px = pt[0]
let py = pt[1]
let pz = pt[2]
// Map world coords to grid coords
let gx = if nx > 1 { (px - lo[2]) / (hi[2] - lo[2]) * (nx - 1) } else { 0.0 }
let gy = if ny > 1 { (py - lo[1]) / (hi[1] - lo[1]) * (ny - 1) } else { 0.0 }
let gz = if nz > 1 { (pz - lo[0]) / (hi[0] - lo[0]) * (nz - 1) } else { 0.0 }
// Integer indices
let ix0 = floor(gx)
let iy0 = floor(gy)
let iz0 = floor(gz)
// Fractional parts
let fx = gx - ix0
let fy = gy - iy0
let fz = gz - iz0
// Clamp indices
let ix0c = clamp(ix0, 0, nx - 1)
let ix1c = clamp(ix0 + 1, 0, nx - 1)
let iy0c = clamp(iy0, 0, ny - 1)
let iy1c = clamp(iy0 + 1, 0, ny - 1)
let iz0c = clamp(iz0, 0, nz - 1)
let iz1c = clamp(iz0 + 1, 0, nz - 1)
// Look up 8 corner values: grid[iz * ny*nx + iy * nx + ix]
let c000 = grid[iz0c * ny * nx + iy0c * nx + ix0c]
let c001 = grid[iz0c * ny * nx + iy0c * nx + ix1c]
let c010 = grid[iz0c * ny * nx + iy1c * nx + ix0c]
let c011 = grid[iz0c * ny * nx + iy1c * nx + ix1c]
let c100 = grid[iz1c * ny * nx + iy0c * nx + ix0c]
let c101 = grid[iz1c * ny * nx + iy0c * nx + ix1c]
let c110 = grid[iz1c * ny * nx + iy1c * nx + ix0c]
let c111 = grid[iz1c * ny * nx + iy1c * nx + ix1c]
// Interpolate along x
let c00 = lerp(c000, c001, fx)
let c01 = lerp(c010, c011, fx)
let c10 = lerp(c100, c101, fx)
let c11 = lerp(c110, c111, fx)
// Interpolate along y
let c0 = lerp(c00, c01, fy)
let c1 = lerp(c10, c11, fy)
// Interpolate along z
let val = lerp(c0, c1, fz)
result = result.push(val)
row = row + 1
}
result
}
/// Bilinear interpolation on a flat 2D grid.
///
/// @param grid - Flat array of grid values in [y][x] order
/// @param shape - Grid dimensions [ny, nx]
/// @param lo - Lower bounds [y_lo, x_lo]
/// @param hi - Upper bounds [y_hi, x_hi]
/// @param points - Nx2 Mat<number> of query points, each row [x, y]
/// @returns Array<number> of interpolated values, one per query point
///
/// @example
/// bilinear([0,1,2,3], [2,2], [0,0], [1,1], points)
pub fn bilinear(grid, shape, lo, hi, points) {
let ny = shape[0]
let nx = shape[1]
let n = points.shape()[0]
let mut result = []
let mut row = 0
while row < n {
let pt = points.row(row)
let px = pt[0]
let py = pt[1]
// Map world coords to grid coords
let gx = if nx > 1 { (px - lo[1]) / (hi[1] - lo[1]) * (nx - 1) } else { 0.0 }
let gy = if ny > 1 { (py - lo[0]) / (hi[0] - lo[0]) * (ny - 1) } else { 0.0 }
// Integer indices
let ix0 = floor(gx)
let iy0 = floor(gy)
// Fractional parts
let fx = gx - ix0
let fy = gy - iy0
// Clamp indices
let ix0c = clamp(ix0, 0, nx - 1)
let ix1c = clamp(ix0 + 1, 0, nx - 1)
let iy0c = clamp(iy0, 0, ny - 1)
let iy1c = clamp(iy0 + 1, 0, ny - 1)
// Look up 4 corner values: grid[iy * nx + ix]
let c00 = grid[iy0c * nx + ix0c]
let c01 = grid[iy0c * nx + ix1c]
let c10 = grid[iy1c * nx + ix0c]
let c11 = grid[iy1c * nx + ix1c]
// Interpolate along x, then y
let c0 = lerp(c00, c01, fx)
let c1 = lerp(c10, c11, fx)
let val = lerp(c0, c1, fy)
result = result.push(val)
row = row + 1
}
result
}