use crate::error::{
ArithmeticOverflowPayload, CapExceededPayload, Error, LengthMismatchPayload, OutOfRangePayload,
Result, try_with_capacity,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Filter {
Nearest,
Bilinear,
Bicubic,
Lanczos3,
}
const PRECISION_BITS: u32 = 32 - 8 - 2;
const ROUND_BIAS: i32 = 1 << (PRECISION_BITS - 1);
const CHANNELS: usize = 4;
const MAX_DECODED_IMAGE_BYTES: usize = 512 * 1024 * 1024;
fn checked_buffer_bytes(elems: usize, elem_size: usize, what: &'static str) -> Result<usize> {
let bytes = elems.checked_mul(elem_size).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"resize: buffer size (elems * elem_size)",
"usize",
[("elems", elems as u64), ("elem_size", elem_size as u64)],
))
})?;
if bytes > MAX_DECODED_IMAGE_BYTES {
return Err(Error::CapExceeded(CapExceededPayload::new(
what,
"MAX_DECODED_IMAGE_BYTES",
MAX_DECODED_IMAGE_BYTES as u64,
bytes as u64,
)));
}
Ok(bytes)
}
fn filter_support(f: Filter) -> f64 {
match f {
Filter::Nearest => 0.0,
Filter::Bilinear => 1.0,
Filter::Bicubic => 2.0,
Filter::Lanczos3 => 3.0,
}
}
fn filter_eval(f: Filter, x: f64) -> f64 {
match f {
Filter::Nearest => 0.0,
Filter::Bilinear => {
let x = x.abs();
if x < 1.0 { 1.0 - x } else { 0.0 }
}
Filter::Bicubic => {
const A: f64 = -0.5;
let x = x.abs();
if x < 1.0 {
((A + 2.0) * x - (A + 3.0)) * x * x + 1.0
} else if x < 2.0 {
(((x - 5.0) * x + 8.0) * x - 4.0) * A
} else {
0.0
}
}
Filter::Lanczos3 => {
let x = x.abs();
if x < 3.0 {
sinc(x) * sinc(x / 3.0)
} else {
0.0
}
}
}
}
fn sinc(x: f64) -> f64 {
if x == 0.0 {
1.0
} else {
let px = x * std::f64::consts::PI;
px.sin() / px
}
}
struct Coeffs {
bounds: Vec<(usize, usize)>,
weights: Vec<i32>,
ksize: usize,
}
fn precompute_coeffs(in_size: usize, out_size: usize, filter: Filter) -> Result<Coeffs> {
if in_size == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"precompute_coeffs: in_size",
"must be non-zero",
format!("{in_size}"),
)));
}
if out_size == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"precompute_coeffs: out_size",
"must be non-zero",
format!("{out_size}"),
)));
}
let scale = in_size as f64 / out_size as f64;
let filterscale = if scale < 1.0 { 1.0 } else { scale };
let support = filter_support(filter) * filterscale;
let ksize_unclamped = (support.ceil() as usize)
.checked_mul(2)
.and_then(|v| v.checked_add(1))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"precompute_coeffs: ksize (ceil(support) * 2 + 1)",
"usize",
[("ceil(support)", support.ceil() as u64)],
))
})?;
let ksize = ksize_unclamped.min(in_size.max(1));
checked_buffer_bytes(
out_size,
std::mem::size_of::<(usize, usize)>(),
"coefficient bounds table",
)?;
let mut bounds: Vec<(usize, usize)> = try_with_capacity(out_size)?;
let weight_len = out_size.checked_mul(ksize).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"precompute_coeffs: weight_len (out_size * ksize)",
"usize",
[("out_size", out_size as u64), ("ksize", ksize as u64)],
))
})?;
checked_buffer_bytes(
weight_len,
std::mem::size_of::<i32>(),
"coefficient weight table",
)?;
let mut weights: Vec<i32> = try_with_capacity(weight_len)?;
weights.resize(weight_len, 0i32);
checked_buffer_bytes(ksize, std::mem::size_of::<f64>(), "coefficient row scratch")?;
let mut row: Vec<f64> = try_with_capacity(ksize)?;
let inv_filterscale = 1.0 / filterscale;
for xx in 0..out_size {
let center = (xx as f64 + 0.5) * scale;
let xmin = {
let v = (center - support + 0.5).floor();
if v < 0.0 { 0 } else { v as usize }
};
let xmax = {
let v = (center + support + 0.5).floor();
let v = if v < 0.0 { 0usize } else { v as usize };
v.min(in_size)
};
let n = xmax.saturating_sub(xmin);
row.clear();
let mut wsum = 0.0f64;
for i in 0..n {
let w = filter_eval(
filter,
(xmin as f64 + i as f64 - center + 0.5) * inv_filterscale,
);
row.push(w);
wsum += w;
}
let base = xx * ksize;
if wsum != 0.0 {
let inv = 1.0 / wsum;
for (i, &w) in row.iter().enumerate() {
let scaled = (w * inv) * f64::from(1i32 << PRECISION_BITS);
weights[base + i] = scaled.round() as i32;
}
}
debug_assert!(
n <= ksize,
"precompute_coeffs: window n={n} exceeds ksize={ksize}"
);
bounds.push((xmin, n));
}
Ok(Coeffs {
bounds,
weights,
ksize,
})
}
#[inline]
fn clip8(acc: i32) -> u8 {
let v = acc >> PRECISION_BITS;
if v < 0 {
0
} else if v > 255 {
255
} else {
v as u8
}
}
pub(crate) fn resize_rgba8(
src: &[u8],
src_w: usize,
src_h: usize,
dst_w: usize,
dst_h: usize,
filter: Filter,
) -> Result<Vec<u8>> {
if src_w == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"resize_rgba8: src_w",
"must be non-zero",
format!("{src_w}"),
)));
}
if src_h == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"resize_rgba8: src_h",
"must be non-zero",
format!("{src_h}"),
)));
}
if dst_w == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"resize_rgba8: dst_w",
"must be non-zero",
format!("{dst_w}"),
)));
}
if dst_h == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"resize_rgba8: dst_h",
"must be non-zero",
format!("{dst_h}"),
)));
}
let src_len = src_w
.checked_mul(src_h)
.and_then(|v| v.checked_mul(CHANNELS))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"resize_rgba8: src_len (src_w * src_h * CHANNELS)",
"usize",
[
("src_w", src_w as u64),
("src_h", src_h as u64),
("CHANNELS", CHANNELS as u64),
],
))
})?;
if src.len() != src_len {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"resize_rgba8: src buffer bytes vs src_w * src_h * CHANNELS",
src_len,
src.len(),
)));
}
checked_buffer_bytes(src_len, 1, "resize_rgba8: RGBA8 source")?;
let dst_len = dst_w
.checked_mul(dst_h)
.and_then(|v| v.checked_mul(CHANNELS))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"resize_rgba8: dst_len (dst_w * dst_h * CHANNELS)",
"usize",
[
("dst_w", dst_w as u64),
("dst_h", dst_h as u64),
("CHANNELS", CHANNELS as u64),
],
))
})?;
checked_buffer_bytes(dst_len, 1, "resize_rgba8: destination RGBA8")?;
if filter == Filter::Nearest {
return resize_nearest(src, src_w, src_h, dst_w, dst_h, dst_len);
}
let src_pm = premultiply_rgba(src)?;
let hcoeffs = precompute_coeffs(src_w, dst_w, filter)?;
let vcoeffs = precompute_coeffs(src_h, dst_h, filter)?;
let inter_len = src_h
.checked_mul(dst_w)
.and_then(|v| v.checked_mul(CHANNELS))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"resize_rgba8: inter_len (src_h * dst_w * CHANNELS)",
"usize",
[
("src_h", src_h as u64),
("dst_w", dst_w as u64),
("CHANNELS", CHANNELS as u64),
],
))
})?;
checked_buffer_bytes(
inter_len,
1,
"resize_rgba8: horizontal-pass intermediate RGBA8",
)?;
let mut inter: Vec<u8> = try_with_capacity(inter_len)?;
inter.resize(inter_len, 0u8);
let mut dst: Vec<u8> = try_with_capacity(dst_len)?;
dst.resize(dst_len, 0u8);
convolve_axis(&src_pm, src_w, src_h, &mut inter, dst_w, &hcoeffs);
convolve_vertical(&inter, dst_w, src_h, &mut dst, dst_h, &vcoeffs);
unpremultiply_rgba(&mut dst);
Ok(dst)
}
#[inline]
fn muldiv255(c: u8, a: u8) -> u8 {
let tmp = u32::from(c) * u32::from(a) + 128;
(((tmp >> 8) + tmp) >> 8) as u8
}
fn premultiply_rgba(src: &[u8]) -> Result<Vec<u8>> {
let mut out: Vec<u8> = try_with_capacity(src.len())?;
for px in src.chunks_exact(CHANNELS) {
let a = px[3];
out.push(muldiv255(px[0], a));
out.push(muldiv255(px[1], a));
out.push(muldiv255(px[2], a));
out.push(a);
}
debug_assert_eq!(
out.len(),
src.len(),
"premultiply_rgba: src length must be a multiple of CHANNELS"
);
Ok(out)
}
fn unpremultiply_rgba(buf: &mut [u8]) {
for px in buf.chunks_exact_mut(CHANNELS) {
let a = px[3];
if a == 0 || a == 255 {
continue;
}
let a32 = u32::from(a);
px[0] = clip8_div(u32::from(px[0]), a32);
px[1] = clip8_div(u32::from(px[1]), a32);
px[2] = clip8_div(u32::from(px[2]), a32);
}
}
#[inline]
fn clip8_div(c: u32, a: u32) -> u8 {
let v = (255 * c) / a;
if v > 255 { 255 } else { v as u8 }
}
fn resize_nearest(
src: &[u8],
src_w: usize,
src_h: usize,
dst_w: usize,
dst_h: usize,
dst_len: usize,
) -> Result<Vec<u8>> {
checked_buffer_bytes(
dst_w,
std::mem::size_of::<usize>(),
"resize_nearest: x-index map",
)?;
let mut xmap: Vec<usize> = try_with_capacity(dst_w)?;
for ox in 0..dst_w {
let sx = ((ox as f64 + 0.5) * src_w as f64 / dst_w as f64).floor() as usize;
xmap.push(sx.min(src_w - 1));
}
checked_buffer_bytes(dst_len, 1, "resize_nearest: destination RGBA8")?;
let mut dst: Vec<u8> = try_with_capacity(dst_len)?;
dst.resize(dst_len, 0u8);
for oy in 0..dst_h {
let sy = (((oy as f64 + 0.5) * src_h as f64 / dst_h as f64).floor() as usize).min(src_h - 1);
let src_row = &src[sy * src_w * CHANNELS..(sy + 1) * src_w * CHANNELS];
let dst_row = &mut dst[oy * dst_w * CHANNELS..(oy + 1) * dst_w * CHANNELS];
for ox in 0..dst_w {
let sx = xmap[ox];
dst_row[ox * CHANNELS..ox * CHANNELS + CHANNELS]
.copy_from_slice(&src_row[sx * CHANNELS..sx * CHANNELS + CHANNELS]);
}
}
Ok(dst)
}
fn convolve_axis(
src: &[u8],
src_w: usize,
rows: usize,
out: &mut [u8],
out_w: usize,
coeffs: &Coeffs,
) {
assert_eq!(src.len(), src_w * rows * CHANNELS, "convolve_axis: src len");
assert_eq!(out.len(), out_w * rows * CHANNELS, "convolve_axis: out len");
assert_eq!(coeffs.bounds.len(), out_w, "convolve_axis: bounds len");
#[cfg(all(target_arch = "aarch64", not(mlxrs_force_scalar)))]
{
if std::arch::is_aarch64_feature_detected!("neon") {
unsafe {
convolve_axis_neon(src, src_w, rows, out, out_w, coeffs);
}
return;
}
}
convolve_axis_scalar(src, src_w, rows, out, out_w, coeffs);
}
fn convolve_vertical(
inter: &[u8],
out_w: usize,
src_h: usize,
out: &mut [u8],
out_h: usize,
coeffs: &Coeffs,
) {
assert_eq!(
inter.len(),
out_w * src_h * CHANNELS,
"convolve_vertical: inter len"
);
assert_eq!(
out.len(),
out_w * out_h * CHANNELS,
"convolve_vertical: out len"
);
assert_eq!(coeffs.bounds.len(), out_h, "convolve_vertical: bounds len");
#[cfg(all(target_arch = "aarch64", not(mlxrs_force_scalar)))]
{
if std::arch::is_aarch64_feature_detected!("neon") {
unsafe {
convolve_vertical_neon(inter, out_w, src_h, out, out_h, coeffs);
}
return;
}
}
convolve_vertical_scalar(inter, out_w, src_h, out, out_h, coeffs);
}
fn convolve_axis_scalar(
src: &[u8],
src_w: usize,
rows: usize,
out: &mut [u8],
out_w: usize,
coeffs: &Coeffs,
) {
let ksize = coeffs.ksize;
for y in 0..rows {
let src_row = &src[y * src_w * CHANNELS..(y + 1) * src_w * CHANNELS];
let out_row = &mut out[y * out_w * CHANNELS..(y + 1) * out_w * CHANNELS];
for ox in 0..out_w {
let (xmin, n) = coeffs.bounds[ox];
let taps = &coeffs.weights[ox * ksize..ox * ksize + n];
let mut acc = [ROUND_BIAS; CHANNELS];
for (i, &w) in taps.iter().enumerate() {
let px = &src_row[(xmin + i) * CHANNELS..(xmin + i) * CHANNELS + CHANNELS];
acc[0] += i32::from(px[0]) * w;
acc[1] += i32::from(px[1]) * w;
acc[2] += i32::from(px[2]) * w;
acc[3] += i32::from(px[3]) * w;
}
let o = &mut out_row[ox * CHANNELS..ox * CHANNELS + CHANNELS];
o[0] = clip8(acc[0]);
o[1] = clip8(acc[1]);
o[2] = clip8(acc[2]);
o[3] = clip8(acc[3]);
}
}
}
fn convolve_vertical_scalar(
inter: &[u8],
out_w: usize,
_src_h: usize,
out: &mut [u8],
out_h: usize,
coeffs: &Coeffs,
) {
let ksize = coeffs.ksize;
let row_stride = out_w * CHANNELS;
for oy in 0..out_h {
let (ymin, n) = coeffs.bounds[oy];
let taps = &coeffs.weights[oy * ksize..oy * ksize + n];
let out_row = &mut out[oy * row_stride..(oy + 1) * row_stride];
for ox in 0..out_w {
let mut acc = [ROUND_BIAS; CHANNELS];
for (i, &w) in taps.iter().enumerate() {
let base = (ymin + i) * row_stride + ox * CHANNELS;
let px = &inter[base..base + CHANNELS];
acc[0] += i32::from(px[0]) * w;
acc[1] += i32::from(px[1]) * w;
acc[2] += i32::from(px[2]) * w;
acc[3] += i32::from(px[3]) * w;
}
let o = &mut out_row[ox * CHANNELS..ox * CHANNELS + CHANNELS];
o[0] = clip8(acc[0]);
o[1] = clip8(acc[1]);
o[2] = clip8(acc[2]);
o[3] = clip8(acc[3]);
}
}
}
#[cfg(all(target_arch = "aarch64", not(mlxrs_force_scalar)))]
#[target_feature(enable = "neon")]
unsafe fn convolve_axis_neon(
src: &[u8],
src_w: usize,
rows: usize,
out: &mut [u8],
out_w: usize,
coeffs: &Coeffs,
) {
use std::arch::aarch64::*;
let ksize = coeffs.ksize;
for y in 0..rows {
let src_row = &src[y * src_w * CHANNELS..(y + 1) * src_w * CHANNELS];
let out_row = &mut out[y * out_w * CHANNELS..(y + 1) * out_w * CHANNELS];
for ox in 0..out_w {
let (xmin, n) = coeffs.bounds[ox];
let taps = &coeffs.weights[ox * ksize..ox * ksize + n];
let mut acc = vdupq_n_s32(ROUND_BIAS);
for (i, &w) in taps.iter().enumerate() {
let off = (xmin + i) * CHANNELS;
let px4 = [
src_row[off],
src_row[off + 1],
src_row[off + 2],
src_row[off + 3],
];
let v8 = unsafe { neon_load_rgba(px4) };
let v16 = vmovl_u8(v8); let v16lo = vget_low_u16(v16); let v32 = vreinterpretq_s32_u32(vmovl_u16(v16lo)); let wv = vdupq_n_s32(w);
acc = vmlaq_s32(acc, v32, wv);
}
let shifted = vshrq_n_s32::<{ PRECISION_BITS as i32 }>(acc);
let u16x4 = vqmovun_s32(shifted); let u16x8 = vcombine_u16(u16x4, vdup_n_u16(0));
let u8x8 = vqmovn_u16(u16x8); let o = &mut out_row[ox * CHANNELS..ox * CHANNELS + CHANNELS];
unsafe { neon_store_rgba(u8x8, o) };
}
}
}
#[cfg(all(target_arch = "aarch64", not(mlxrs_force_scalar)))]
#[target_feature(enable = "neon")]
unsafe fn neon_load_rgba(px4: [u8; CHANNELS]) -> std::arch::aarch64::uint8x8_t {
use std::arch::aarch64::*;
let buf = [px4[0], px4[1], px4[2], px4[3], 0, 0, 0, 0];
unsafe { vld1_u8(buf.as_ptr()) }
}
#[cfg(all(target_arch = "aarch64", not(mlxrs_force_scalar)))]
#[target_feature(enable = "neon")]
unsafe fn neon_store_rgba(v: std::arch::aarch64::uint8x8_t, out: &mut [u8]) {
use std::arch::aarch64::*;
assert_eq!(
out.len(),
CHANNELS,
"neon_store_rgba: out must be one RGBA pixel"
);
let mut tmp = [0u8; 8];
unsafe { vst1_u8(tmp.as_mut_ptr(), v) };
out.copy_from_slice(&tmp[..CHANNELS]);
}
#[cfg(all(target_arch = "aarch64", not(mlxrs_force_scalar)))]
#[target_feature(enable = "neon")]
unsafe fn convolve_vertical_neon(
inter: &[u8],
out_w: usize,
_src_h: usize,
out: &mut [u8],
out_h: usize,
coeffs: &Coeffs,
) {
use std::arch::aarch64::*;
let ksize = coeffs.ksize;
let row_stride = out_w * CHANNELS;
for oy in 0..out_h {
let (ymin, n) = coeffs.bounds[oy];
let taps = &coeffs.weights[oy * ksize..oy * ksize + n];
let out_row = &mut out[oy * row_stride..(oy + 1) * row_stride];
for ox in 0..out_w {
let mut acc = vdupq_n_s32(ROUND_BIAS);
for (i, &w) in taps.iter().enumerate() {
let base = (ymin + i) * row_stride + ox * CHANNELS;
let px4 = [
inter[base],
inter[base + 1],
inter[base + 2],
inter[base + 3],
];
let v8 = unsafe { neon_load_rgba(px4) };
let v16 = vmovl_u8(v8);
let v16lo = vget_low_u16(v16);
let v32 = vreinterpretq_s32_u32(vmovl_u16(v16lo));
let wv = vdupq_n_s32(w);
acc = vmlaq_s32(acc, v32, wv);
}
let shifted = vshrq_n_s32::<{ PRECISION_BITS as i32 }>(acc);
let u16x4 = vqmovun_s32(shifted);
let u16x8 = vcombine_u16(u16x4, vdup_n_u16(0));
let u8x8 = vqmovn_u16(u16x8);
let o = &mut out_row[ox * CHANNELS..ox * CHANNELS + CHANNELS];
unsafe { neon_store_rgba(u8x8, o) };
}
}
}
#[cfg(test)]
mod tests;