use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::region::wrap_anonymous;
const OP_ID: &str = "vyre-libs::math::conv::conv2d_3x3_direct";
pub fn conv2d_3x3_direct(
input: &str,
kernel: &str,
output: &str,
h: u32,
w: u32,
) -> Result<Program, String> {
if h == 0 || w == 0 {
return Err("Fix: conv2d_3x3_direct requires non-zero height and width.".to_string());
}
let elements = h.checked_mul(w).ok_or_else(|| {
"Fix: conv2d_3x3_direct h*w overflows u32; reduce dimensions.".to_string()
})?;
let body = vec![
Node::let_bind("flat", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(Expr::var("flat"), Expr::u32(elements)),
vec![
Node::let_bind("y", Expr::div(Expr::var("flat"), Expr::u32(w))),
Node::let_bind("x", Expr::rem(Expr::var("flat"), Expr::u32(w))),
Node::let_bind("acc", Expr::f32(0.0)),
{
let mut taps: Vec<Node> = Vec::new();
for ky in 0..3u32 {
for kx in 0..3u32 {
let dy = (ky as i32) - 1;
let dx = (kx as i32) - 1;
let y_in_bounds = if dy < 0 {
Expr::ge(Expr::var("y"), Expr::u32((-dy) as u32))
} else {
Expr::lt(
Expr::add(Expr::var("y"), Expr::u32(dy as u32)),
Expr::u32(h),
)
};
let x_in_bounds = if dx < 0 {
Expr::ge(Expr::var("x"), Expr::u32((-dx) as u32))
} else {
Expr::lt(
Expr::add(Expr::var("x"), Expr::u32(dx as u32)),
Expr::u32(w),
)
};
let ny = if dy < 0 {
Expr::sub(Expr::var("y"), Expr::u32((-dy) as u32))
} else if dy > 0 {
Expr::add(Expr::var("y"), Expr::u32(dy as u32))
} else {
Expr::var("y")
};
let nx = if dx < 0 {
Expr::sub(Expr::var("x"), Expr::u32((-dx) as u32))
} else if dx > 0 {
Expr::add(Expr::var("x"), Expr::u32(dx as u32))
} else {
Expr::var("x")
};
let load_idx = Expr::add(Expr::mul(ny, Expr::u32(w)), nx);
let kernel_val = Expr::load(kernel, Expr::u32(ky * 3 + kx));
let in_bounds = Expr::and(y_in_bounds, x_in_bounds);
let tap = Expr::select(
in_bounds,
Expr::mul(Expr::load(input, load_idx), kernel_val),
Expr::f32(0.0),
);
taps.push(Node::assign("acc", Expr::add(Expr::var("acc"), tap)));
}
}
Node::Block(taps)
},
Node::store(output, Expr::var("flat"), Expr::var("acc")),
],
),
];
Ok(Program::wrapped(
vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::F32)
.with_count(elements),
BufferDecl::storage(kernel, 1, BufferAccess::ReadOnly, DataType::F32).with_count(9),
BufferDecl::output(output, 2, DataType::F32).with_count(elements),
],
[64, 1, 1],
vec![wrap_anonymous(OP_ID, body)],
))
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || {
conv2d_3x3_direct("input", "kernel", "output", 4, 4).unwrap_or_else(|error| {
crate::builder::invalid_output_program(
OP_ID,
"output",
DataType::F32,
error,
)
})
},
test_inputs: Some(|| {
let input = crate::test_support::byte_pack::f32_bytes(&[
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
]);
let kernel = crate::test_support::byte_pack::f32_bytes(&[1.0; 9]);
vec![vec![input, kernel]]
}),
expected_output: Some(|| {
vec![vec![crate::test_support::byte_pack::f32_bytes(&[
2.0, 2.0, 1.0, 0.0,
2.0, 3.0, 2.0, 1.0,
1.0, 2.0, 3.0, 2.0,
0.0, 1.0, 2.0, 2.0,
])]]
}),
category: Some("math"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::f32_bytes;
use vyre_reference::value::Value;
fn decode(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes(c.try_into().unwrap()))
.collect()
}
fn naive_conv2d_3x3(input: &[f32], kernel: &[f32], h: usize, w: usize) -> Vec<f32> {
let mut out = vec![0.0_f32; h * w];
for y in 0..h {
for x in 0..w {
let mut acc = 0.0_f32;
for ky in 0..3usize {
for kx in 0..3usize {
let ny = (y as i32) + (ky as i32) - 1;
let nx = (x as i32) + (kx as i32) - 1;
if ny < 0 || ny >= h as i32 || nx < 0 || nx >= w as i32 {
continue;
}
let pixel = input[(ny as usize) * w + (nx as usize)];
let k = kernel[ky * 3 + kx];
acc += pixel * k;
}
}
out[y * w + x] = acc;
}
}
out
}
fn run(h: u32, w: u32, input: &[f32], kernel: &[f32]) -> Vec<f32> {
let prog = conv2d_3x3_direct("input", "kernel", "output", h, w).expect("Fix: build");
let outputs = vyre_reference::reference_eval(
&prog,
&[
Value::from(f32_bytes(input)),
Value::from(f32_bytes(kernel)),
],
)
.expect("Fix: conv2d_3x3_direct must execute in the reference interpreter.");
decode(&outputs[0].to_bytes())
}
#[test]
fn conv2d_identity_box_matches_naive() {
let input = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
];
let kernel = vec![1.0; 9];
let actual = run(4, 4, &input, &kernel);
let expected = naive_conv2d_3x3(&input, &kernel, 4, 4);
assert_eq!(actual.len(), 16);
for (a, e) in actual.iter().zip(expected.iter()) {
assert!((a - e).abs() <= 1.0e-5, "{a} != {e}");
}
}
#[test]
fn conv2d_identity_kernel_passes_input_through() {
let input: Vec<f32> = (0..16).map(|i| i as f32 - 7.5).collect();
let kernel = vec![0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
let actual = run(4, 4, &input, &kernel);
for (a, e) in actual.iter().zip(input.iter()) {
assert!((a - e).abs() <= 1.0e-5, "{a} != {e}");
}
}
#[test]
fn conv2d_matches_naive_on_random_fuzz() {
let mut state = 0xDEADC0DE_u64;
let mut next = || {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32 / (u32::MAX as f32 / 2.0)) - 1.0
};
for _ in 0..30 {
let input: Vec<f32> = (0..25).map(|_| next()).collect();
let kernel: Vec<f32> = (0..9).map(|_| next()).collect();
let actual = run(5, 5, &input, &kernel);
let expected = naive_conv2d_3x3(&input, &kernel, 5, 5);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() <= 1.0e-4,
"lane {i}: direct={a} naive={e} diff={}",
(a - e).abs()
);
}
}
}
#[test]
fn conv2d_1x1_image() {
let input = vec![5.0_f32];
let kernel = vec![0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
let actual = run(1, 1, &input, &kernel);
assert_eq!(actual.len(), 1);
assert!(
(actual[0] - 5.0).abs() <= 1.0e-5,
"1x1 conv with identity kernel = 5.0, got {}",
actual[0]
);
}
#[test]
fn conv2d_nan_input_propagates() {
let input = vec![f32::NAN; 16];
let kernel = vec![1.0_f32; 9];
let actual = run(4, 4, &input, &kernel);
for (i, &v) in actual.iter().enumerate() {
assert!(
v.is_nan(),
"conv2d output[{i}] must be NaN when input is NaN"
);
}
}
#[test]
fn conv2d_inf_input_propagates() {
let input = vec![f32::INFINITY; 16];
let kernel = vec![1.0_f32; 9];
let actual = run(4, 4, &input, &kernel);
for (i, &v) in actual.iter().enumerate() {
assert!(
v.is_infinite(),
"conv2d output[{i}] must be Inf when input is Inf"
);
}
}
#[test]
fn conv2d_zero_dimensions_should_error() {
let err = conv2d_3x3_direct("input", "kernel", "output", 0, 0)
.expect_err("0x0 conv2d must error instead of returning empty program");
assert!(
err.contains("non-zero height and width"),
"0x0 conv2d error must name the dimension contract: {err}"
);
}
}