use numr::prelude::*;
fn main() -> Result<()> {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
#[rustfmt::skip]
let input_data: &[f32] = &[
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let input = Tensor::<CpuRuntime>::from_slice(input_data, &[1, 1, 4, 4], &device);
#[rustfmt::skip]
let kernel_data: &[f32] = &[
1.0, 0.0, -1.0,
1.0, 0.0, -1.0,
1.0, 0.0, -1.0,
];
let kernel = Tensor::<CpuRuntime>::from_slice(kernel_data, &[1, 1, 3, 3], &device);
let direct_out = client.conv2d(
&input,
&kernel,
None, (1, 1), PaddingMode::Valid, (1, 1), 1, )?;
println!("Direct conv2d output (shape {:?}):", direct_out.shape());
println!("{:?}\n", direct_out.to_vec::<f32>());
let unfolded_h = client.unfold(&input, 2, 3, 1)?;
let unfolded_hw = client.unfold(&unfolded_h, 3, 3, 1)?;
println!("Unfolded patches shape: {:?}", unfolded_hw.shape());
let out_h = unfolded_hw.shape()[2];
let out_w = unfolded_hw.shape()[3];
let k_h = unfolded_hw.shape()[4];
let k_w = unfolded_hw.shape()[5];
let patches = unfolded_hw
.contiguous()
.reshape(&[out_h * out_w, k_h * k_w])?;
let kernel_flat = kernel.reshape(&[1, k_h * k_w])?;
let kernel_col = kernel_flat.transpose(0, 1)?;
let im2col_flat = client.matmul(&patches, &kernel_col.contiguous())?;
let im2col_out = im2col_flat.reshape(&[1, 1, out_h, out_w])?;
println!("im2col conv output (shape {:?}):", im2col_out.shape());
println!("{:?}", im2col_out.to_vec::<f32>());
let direct_vec: Vec<f32> = direct_out.to_vec();
let im2col_vec: Vec<f32> = im2col_out.to_vec();
let max_diff: f32 = direct_vec
.iter()
.zip(im2col_vec.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
println!("\nMax difference between direct and im2col: {max_diff:.6e}");
assert!(max_diff < 1e-5, "Results should match within FP tolerance");
println!("\nConv/unfold im2col example completed successfully!");
Ok(())
}