use alloc::vec;
use alloc::vec::Vec;
use burn_backend::DType;
use burn_backend::ops::ConvOptions;
use burn_backend::ops::conv::calculate_conv_output_size;
use burn_std::{Bytes, Shape, f16};
use crate::{FlexTensor, Layout};
use super::conv_common::{add_bias, squeeze_3d_to_1d, squeeze_3d_to_2d};
macro_rules! conv3d_1x1_typed {
($fn_name:ident, $T:ty, $dtype:expr, $zero:expr, $one:expr, $add_fn:expr) => {
fn $fn_name(
x: FlexTensor,
weight: FlexTensor,
bias: Option<FlexTensor>,
options: &ConvOptions<3>,
) -> FlexTensor {
conv3d_1x1_impl::<$T>(x, weight, bias, options, $dtype, $zero, $one, $add_fn)
}
};
}
macro_rules! conv3d_typed {
($fn_name:ident, $T:ty, $dtype:expr, $zero:expr, $gemm_fn:ident, $add_fn:expr, $fn_1x1:ident, $fn_depthwise:ident, $fn_small_channel:ident $(, $fn_direct:ident)?) => {
pub fn $fn_name(
x: FlexTensor,
weight: FlexTensor,
bias: Option<FlexTensor>,
options: &ConvOptions<3>,
) -> FlexTensor {
let w_shape = weight.layout().shape();
if is_1x1_conv(w_shape[2], w_shape[3], w_shape[4], options) {
return $fn_1x1(x, weight, bias, options);
}
let x_shape = x.layout().shape();
if should_use_depthwise_conv(x_shape, w_shape, options) {
return $fn_depthwise(x, weight, bias, options);
}
if should_use_small_channel_conv(x_shape, w_shape, options) {
return $fn_small_channel(x, weight, bias, options);
}
$(
if should_use_direct_conv(x_shape, w_shape, options) {
return $fn_direct(x, weight, bias, options);
}
)?
conv3d_impl::<$T>(x, weight, bias, options, $dtype, $zero, $gemm_fn, $add_fn)
}
};
}
conv_nd_via_3d!(
conv1d_f32,
conv3d_f32,
expand_1d_to_3d,
squeeze_3d_to_1d,
1,
ConvOptions
);
conv_nd_via_3d!(
conv1d_f64,
conv3d_f64,
expand_1d_to_3d,
squeeze_3d_to_1d,
1,
ConvOptions
);
conv_nd_via_3d!(
conv1d_f16,
conv3d_f16,
expand_1d_to_3d,
squeeze_3d_to_1d,
1,
ConvOptions
);
bf16_via_f32!(conv1d_bf16, conv1d_f32, 1, ConvOptions);
fn expand_1d_to_3d(
x: &FlexTensor,
weight: &FlexTensor,
options: &ConvOptions<1>,
) -> (FlexTensor, FlexTensor, ConvOptions<3>) {
let x_shape = x.layout().shape();
let x_3d = x.reshape(Shape::from(vec![x_shape[0], x_shape[1], 1, 1, x_shape[2]]));
let w_shape = weight.layout().shape();
let weight_3d = weight.reshape(Shape::from(vec![w_shape[0], w_shape[1], 1, 1, w_shape[2]]));
let options_3d = ConvOptions::new(
[1, 1, options.stride[0]],
[0, 0, options.padding[0]],
[1, 1, options.dilation[0]],
options.groups,
);
(x_3d, weight_3d, options_3d)
}
conv_nd_via_3d!(
conv2d_f32,
conv3d_f32,
expand_2d_to_3d,
squeeze_3d_to_2d,
2,
ConvOptions
);
conv_nd_via_3d!(
conv2d_f64,
conv3d_f64,
expand_2d_to_3d,
squeeze_3d_to_2d,
2,
ConvOptions
);
conv_nd_via_3d!(
conv2d_f16,
conv3d_f16,
expand_2d_to_3d,
squeeze_3d_to_2d,
2,
ConvOptions
);
bf16_via_f32!(conv2d_bf16, conv2d_f32, 2, ConvOptions);
fn expand_2d_to_3d(
x: &FlexTensor,
weight: &FlexTensor,
options: &ConvOptions<2>,
) -> (FlexTensor, FlexTensor, ConvOptions<3>) {
let x_shape = x.layout().shape();
let x_3d = x.reshape(Shape::from(vec![
x_shape[0], x_shape[1], 1, x_shape[2], x_shape[3],
]));
let w_shape = weight.layout().shape();
let weight_3d = weight.reshape(Shape::from(vec![
w_shape[0], w_shape[1], 1, w_shape[2], w_shape[3],
]));
let options_3d = ConvOptions::new(
[1, options.stride[0], options.stride[1]],
[0, options.padding[0], options.padding[1]],
[1, options.dilation[0], options.dilation[1]],
options.groups,
);
(x_3d, weight_3d, options_3d)
}
conv3d_typed!(
conv3d_f32,
f32,
DType::F32,
0.0f32,
gemm_f32,
|a, b| a + b,
conv3d_1x1_f32,
conv3d_depthwise_f32,
conv3d_small_channel_f32,
conv3d_direct_f32
);
conv3d_typed!(
conv3d_f64,
f64,
DType::F64,
0.0f64,
gemm_f64,
|a, b| a + b,
conv3d_1x1_f64,
conv3d_depthwise_f64,
conv3d_small_channel_f64,
conv3d_direct_f64
);
conv3d_typed!(
conv3d_f16,
f16,
DType::F16,
f16::from_f32(0.0),
gemm_f16,
|a: f16, b: f16| f16::from_f32(a.to_f32() + b.to_f32()),
conv3d_1x1_f16,
conv3d_depthwise_f16,
conv3d_small_channel_f16
);
bf16_via_f32!(conv3d_bf16, conv3d_f32, 3, ConvOptions);
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
fn conv3d_impl<T: bytemuck::Pod + Clone + Copy + burn_backend::Element + Send + Sync>(
x: FlexTensor,
weight: FlexTensor,
bias: Option<FlexTensor>,
options: &ConvOptions<3>,
dtype: DType,
zero: T,
gemm_fn: fn(&[T], &[T], usize, usize, usize) -> Vec<T>,
add_fn: fn(T, T) -> T,
) -> FlexTensor {
let x = x.to_contiguous();
let weight = weight.to_contiguous();
let x_shape = x.layout().shape();
let w_shape = weight.layout().shape();
let batch_size = x_shape[0];
let channels_in = x_shape[1];
let in_d = x_shape[2];
let in_h = x_shape[3];
let in_w = x_shape[4];
let channels_out = w_shape[0];
let channels_per_group = w_shape[1];
let kernel_d = w_shape[2];
let kernel_h = w_shape[3];
let kernel_w = w_shape[4];
let [stride_d, stride_h, stride_w] = options.stride;
let [pad_d, pad_h, pad_w] = options.padding;
let groups = options.groups;
let out_channels_per_group = channels_out / groups;
let out_d = calculate_conv_output_size(kernel_d, stride_d, pad_d, options.dilation[0], in_d);
let out_h = calculate_conv_output_size(kernel_h, stride_h, pad_h, options.dilation[1], in_h);
let out_w = calculate_conv_output_size(kernel_w, stride_w, pad_w, options.dilation[2], in_w);
let _total = [batch_size, channels_out, out_d, out_h, out_w]
.iter()
.try_fold(1usize, |acc, &x| acc.checked_mul(x))
.expect("conv: output tensor dimensions would overflow index calculations");
let _col_total = [channels_per_group, kernel_d, kernel_h, kernel_w]
.iter()
.try_fold(1usize, |acc, &x| acc.checked_mul(x))
.expect("conv: kernel dimensions would overflow index calculations");
let x_data: &[T] = x.storage();
let w_data: &[T] = weight.storage();
let col_len = channels_per_group * kernel_d * kernel_h * kernel_w;
let spatial_out = out_d * out_h * out_w;
let [dilation_d, dilation_h, dilation_w] = options.dilation;
const TILE_SIZE: usize = 512;
let num_tiles = spatial_out.div_ceil(TILE_SIZE);
let k_spatial = kernel_d * kernel_h * kernel_w;
let mut w_flat = vec![zero; channels_out * col_len];
for c_out in 0..channels_out {
let src_base = c_out * channels_per_group * k_spatial;
let dst_base = c_out * col_len;
for c_in in 0..channels_per_group {
let src_row = src_base + c_in * k_spatial;
for k in 0..k_spatial {
w_flat[dst_base + k * channels_per_group + c_in] = w_data[src_row + k];
}
}
}
let nhwc_stride = (
in_d * in_h * in_w * channels_in,
in_h * in_w * channels_in,
in_w * channels_in,
channels_in,
1,
);
let mut x_nhwc = vec![zero; batch_size * in_d * in_h * in_w * channels_in];
for b in 0..batch_size {
for d in 0..in_d {
for h in 0..in_h {
for w in 0..in_w {
for c in 0..channels_in {
let src_idx = b * channels_in * in_d * in_h * in_w
+ c * in_d * in_h * in_w
+ d * in_h * in_w
+ h * in_w
+ w;
let dst_idx = b * nhwc_stride.0
+ d * nhwc_stride.1
+ h * nhwc_stride.2
+ w * nhwc_stride.3
+ c;
x_nhwc[dst_idx] = x_data[src_idx];
}
}
}
}
}
let output = {
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
let mut dst = vec![zero; batch_size * channels_out * spatial_out];
let dst_ptr = crate::ops::SendMutPtr::new(dst.as_mut_ptr());
(0..batch_size).into_par_iter().for_each(|b| {
(0..num_tiles).into_par_iter().for_each(|tile_idx| {
let tile_start = tile_idx * TILE_SIZE;
let tile_end = (tile_start + TILE_SIZE).min(spatial_out);
let tile_size = tile_end - tile_start;
for g in 0..groups {
let in_c_start = g * channels_per_group;
let out_c_start = g * out_channels_per_group;
let mut col_tile = vec![zero; col_len * tile_size];
im2col_3d_tile(
&mut col_tile,
&x_nhwc,
tile_start,
tile_end,
out_h,
out_w,
kernel_d,
kernel_h,
kernel_w,
stride_d,
stride_h,
stride_w,
dilation_d,
dilation_h,
dilation_w,
pad_d,
pad_h,
pad_w,
in_d,
in_h,
in_w,
channels_per_group,
col_len,
b,
in_c_start,
nhwc_stride,
);
let w_start = out_c_start * col_len;
let w_end = w_start + out_channels_per_group * col_len;
let w_group = &w_flat[w_start..w_end];
let result = gemm_fn(
w_group,
&col_tile,
out_channels_per_group,
col_len,
tile_size,
);
for (local_idx, global_idx) in (tile_start..tile_end).enumerate() {
for c_out in 0..out_channels_per_group {
let dst_idx = b * channels_out * spatial_out
+ (out_c_start + c_out) * spatial_out
+ global_idx;
let res_idx = c_out * tile_size + local_idx;
unsafe {
debug_assert!(
dst_idx < batch_size * channels_out * spatial_out
);
dst_ptr.write(dst_idx, result[res_idx]);
}
}
}
}
});
});
dst
}
#[cfg(not(feature = "rayon"))]
{
let mut output = vec![zero; batch_size * channels_out * spatial_out];
for b in 0..batch_size {
for tile_idx in 0..num_tiles {
let tile_start = tile_idx * TILE_SIZE;
let tile_end = (tile_start + TILE_SIZE).min(spatial_out);
let tile_size = tile_end - tile_start;
for g in 0..groups {
let in_c_start = g * channels_per_group;
let out_c_start = g * out_channels_per_group;
let mut col_tile = vec![zero; col_len * tile_size];
im2col_3d_tile(
&mut col_tile,
&x_nhwc,
tile_start,
tile_end,
out_h,
out_w,
kernel_d,
kernel_h,
kernel_w,
stride_d,
stride_h,
stride_w,
dilation_d,
dilation_h,
dilation_w,
pad_d,
pad_h,
pad_w,
in_d,
in_h,
in_w,
channels_per_group,
col_len,
b,
in_c_start,
nhwc_stride,
);
let w_start = out_c_start * col_len;
let w_end = w_start + out_channels_per_group * col_len;
let w_group = &w_flat[w_start..w_end];
let result = gemm_fn(
w_group,
&col_tile,
out_channels_per_group,
col_len,
tile_size,
);
for (local_idx, global_idx) in (tile_start..tile_end).enumerate() {
for c_out in 0..out_channels_per_group {
let dst_idx = b * channels_out * spatial_out
+ (out_c_start + c_out) * spatial_out
+ global_idx;
let res_idx = c_out * tile_size + local_idx;
output[dst_idx] = result[res_idx];
}
}
}
}
}
output
}
};
let mut output = output;
if let Some(bias) = bias {
let bias = bias.to_contiguous();
let bias_data: &[T] = bias.storage();
add_bias(
&mut output,
bias_data,
batch_size,
channels_out,
spatial_out,
add_fn,
);
}
let out_shape = Shape::from(vec![batch_size, channels_out, out_d, out_h, out_w]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
dtype,
)
}
#[allow(clippy::too_many_arguments)]
fn im2col_3d_tile<T: bytemuck::Pod + Copy>(
col_tile: &mut [T],
x_nhwc: &[T],
tile_start: usize,
tile_end: usize,
out_h: usize,
out_w: usize,
kernel_d: usize,
kernel_h: usize,
kernel_w: usize,
stride_d: usize,
stride_h: usize,
stride_w: usize,
dilation_d: usize,
dilation_h: usize,
dilation_w: usize,
pad_d: usize,
pad_h: usize,
pad_w: usize,
in_d: usize,
in_h: usize,
in_w: usize,
channels_per_group: usize,
col_len: usize,
b: usize,
in_c_start: usize,
nhwc_stride: (usize, usize, usize, usize, usize),
) {
for (local_idx, global_idx) in (tile_start..tile_end).enumerate() {
let out_d_idx = global_idx / (out_h * out_w);
let rem = global_idx % (out_h * out_w);
let out_h_idx = rem / out_w;
let out_w_idx = rem % out_w;
let mut col_offset = 0;
for kd in 0..kernel_d {
let id = (out_d_idx * stride_d + kd * dilation_d) as isize - pad_d as isize;
for kh in 0..kernel_h {
let ih = (out_h_idx * stride_h + kh * dilation_h) as isize - pad_h as isize;
for kw in 0..kernel_w {
let iw = (out_w_idx * stride_w + kw * dilation_w) as isize - pad_w as isize;
if id >= 0
&& id < in_d as isize
&& ih >= 0
&& ih < in_h as isize
&& iw >= 0
&& iw < in_w as isize
{
let id = id as usize;
let ih = ih as usize;
let iw = iw as usize;
let inp_base = b * nhwc_stride.0
+ id * nhwc_stride.1
+ ih * nhwc_stride.2
+ iw * nhwc_stride.3
+ in_c_start;
for c in 0..channels_per_group {
col_tile[local_idx * col_len + col_offset] = x_nhwc[inp_base + c];
col_offset += 1;
}
} else {
col_offset += channels_per_group;
}
}
}
}
}
}
fn is_1x1_conv(
kernel_d: usize,
kernel_h: usize,
kernel_w: usize,
options: &ConvOptions<3>,
) -> bool {
kernel_d == 1
&& kernel_h == 1
&& kernel_w == 1
&& options.stride == [1, 1, 1]
&& options.padding == [0, 0, 0]
}
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
fn conv3d_1x1_impl<T: bytemuck::Pod + Clone + Copy + burn_backend::Element + Send + Sync>(
x: FlexTensor,
weight: FlexTensor,
bias: Option<FlexTensor>,
options: &ConvOptions<3>,
dtype: DType,
zero: T,
one: T,
add_fn: fn(T, T) -> T,
) -> FlexTensor {
let x = x.to_contiguous();
let weight = weight.to_contiguous();
let x_shape = x.layout().shape();
let w_shape = weight.layout().shape();
let batch_size = x_shape[0];
let channels_in = x_shape[1];
let spatial = x_shape[2] * x_shape[3] * x_shape[4];
let channels_out = w_shape[0];
let channels_per_group = w_shape[1];
let groups = options.groups;
let out_channels_per_group = channels_out / groups;
let m = out_channels_per_group;
let k = channels_per_group;
let n = spatial;
let total_output = [batch_size, channels_out, spatial]
.iter()
.try_fold(1usize, |acc, &x| acc.checked_mul(x))
.expect("conv 1x1: output tensor dimensions would overflow index calculations");
let x_data: &[T] = x.storage();
let w_data: &[T] = weight.storage();
#[cfg(feature = "rayon")]
let parallelism = if m.saturating_mul(n).saturating_mul(k) >= 192 * 192 * 192 {
gemm::Parallelism::Rayon(0)
} else {
gemm::Parallelism::None
};
#[cfg(not(feature = "rayon"))]
let parallelism = gemm::Parallelism::None;
let mut output = vec![zero; total_output];
{
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
let dst_ptr = crate::ops::SendMutPtr::new(output.as_mut_ptr());
(0..batch_size).into_par_iter().for_each(|b| {
for g in 0..groups {
let x_offset = b * channels_in * spatial + g * k * spatial;
let w_offset = g * out_channels_per_group * k;
let out_offset =
b * channels_out * spatial + g * out_channels_per_group * spatial;
unsafe {
gemm::gemm(
m,
n,
k,
dst_ptr.ptr_add(out_offset),
1, n as isize, false, w_data.as_ptr().add(w_offset),
1, k as isize, x_data.as_ptr().add(x_offset),
1, n as isize, zero,
one,
false,
false,
false,
parallelism,
);
}
}
});
}
#[cfg(not(feature = "rayon"))]
{
for b in 0..batch_size {
for g in 0..groups {
let x_offset = b * channels_in * spatial + g * k * spatial;
let w_offset = g * out_channels_per_group * k;
let out_offset =
b * channels_out * spatial + g * out_channels_per_group * spatial;
unsafe {
gemm::gemm(
m,
n,
k,
output.as_mut_ptr().add(out_offset),
1,
n as isize,
false,
w_data.as_ptr().add(w_offset),
1,
k as isize,
x_data.as_ptr().add(x_offset),
1,
n as isize,
zero,
one,
false,
false,
false,
parallelism,
);
}
}
}
}
}
if let Some(bias) = bias {
let bias = bias.to_contiguous();
let bias_data: &[T] = bias.storage();
add_bias(
&mut output,
bias_data,
batch_size,
channels_out,
spatial,
add_fn,
);
}
let out_shape = Shape::from(vec![
batch_size,
channels_out,
x_shape[2],
x_shape[3],
x_shape[4],
]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
dtype,
)
}
conv3d_1x1_typed!(conv3d_1x1_f32, f32, DType::F32, 0.0f32, 1.0f32, |a, b| a
+ b);
conv3d_1x1_typed!(conv3d_1x1_f64, f64, DType::F64, 0.0f64, 1.0f64, |a, b| a
+ b);
conv3d_1x1_typed!(
conv3d_1x1_f16,
f16,
DType::F16,
f16::from_f32(0.0),
f16::from_f32(1.0),
|a: f16, b: f16| f16::from_f32(a.to_f32() + b.to_f32())
);
fn should_use_depthwise_conv(
x_shape: &[usize],
w_shape: &[usize],
options: &ConvOptions<3>,
) -> bool {
let channels_in = x_shape[1];
let channels_out = w_shape[0];
let channels_per_group = w_shape[1];
let groups = options.groups;
if channels_per_group != 1 || groups != channels_in || channels_out != channels_in {
return false;
}
if w_shape[2] != 1 || x_shape[2] != 1 {
return false;
}
if options.stride[0] != 1 || options.padding[0] != 0 || options.dilation[0] != 1 {
return false;
}
true
}
#[inline]
fn valid_out_range(
k: usize,
dilation: usize,
pad: usize,
stride: usize,
in_size: usize,
out_size: usize,
) -> (usize, usize) {
debug_assert!(stride >= 1, "stride must be >= 1");
let offset = k * dilation;
let out_start = if offset >= pad {
0
} else {
(pad - offset).div_ceil(stride)
};
let threshold = in_size + pad;
let out_end = if offset >= threshold {
0
} else {
(threshold - offset).div_ceil(stride)
};
let out_end = out_end.min(out_size);
let out_start = out_start.min(out_end);
(out_start, out_end)
}
const CONV_PLANE_OH_OUTER_THRESHOLD: usize = 8192;
#[inline]
#[allow(clippy::too_many_arguments)]
fn conv_plane_accumulate<T: num_traits::Float + Copy>(
out_plane: &mut [T],
in_plane: &[T],
w_plane: &[T],
kernel_h: usize,
kernel_w: usize,
in_w: usize,
out_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
dilation_h: usize,
dilation_w: usize,
oh_ranges: &[(usize, usize)],
ow_ranges: &[(usize, usize)],
) {
if out_plane.len() > CONV_PLANE_OH_OUTER_THRESHOLD {
conv_plane_accumulate_oh_outer(
out_plane, in_plane, w_plane, kernel_h, kernel_w, in_w, out_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, oh_ranges, ow_ranges,
);
} else {
conv_plane_accumulate_kh_outer(
out_plane, in_plane, w_plane, kernel_h, kernel_w, in_w, out_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, oh_ranges, ow_ranges,
);
}
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
fn conv_plane_accumulate_oh_outer<T: num_traits::Float + Copy>(
out_plane: &mut [T],
in_plane: &[T],
w_plane: &[T],
kernel_h: usize,
kernel_w: usize,
in_w: usize,
out_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
dilation_h: usize,
dilation_w: usize,
oh_ranges: &[(usize, usize)],
ow_ranges: &[(usize, usize)],
) {
if out_plane.is_empty() || out_w == 0 {
return;
}
debug_assert_eq!(
out_plane.len() % out_w,
0,
"out_plane length must be a whole number of rows"
);
let out_h = out_plane.len() / out_w;
for oh in 0..out_h {
let out_row = &mut out_plane[oh * out_w..(oh + 1) * out_w];
for kh in 0..kernel_h {
let (oh_start, oh_end) = oh_ranges[kh];
if oh < oh_start || oh >= oh_end {
continue;
}
let ih = oh * stride_h + kh * dilation_h - pad_h;
let in_row = &in_plane[ih * in_w..(ih + 1) * in_w];
for kw in 0..kernel_w {
let (ow_start, ow_end) = ow_ranges[kw];
if ow_start >= ow_end {
continue;
}
let w_val = w_plane[kh * kernel_w + kw];
let iw_start = ow_start * stride_w + kw * dilation_w - pad_w;
if stride_w == 1 {
let run_len = ow_end - ow_start;
let in_slice = &in_row[iw_start..iw_start + run_len];
let out_slice = &mut out_row[ow_start..ow_end];
for (o, &xv) in out_slice.iter_mut().zip(in_slice.iter()) {
*o = *o + w_val * xv;
}
} else {
let mut iw = iw_start;
for o in &mut out_row[ow_start..ow_end] {
*o = *o + w_val * in_row[iw];
iw += stride_w;
}
}
}
}
}
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
fn conv_plane_accumulate_kh_outer<T: num_traits::Float + Copy>(
out_plane: &mut [T],
in_plane: &[T],
w_plane: &[T],
kernel_h: usize,
kernel_w: usize,
in_w: usize,
out_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
dilation_h: usize,
dilation_w: usize,
oh_ranges: &[(usize, usize)],
ow_ranges: &[(usize, usize)],
) {
for kh in 0..kernel_h {
let (oh_start, oh_end) = oh_ranges[kh];
if oh_start >= oh_end {
continue;
}
for kw in 0..kernel_w {
let (ow_start, ow_end) = ow_ranges[kw];
if ow_start >= ow_end {
continue;
}
let w_val = w_plane[kh * kernel_w + kw];
let iw_start = ow_start * stride_w + kw * dilation_w - pad_w;
let run_len = ow_end - ow_start;
for oh in oh_start..oh_end {
let ih = oh * stride_h + kh * dilation_h - pad_h;
let in_row = &in_plane[ih * in_w..(ih + 1) * in_w];
let out_row = &mut out_plane[oh * out_w..(oh + 1) * out_w];
if stride_w == 1 {
let in_slice = &in_row[iw_start..iw_start + run_len];
let out_slice = &mut out_row[ow_start..ow_end];
for (o, &xv) in out_slice.iter_mut().zip(in_slice.iter()) {
*o = *o + w_val * xv;
}
} else {
let mut iw = iw_start;
for o in &mut out_row[ow_start..ow_end] {
*o = *o + w_val * in_row[iw];
iw += stride_w;
}
}
}
}
}
}
macro_rules! conv3d_depthwise_typed {
($fn_name:ident, $T:ty, $dtype:expr) => {
fn $fn_name(
x: FlexTensor,
weight: FlexTensor,
bias: Option<FlexTensor>,
options: &ConvOptions<3>,
) -> FlexTensor {
conv3d_depthwise_impl::<$T>(x, weight, bias, options, $dtype)
}
};
}
conv3d_depthwise_typed!(conv3d_depthwise_f32, f32, DType::F32);
conv3d_depthwise_typed!(conv3d_depthwise_f64, f64, DType::F64);
conv3d_depthwise_typed!(conv3d_depthwise_f16, f16, DType::F16);
fn conv3d_depthwise_impl<T>(
x: FlexTensor,
weight: FlexTensor,
bias: Option<FlexTensor>,
options: &ConvOptions<3>,
dtype: DType,
) -> FlexTensor
where
T: num_traits::Float + bytemuck::Pod + Clone + Copy + burn_backend::Element + Send + Sync,
{
let zero = <T as num_traits::Zero>::zero();
let x = x.to_contiguous();
let weight = weight.to_contiguous();
let x_shape = x.layout().shape();
let w_shape = weight.layout().shape();
let batch_size = x_shape[0];
let channels = x_shape[1];
let in_h = x_shape[3];
let in_w = x_shape[4];
let kernel_h = w_shape[3];
let kernel_w = w_shape[4];
let [_, stride_h, stride_w] = options.stride;
let [_, pad_h, pad_w] = options.padding;
let [_, dilation_h, dilation_w] = options.dilation;
let out_h = calculate_conv_output_size(kernel_h, stride_h, pad_h, dilation_h, in_h);
let out_w = calculate_conv_output_size(kernel_w, stride_w, pad_w, dilation_w, in_w);
let total = [batch_size, channels, out_h, out_w]
.iter()
.try_fold(1usize, |acc, &x| acc.checked_mul(x))
.expect("conv depthwise: output dimensions would overflow");
let x_data: &[T] = x.storage();
let w_data: &[T] = weight.storage();
let in_spatial = in_h * in_w;
let out_spatial = out_h * out_w;
let k_spatial = kernel_h * kernel_w;
let oh_ranges: Vec<(usize, usize)> = (0..kernel_h)
.map(|kh| valid_out_range(kh, dilation_h, pad_h, stride_h, in_h, out_h))
.collect();
let ow_ranges: Vec<(usize, usize)> = (0..kernel_w)
.map(|kw| valid_out_range(kw, dilation_w, pad_w, stride_w, in_w, out_w))
.collect();
let mut output = vec![zero; total];
let plane_work = |bc: usize, out_plane: &mut [T]| {
let c = bc % channels;
let in_base = bc * in_spatial;
let w_base = c * k_spatial;
conv_plane_accumulate(
out_plane,
&x_data[in_base..in_base + in_spatial],
&w_data[w_base..w_base + k_spatial],
kernel_h,
kernel_w,
in_w,
out_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
&oh_ranges,
&ow_ranges,
);
};
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
let dst_ptr = crate::ops::SendMutPtr::new(output.as_mut_ptr());
(0..batch_size * channels).into_par_iter().for_each(|bc| {
let out_plane: &mut [T] = unsafe {
core::slice::from_raw_parts_mut(dst_ptr.ptr_add(bc * out_spatial), out_spatial)
};
plane_work(bc, out_plane);
});
}
#[cfg(not(feature = "rayon"))]
{
for bc in 0..batch_size * channels {
let out_base = bc * out_spatial;
plane_work(bc, &mut output[out_base..out_base + out_spatial]);
}
}
if let Some(bias) = bias {
let bias = bias.to_contiguous();
let bias_data: &[T] = bias.storage();
assert_eq!(
bias_data.len(),
channels,
"conv depthwise: bias length ({}) must equal channels ({channels})",
bias_data.len()
);
add_bias(
&mut output,
bias_data,
batch_size,
channels,
out_spatial,
|a, b| a + b,
);
}
let out_shape = Shape::from(vec![batch_size, channels, 1, out_h, out_w]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
dtype,
)
}
const SMALL_CHANNEL_IN_THRESHOLD: usize = 4;
const SMALL_CHANNEL_OUT_THRESHOLD: usize = 16;
fn should_use_small_channel_conv(
x_shape: &[usize],
w_shape: &[usize],
options: &ConvOptions<3>,
) -> bool {
if options.groups != 1 {
return false;
}
if w_shape[2] != 1 || x_shape[2] != 1 {
return false;
}
if options.stride[0] != 1 || options.padding[0] != 0 || options.dilation[0] != 1 {
return false;
}
let channels_in = x_shape[1];
let channels_out = w_shape[0];
channels_in > 0
&& channels_in <= SMALL_CHANNEL_IN_THRESHOLD
&& channels_out > 0
&& channels_out <= SMALL_CHANNEL_OUT_THRESHOLD
}
macro_rules! conv3d_small_channel_typed {
($fn_name:ident, $T:ty, $dtype:expr) => {
fn $fn_name(
x: FlexTensor,
weight: FlexTensor,
bias: Option<FlexTensor>,
options: &ConvOptions<3>,
) -> FlexTensor {
conv3d_small_channel_impl::<$T>(x, weight, bias, options, $dtype)
}
};
}
conv3d_small_channel_typed!(conv3d_small_channel_f32, f32, DType::F32);
conv3d_small_channel_typed!(conv3d_small_channel_f64, f64, DType::F64);
conv3d_small_channel_typed!(conv3d_small_channel_f16, f16, DType::F16);
fn conv3d_small_channel_impl<T>(
x: FlexTensor,
weight: FlexTensor,
bias: Option<FlexTensor>,
options: &ConvOptions<3>,
dtype: DType,
) -> FlexTensor
where
T: num_traits::Float + bytemuck::Pod + Clone + Copy + burn_backend::Element + Send + Sync,
{
let zero = <T as num_traits::Zero>::zero();
let x = x.to_contiguous();
let weight = weight.to_contiguous();
let x_shape = x.layout().shape();
let w_shape = weight.layout().shape();
let batch_size = x_shape[0];
let channels_in = x_shape[1];
let in_h = x_shape[3];
let in_w = x_shape[4];
let channels_out = w_shape[0];
let kernel_h = w_shape[3];
let kernel_w = w_shape[4];
let [_, stride_h, stride_w] = options.stride;
let [_, pad_h, pad_w] = options.padding;
let [_, dilation_h, dilation_w] = options.dilation;
let out_h = calculate_conv_output_size(kernel_h, stride_h, pad_h, dilation_h, in_h);
let out_w = calculate_conv_output_size(kernel_w, stride_w, pad_w, dilation_w, in_w);
let total = [batch_size, channels_out, out_h, out_w]
.iter()
.try_fold(1usize, |acc, &x| acc.checked_mul(x))
.expect("conv small-channel: output dimensions would overflow");
let x_data: &[T] = x.storage();
let w_data: &[T] = weight.storage();
let in_spatial = in_h * in_w;
let out_spatial = out_h * out_w;
let k_spatial = kernel_h * kernel_w;
let w_co_stride = channels_in * k_spatial;
let x_batch_stride = channels_in * in_spatial;
let oh_ranges: Vec<(usize, usize)> = (0..kernel_h)
.map(|kh| valid_out_range(kh, dilation_h, pad_h, stride_h, in_h, out_h))
.collect();
let ow_ranges: Vec<(usize, usize)> = (0..kernel_w)
.map(|kw| valid_out_range(kw, dilation_w, pad_w, stride_w, in_w, out_w))
.collect();
let mut output = vec![zero; total];
let plane_work = |b_co: usize, out_plane: &mut [T]| {
let b = b_co / channels_out;
let co = b_co % channels_out;
for ci in 0..channels_in {
let in_base = b * x_batch_stride + ci * in_spatial;
let w_base = co * w_co_stride + ci * k_spatial;
conv_plane_accumulate(
out_plane,
&x_data[in_base..in_base + in_spatial],
&w_data[w_base..w_base + k_spatial],
kernel_h,
kernel_w,
in_w,
out_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
&oh_ranges,
&ow_ranges,
);
}
};
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
let dst_ptr = crate::ops::SendMutPtr::new(output.as_mut_ptr());
(0..batch_size * channels_out)
.into_par_iter()
.for_each(|b_co| {
let out_plane: &mut [T] = unsafe {
core::slice::from_raw_parts_mut(
dst_ptr.ptr_add(b_co * out_spatial),
out_spatial,
)
};
plane_work(b_co, out_plane);
});
}
#[cfg(not(feature = "rayon"))]
{
for b_co in 0..batch_size * channels_out {
let out_base = b_co * out_spatial;
plane_work(b_co, &mut output[out_base..out_base + out_spatial]);
}
}
if let Some(bias) = bias {
let bias = bias.to_contiguous();
let bias_data: &[T] = bias.storage();
assert_eq!(
bias_data.len(),
channels_out,
"conv small-channel: bias length ({}) must equal channels_out ({channels_out})",
bias_data.len()
);
add_bias(
&mut output,
bias_data,
batch_size,
channels_out,
out_spatial,
|a, b| a + b,
);
}
let out_shape = Shape::from(vec![batch_size, channels_out, 1, out_h, out_w]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
dtype,
)
}
fn should_use_direct_conv(x_shape: &[usize], w_shape: &[usize], options: &ConvOptions<3>) -> bool {
if options.groups != 1 || options.padding != [0, 0, 0] || options.dilation != [1, 1, 1] {
return false;
}
let kernel_d = w_shape[2];
let kernel_h = w_shape[3];
let kernel_w = w_shape[4];
if kernel_d != 1 || kernel_h != 1 {
return false;
}
if x_shape[2] != 1 || x_shape[3] != 1 {
return false;
}
let channels_in = x_shape[1];
let in_w = x_shape[4];
let out_w = calculate_conv_output_size(kernel_w, options.stride[2], 0, 1, in_w);
channels_in >= 32 && kernel_w <= 8 && out_w <= 800
}
macro_rules! conv3d_direct_typed {
($fn_name:ident, $T:ty, $dtype:expr, $zero:expr, $one:expr, $add_fn:expr) => {
fn $fn_name(
x: FlexTensor,
weight: FlexTensor,
bias: Option<FlexTensor>,
options: &ConvOptions<3>,
) -> FlexTensor {
conv3d_direct_impl::<$T>(x, weight, bias, options, $dtype, $zero, $one, $add_fn)
}
};
}
conv3d_direct_typed!(
conv3d_direct_f32,
f32,
DType::F32,
0.0f32,
1.0f32,
|a, b| a + b
);
conv3d_direct_typed!(
conv3d_direct_f64,
f64,
DType::F64,
0.0f64,
1.0f64,
|a, b| a + b
);
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
fn conv3d_direct_impl<T: bytemuck::Pod + Clone + Copy + burn_backend::Element + Send + Sync>(
x: FlexTensor,
weight: FlexTensor,
bias: Option<FlexTensor>,
options: &ConvOptions<3>,
dtype: DType,
zero: T,
one: T,
add_fn: fn(T, T) -> T,
) -> FlexTensor {
let x = x.to_contiguous();
let weight = weight.to_contiguous();
let x_shape = x.layout().shape();
let w_shape = weight.layout().shape();
let batch_size = x_shape[0];
let channels_in = x_shape[1];
let in_w = x_shape[4];
let channels_out = w_shape[0];
let kernel_w = w_shape[4];
let stride_w = options.stride[2];
let out_w = calculate_conv_output_size(kernel_w, stride_w, 0, 1, in_w);
let x_data: &[T] = x.storage();
let w_data: &[T] = weight.storage();
let lhs_rs = (channels_in * kernel_w) as isize;
let lhs_cs = kernel_w as isize;
let rhs_rs = in_w as isize;
let rhs_cs = stride_w as isize;
let m = channels_out;
let gemm_k = channels_in;
let n = out_w;
#[cfg(feature = "rayon")]
let parallelism = if m.saturating_mul(n).saturating_mul(gemm_k) >= 192 * 192 * 192 {
gemm::Parallelism::Rayon(0)
} else {
gemm::Parallelism::None
};
#[cfg(not(feature = "rayon"))]
let parallelism = gemm::Parallelism::None;
let total_output = batch_size
.checked_mul(channels_out)
.and_then(|x| x.checked_mul(out_w))
.expect("conv direct: output dimensions overflow");
let mut output = vec![zero; total_output];
let batch_x_len = channels_in * in_w;
{
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
let dst_ptr = crate::ops::SendMutPtr::new(output.as_mut_ptr());
(0..batch_size).into_par_iter().for_each(|b| {
let out_offset = b * channels_out * out_w;
for k in 0..kernel_w {
unsafe {
let x_base = x_data.as_ptr().add(b * batch_x_len);
gemm::gemm(
m,
n,
gemm_k,
dst_ptr.ptr_add(out_offset),
1,
n as isize,
k > 0,
w_data.as_ptr().add(k),
lhs_cs,
lhs_rs,
x_base.add(k),
rhs_cs,
rhs_rs,
one,
one,
false,
false,
false,
parallelism,
);
}
}
});
}
#[cfg(not(feature = "rayon"))]
{
for b in 0..batch_size {
let out_offset = b * channels_out * out_w;
for k in 0..kernel_w {
unsafe {
let x_base = x_data.as_ptr().add(b * batch_x_len);
gemm::gemm(
m,
n,
gemm_k,
output.as_mut_ptr().add(out_offset),
1,
n as isize,
k > 0,
w_data.as_ptr().add(k),
lhs_cs,
lhs_rs,
x_base.add(k),
rhs_cs,
rhs_rs,
one,
one,
false,
false,
false,
parallelism,
);
}
}
}
}
}
if let Some(bias) = bias {
let bias = bias.to_contiguous();
let bias_data: &[T] = bias.storage();
add_bias(
&mut output,
bias_data,
batch_size,
channels_out,
out_w,
add_fn,
);
}
let out_shape = Shape::from(vec![batch_size, channels_out, 1, 1, out_w]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
dtype,
)
}
macro_rules! gemm_typed {
($fn_name:ident, $T:ty, $zero:expr, $one:expr) => {
fn $fn_name(a: &[$T], b: &[$T], m: usize, k: usize, n: usize) -> Vec<$T> {
let mut c = vec![$zero; m * n];
#[cfg(feature = "rayon")]
let parallelism = if m * n * k >= 192 * 192 * 192 {
gemm::Parallelism::Rayon(0)
} else {
gemm::Parallelism::None
};
#[cfg(not(feature = "rayon"))]
let parallelism = gemm::Parallelism::None;
unsafe {
gemm::gemm(
m,
n,
k,
c.as_mut_ptr(),
1,
n as isize,
false,
a.as_ptr(),
1,
k as isize,
b.as_ptr(),
k as isize,
1,
$zero,
$one,
false,
false,
false,
parallelism,
);
}
c
}
};
}
gemm_typed!(gemm_f32, f32, 0.0f32, 1.0f32);
gemm_typed!(gemm_f64, f64, 0.0f64, 1.0f64);
gemm_typed!(gemm_f16, f16, f16::from_f32(0.0), f16::from_f32(1.0));
#[cfg(test)]
mod tests {
use super::*;
use burn_backend::TensorData;
use burn_std::bf16;
#[test]
fn test_conv1d_direct_path() {
let c_in = 64;
let c_out = 32;
let in_w = 100;
let kw = 3;
let stride = 2;
let out_w = (in_w - kw) / stride + 1;
let x_data: Vec<f32> = (0..c_in * in_w)
.map(|i| ((i % 100) as f32 / 100.0) - 0.5)
.collect();
let w_data: Vec<f32> = (0..c_out * c_in * kw)
.map(|i| ((i % 50) as f32 / 50.0) - 0.5)
.collect();
let x = FlexTensor::from_data(TensorData::new(x_data.clone(), vec![1, c_in, in_w]));
let weight = FlexTensor::from_data(TensorData::new(w_data.clone(), vec![c_out, c_in, kw]));
let options = ConvOptions::new([stride], [0], [1], 1);
let result = conv1d_f32(x, weight, None, &options);
let out: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(out.len(), c_out * out_w);
for co in 0..c_out {
for o in 0..out_w {
let mut expected = 0.0f32;
for ci in 0..c_in {
for k in 0..kw {
expected += w_data[co * c_in * kw + ci * kw + k]
* x_data[ci * in_w + o * stride + k];
}
}
let actual = out[co * out_w + o];
assert!(
(actual - expected).abs() < 1e-3,
"mismatch at co={co}, o={o}: expected {expected}, got {actual}"
);
}
}
}
#[test]
fn test_conv1d_direct_path_kw2() {
let c_in = 64;
let c_out = 32;
let in_w = 50;
let kw = 2;
let stride = 2;
let out_w = (in_w - kw) / stride + 1;
let x_data: Vec<f32> = (0..c_in * in_w)
.map(|i| ((i % 100) as f32 / 100.0) - 0.5)
.collect();
let w_data: Vec<f32> = (0..c_out * c_in * kw)
.map(|i| ((i % 50) as f32 / 50.0) - 0.5)
.collect();
let x = FlexTensor::from_data(TensorData::new(x_data.clone(), vec![1, c_in, in_w]));
let weight = FlexTensor::from_data(TensorData::new(w_data.clone(), vec![c_out, c_in, kw]));
let options = ConvOptions::new([stride], [0], [1], 1);
let result = conv1d_f32(x, weight, None, &options);
let out: Vec<f32> = result.into_data().to_vec().unwrap();
for co in 0..c_out {
for o in 0..out_w {
let mut expected = 0.0f32;
for ci in 0..c_in {
for k in 0..kw {
expected += w_data[co * c_in * kw + ci * kw + k]
* x_data[ci * in_w + o * stride + k];
}
}
let actual = out[co * out_w + o];
assert!(
(actual - expected).abs() < 1e-3,
"mismatch at co={co}, o={o}: expected {expected}, got {actual}"
);
}
}
}
#[test]
fn test_conv1d_direct_path_f64() {
let c_in = 64;
let c_out = 16;
let in_w = 50;
let kw = 3;
let stride = 2;
let out_w = (in_w - kw) / stride + 1;
let x_data: Vec<f64> = (0..c_in * in_w)
.map(|i| ((i % 100) as f64 / 100.0) - 0.5)
.collect();
let w_data: Vec<f64> = (0..c_out * c_in * kw)
.map(|i| ((i % 50) as f64 / 50.0) - 0.5)
.collect();
let x = FlexTensor::from_data(TensorData::new(x_data.clone(), vec![1, c_in, in_w]));
let weight = FlexTensor::from_data(TensorData::new(w_data.clone(), vec![c_out, c_in, kw]));
let options = ConvOptions::new([stride], [0], [1], 1);
let result = conv1d_f64(x, weight, None, &options);
let out: Vec<f64> = result.into_data().to_vec().unwrap();
for co in 0..c_out {
for o in 0..out_w {
let mut expected = 0.0f64;
for ci in 0..c_in {
for k in 0..kw {
expected += w_data[co * c_in * kw + ci * kw + k]
* x_data[ci * in_w + o * stride + k];
}
}
let actual = out[co * out_w + o];
assert!(
(actual - expected).abs() < 1e-10,
"f64 mismatch at co={co}, o={o}: expected {expected}, got {actual}"
);
}
}
}
#[test]
fn test_conv1d_direct_path_with_bias() {
let c_in = 64;
let c_out = 32;
let in_w = 50;
let kw = 2;
let stride = 2;
let out_w = (in_w - kw) / stride + 1;
let x_data: Vec<f32> = (0..c_in * in_w)
.map(|i| ((i % 100) as f32 / 100.0) - 0.5)
.collect();
let w_data: Vec<f32> = (0..c_out * c_in * kw)
.map(|i| ((i % 50) as f32 / 50.0) - 0.5)
.collect();
let bias_data: Vec<f32> = (0..c_out).map(|i| i as f32 * 0.1).collect();
let x = FlexTensor::from_data(TensorData::new(x_data.clone(), vec![1, c_in, in_w]));
let weight = FlexTensor::from_data(TensorData::new(w_data.clone(), vec![c_out, c_in, kw]));
let bias = FlexTensor::from_data(TensorData::new(bias_data.clone(), vec![c_out]));
let options = ConvOptions::new([stride], [0], [1], 1);
let result = conv1d_f32(x, weight, Some(bias), &options);
let out: Vec<f32> = result.into_data().to_vec().unwrap();
for co in 0..c_out {
for o in 0..out_w {
let mut expected = bias_data[co];
for ci in 0..c_in {
for k in 0..kw {
expected += w_data[co * c_in * kw + ci * kw + k]
* x_data[ci * in_w + o * stride + k];
}
}
let actual = out[co * out_w + o];
assert!(
(actual - expected).abs() < 1e-3,
"bias mismatch at co={co}, o={o}: expected {expected}, got {actual}"
);
}
}
}
#[test]
fn test_conv2d_f64() {
let x_data: Vec<f64> = (1..=16).map(|x| x as f64).collect();
let x = FlexTensor::from_data(TensorData::new(x_data, vec![1, 1, 4, 4]));
let w_data = vec![1.0f64; 4];
let weight = FlexTensor::from_data(TensorData::new(w_data, vec![1, 1, 2, 2]));
let options = ConvOptions::new([1, 1], [0, 0], [1, 1], 1);
let result = conv2d_f64(x, weight, None, &options);
let out: Vec<f64> = result.into_data().to_vec().unwrap();
assert_eq!(
out,
vec![14.0, 18.0, 22.0, 30.0, 34.0, 38.0, 46.0, 50.0, 54.0]
);
}
#[test]
fn test_conv2d_f16() {
let x_data: Vec<f16> = (1..=16).map(|x| f16::from_f32(x as f32)).collect();
let x = FlexTensor::from_data(TensorData::new(x_data, vec![1, 1, 4, 4]));
let w_data: Vec<f16> = vec![f16::from_f32(1.0); 4];
let weight = FlexTensor::from_data(TensorData::new(w_data, vec![1, 1, 2, 2]));
let options = ConvOptions::new([1, 1], [0, 0], [1, 1], 1);
let result = conv2d_f16(x, weight, None, &options);
let out: Vec<f16> = result.into_data().to_vec().unwrap();
let expected = vec![14.0, 18.0, 22.0, 30.0, 34.0, 38.0, 46.0, 50.0, 54.0];
for (a, e) in out.iter().zip(expected.iter()) {
assert!((a.to_f32() - e).abs() < 0.5);
}
}
#[test]
fn test_conv2d_bf16() {
let x_data: Vec<bf16> = (1..=16).map(|x| bf16::from_f32(x as f32)).collect();
let x = FlexTensor::from_data(TensorData::new(x_data, vec![1, 1, 4, 4]));
let w_data: Vec<bf16> = vec![bf16::from_f32(1.0); 4];
let weight = FlexTensor::from_data(TensorData::new(w_data, vec![1, 1, 2, 2]));
let options = ConvOptions::new([1, 1], [0, 0], [1, 1], 1);
let result = conv2d_bf16(x, weight, None, &options);
let out: Vec<bf16> = result.into_data().to_vec().unwrap();
let expected = vec![14.0, 18.0, 22.0, 30.0, 34.0, 38.0, 46.0, 50.0, 54.0];
for (a, e) in out.iter().zip(expected.iter()) {
assert!((a.to_f32() - e).abs() < 0.5);
}
}
#[allow(clippy::too_many_arguments)]
fn naive_depthwise_conv2d_f32(
x: &[f32],
w: &[f32],
bias: Option<&[f32]>,
batch: usize,
channels: usize,
in_h: usize,
in_w: usize,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
dilation_h: usize,
dilation_w: usize,
) -> (Vec<f32>, usize, usize) {
let out_h = (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1;
let out_w = (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1;
let mut out = vec![0.0f32; batch * channels * out_h * out_w];
for b in 0..batch {
for c in 0..channels {
for oh in 0..out_h {
for ow in 0..out_w {
let mut acc = 0.0f32;
for kh in 0..kernel_h {
let ih = oh as isize * stride_h as isize
+ kh as isize * dilation_h as isize
- pad_h as isize;
if ih < 0 || ih >= in_h as isize {
continue;
}
for kw in 0..kernel_w {
let iw = ow as isize * stride_w as isize
+ kw as isize * dilation_w as isize
- pad_w as isize;
if iw < 0 || iw >= in_w as isize {
continue;
}
let x_idx =
((b * channels + c) * in_h + ih as usize) * in_w + iw as usize;
let w_idx = (c * kernel_h + kh) * kernel_w + kw;
acc += x[x_idx] * w[w_idx];
}
}
if let Some(bias) = bias {
acc += bias[c];
}
let o_idx = ((b * channels + c) * out_h + oh) * out_w + ow;
out[o_idx] = acc;
}
}
}
}
(out, out_h, out_w)
}
fn seeded_vec_f32(n: usize, seed: u32) -> Vec<f32> {
(0..n)
.map(|i| {
let v = ((i as u32).wrapping_mul(2654435761).wrapping_add(seed)) & 0xffff;
(v as f32 / 32768.0) - 1.0
})
.collect()
}
#[allow(clippy::too_many_arguments)]
fn check_depthwise_conv2d_f32(
batch: usize,
channels: usize,
in_h: usize,
in_w: usize,
kernel_h: usize,
kernel_w: usize,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
with_bias: bool,
) {
let x_vec = seeded_vec_f32(batch * channels * in_h * in_w, 1);
let w_vec = seeded_vec_f32(channels * kernel_h * kernel_w, 2);
let bias_vec = if with_bias {
Some(seeded_vec_f32(channels, 3))
} else {
None
};
let (expected, out_h, out_w) = naive_depthwise_conv2d_f32(
&x_vec,
&w_vec,
bias_vec.as_deref(),
batch,
channels,
in_h,
in_w,
kernel_h,
kernel_w,
stride[0],
stride[1],
padding[0],
padding[1],
dilation[0],
dilation[1],
);
let x = FlexTensor::from_data(TensorData::new(x_vec, vec![batch, channels, in_h, in_w]));
let weight = FlexTensor::from_data(TensorData::new(
w_vec,
vec![channels, 1, kernel_h, kernel_w],
));
let bias = bias_vec.map(|v| FlexTensor::from_data(TensorData::new(v, vec![channels])));
let options = ConvOptions::new(stride, padding, dilation, channels);
let result = conv2d_f32(x, weight, bias, &options);
assert_eq!(
result.layout().shape().to_vec(),
vec![batch, channels, out_h, out_w],
"output shape mismatch"
);
let out: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(out.len(), expected.len());
for (i, (a, e)) in out.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < 1e-4,
"mismatch at {i}: got {a}, expected {e}"
);
}
}
#[test]
fn test_conv2d_depthwise_3x3_no_pad() {
check_depthwise_conv2d_f32(2, 8, 16, 16, 3, 3, [1, 1], [0, 0], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_3x3_pad1() {
check_depthwise_conv2d_f32(2, 8, 16, 16, 3, 3, [1, 1], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_3x3_stride2_pad1() {
check_depthwise_conv2d_f32(1, 16, 32, 32, 3, 3, [2, 2], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_5x5_pad2() {
check_depthwise_conv2d_f32(2, 4, 10, 10, 5, 5, [1, 1], [2, 2], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_7x7_pad3() {
check_depthwise_conv2d_f32(2, 24, 14, 14, 7, 7, [1, 1], [3, 3], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_dilated() {
check_depthwise_conv2d_f32(1, 8, 12, 12, 3, 3, [1, 1], [2, 2], [2, 2], false);
}
#[test]
fn test_conv2d_depthwise_with_bias() {
check_depthwise_conv2d_f32(2, 8, 8, 8, 3, 3, [1, 1], [1, 1], [1, 1], true);
}
#[test]
fn test_conv2d_depthwise_single_channel() {
check_depthwise_conv2d_f32(1, 1, 5, 5, 3, 3, [1, 1], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_asymmetric_kernel() {
check_depthwise_conv2d_f32(1, 4, 8, 12, 3, 5, [1, 1], [1, 2], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_f64() {
let x_data: Vec<f64> = (0..2 * 4 * 5 * 5).map(|i| (i as f64) * 0.1).collect();
let w_data: Vec<f64> = (0..4 * 3 * 3).map(|i| (i as f64) * 0.01).collect();
let x = FlexTensor::from_data(TensorData::new(x_data.clone(), vec![2, 4, 5, 5]));
let weight = FlexTensor::from_data(TensorData::new(w_data.clone(), vec![4, 1, 3, 3]));
let options = ConvOptions::new([1, 1], [1, 1], [1, 1], 4);
let result = conv2d_f64(x, weight, None, &options);
let out: Vec<f64> = result.into_data().to_vec().unwrap();
let b = 0usize;
let c = 2usize;
let oh = 2usize;
let ow = 2usize;
let mut expected = 0.0f64;
for kh in 0..3 {
for kw in 0..3 {
let ih = oh as isize + kh as isize - 1;
let iw = ow as isize + kw as isize - 1;
if ih >= 0 && ih < 5 && iw >= 0 && iw < 5 {
let x_idx = ((b * 4 + c) * 5 + ih as usize) * 5 + iw as usize;
let w_idx = (c * 3 + kh) * 3 + kw;
expected += x_data[x_idx] * w_data[w_idx];
}
}
}
let out_idx = ((b * 4 + c) * 5 + oh) * 5 + ow;
assert!((out[out_idx] - expected).abs() < 1e-10);
}
#[test]
fn test_conv2d_depthwise_f16() {
use burn_std::f16;
let x_data_f32: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
let w_data_f32: Vec<f32> = (0..16).map(|i| i as f32 * 0.01).collect();
let x_data: Vec<f16> = x_data_f32.iter().copied().map(f16::from_f32).collect();
let w_data: Vec<f16> = w_data_f32.iter().copied().map(f16::from_f32).collect();
let x = FlexTensor::from_data(TensorData::new(x_data, vec![1, 4, 2, 2]));
let weight = FlexTensor::from_data(TensorData::new(w_data, vec![4, 1, 2, 2]));
let options = ConvOptions::new([1, 1], [0, 0], [1, 1], 4);
let result = conv2d_f16(x, weight, None, &options);
assert_eq!(result.layout().shape().to_vec(), vec![1, 4, 1, 1]);
let out: Vec<f16> = result.into_data().to_vec().unwrap();
for c in 0..4 {
let mut expected = 0.0f32;
for k in 0..4 {
expected += x_data_f32[c * 4 + k] * w_data_f32[c * 4 + k];
}
let actual = out[c].to_f32();
assert!(
(actual - expected).abs() < 1e-2,
"f16 depthwise mismatch at c={c}: expected {expected}, got {actual}"
);
}
}
#[test]
fn test_conv1d_depthwise() {
let channels = 4;
let in_w = 16;
let kw = 3;
let x_data = seeded_vec_f32(channels * in_w, 10);
let w_data = seeded_vec_f32(channels * kw, 20);
let x = FlexTensor::from_data(TensorData::new(x_data.clone(), vec![1, channels, in_w]));
let weight = FlexTensor::from_data(TensorData::new(w_data.clone(), vec![channels, 1, kw]));
let options = ConvOptions::new([1], [1], [1], channels);
let result = conv1d_f32(x, weight, None, &options);
let out_w = in_w;
assert_eq!(result.layout().shape().to_vec(), vec![1, channels, out_w]);
let out: Vec<f32> = result.into_data().to_vec().unwrap();
for c in 0..channels {
for o in 0..out_w {
let mut expected = 0.0f32;
for k in 0..kw {
let i = o as isize + k as isize - 1;
if i >= 0 && i < in_w as isize {
expected += x_data[c * in_w + i as usize] * w_data[c * kw + k];
}
}
let actual = out[c * out_w + o];
assert!(
(actual - expected).abs() < 1e-5,
"conv1d depthwise mismatch at c={c}, o={o}: expected {expected}, got {actual}"
);
}
}
}
#[test]
fn test_conv1d_depthwise_stride_batch_bias() {
let batch = 3;
let channels = 6;
let in_w = 32;
let kw = 5;
let stride = 2;
let pad = 2;
let out_w = (in_w + 2 * pad - kw) / stride + 1;
let x_data = seeded_vec_f32(batch * channels * in_w, 30);
let w_data = seeded_vec_f32(channels * kw, 40);
let bias_data = seeded_vec_f32(channels, 50);
let x = FlexTensor::from_data(TensorData::new(x_data.clone(), vec![batch, channels, in_w]));
let weight = FlexTensor::from_data(TensorData::new(w_data.clone(), vec![channels, 1, kw]));
let bias = FlexTensor::from_data(TensorData::new(bias_data.clone(), vec![channels]));
let options = ConvOptions::new([stride], [pad], [1], channels);
let result = conv1d_f32(x, weight, Some(bias), &options);
assert_eq!(
result.layout().shape().to_vec(),
vec![batch, channels, out_w]
);
let out: Vec<f32> = result.into_data().to_vec().unwrap();
for b in 0..batch {
for c in 0..channels {
for o in 0..out_w {
let mut expected = bias_data[c];
for k in 0..kw {
let i = (o as isize * stride as isize) + k as isize - pad as isize;
if i >= 0 && i < in_w as isize {
let x_idx = (b * channels + c) * in_w + i as usize;
let w_idx = c * kw + k;
expected += x_data[x_idx] * w_data[w_idx];
}
}
let actual = out[(b * channels + c) * out_w + o];
assert!(
(actual - expected).abs() < 1e-4,
"conv1d depthwise mismatch at b={b}, c={c}, o={o}: expected {expected}, got {actual}"
);
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn naive_conv2d_f32(
x: &[f32],
w: &[f32],
bias: Option<&[f32]>,
batch: usize,
channels_in: usize,
channels_out: usize,
in_h: usize,
in_w: usize,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
dilation_h: usize,
dilation_w: usize,
) -> (Vec<f32>, usize, usize) {
let out_h = (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1;
let out_w = (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1;
let mut out = vec![0.0f32; batch * channels_out * out_h * out_w];
for b in 0..batch {
for co in 0..channels_out {
for oh in 0..out_h {
for ow in 0..out_w {
let mut acc = 0.0f32;
for ci in 0..channels_in {
for kh in 0..kernel_h {
let ih = oh as isize * stride_h as isize
+ kh as isize * dilation_h as isize
- pad_h as isize;
if ih < 0 || ih >= in_h as isize {
continue;
}
for kw in 0..kernel_w {
let iw = ow as isize * stride_w as isize
+ kw as isize * dilation_w as isize
- pad_w as isize;
if iw < 0 || iw >= in_w as isize {
continue;
}
let x_idx = ((b * channels_in + ci) * in_h + ih as usize)
* in_w
+ iw as usize;
let w_idx =
((co * channels_in + ci) * kernel_h + kh) * kernel_w + kw;
acc += x[x_idx] * w[w_idx];
}
}
}
if let Some(bias) = bias {
acc += bias[co];
}
let o_idx = ((b * channels_out + co) * out_h + oh) * out_w + ow;
out[o_idx] = acc;
}
}
}
}
(out, out_h, out_w)
}
#[allow(clippy::too_many_arguments)]
fn check_small_channel_conv2d_f32(
batch: usize,
channels_in: usize,
channels_out: usize,
in_h: usize,
in_w: usize,
kernel_h: usize,
kernel_w: usize,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
with_bias: bool,
) {
let x_vec = seeded_vec_f32(batch * channels_in * in_h * in_w, 100);
let w_vec = seeded_vec_f32(channels_out * channels_in * kernel_h * kernel_w, 200);
let bias_vec = if with_bias {
Some(seeded_vec_f32(channels_out, 300))
} else {
None
};
let (expected, out_h, out_w) = naive_conv2d_f32(
&x_vec,
&w_vec,
bias_vec.as_deref(),
batch,
channels_in,
channels_out,
in_h,
in_w,
kernel_h,
kernel_w,
stride[0],
stride[1],
padding[0],
padding[1],
dilation[0],
dilation[1],
);
let x = FlexTensor::from_data(TensorData::new(x_vec, vec![batch, channels_in, in_h, in_w]));
let weight = FlexTensor::from_data(TensorData::new(
w_vec,
vec![channels_out, channels_in, kernel_h, kernel_w],
));
let bias = bias_vec.map(|v| FlexTensor::from_data(TensorData::new(v, vec![channels_out])));
let options = ConvOptions::new(stride, padding, dilation, 1);
let result = conv2d_f32(x, weight, bias, &options);
assert_eq!(
result.layout().shape().to_vec(),
vec![batch, channels_out, out_h, out_w],
"output shape mismatch"
);
let out: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(out.len(), expected.len());
for (i, (a, e)) in out.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < 1e-3,
"mismatch at {i}: got {a}, expected {e}"
);
}
}
#[test]
fn test_conv2d_small_channel_3in_8out_k3x3_pad1() {
check_small_channel_conv2d_f32(2, 3, 8, 16, 16, 3, 3, [1, 1], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_3in_3out_k3x3_no_pad() {
check_small_channel_conv2d_f32(1, 3, 3, 10, 10, 3, 3, [1, 1], [0, 0], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_3in_16out_k5x5_pad2() {
check_small_channel_conv2d_f32(2, 3, 16, 12, 12, 5, 5, [1, 1], [2, 2], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_4in_8out_k3x3_stride2() {
check_small_channel_conv2d_f32(1, 4, 8, 16, 16, 3, 3, [2, 2], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_2in_4out_dilated() {
check_small_channel_conv2d_f32(1, 2, 4, 12, 12, 3, 3, [1, 1], [2, 2], [2, 2], false);
}
#[test]
fn test_conv2d_small_channel_with_bias() {
check_small_channel_conv2d_f32(2, 3, 8, 8, 8, 3, 3, [1, 1], [1, 1], [1, 1], true);
}
#[test]
fn test_conv2d_small_channel_asymmetric_kernel() {
check_small_channel_conv2d_f32(1, 3, 6, 16, 16, 1, 3, [1, 1], [0, 1], [1, 1], false);
check_small_channel_conv2d_f32(1, 3, 6, 16, 16, 3, 1, [1, 1], [1, 0], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_f64() {
let x_data: Vec<f64> = (0..2 * 3 * 5 * 5).map(|i| (i as f64) * 0.1).collect();
let w_data: Vec<f64> = (0..4 * 3 * 3 * 3).map(|i| (i as f64) * 0.01).collect();
let x = FlexTensor::from_data(TensorData::new(x_data.clone(), vec![2, 3, 5, 5]));
let weight = FlexTensor::from_data(TensorData::new(w_data.clone(), vec![4, 3, 3, 3]));
let options = ConvOptions::new([1, 1], [1, 1], [1, 1], 1);
let result = conv2d_f64(x, weight, None, &options);
let out: Vec<f64> = result.into_data().to_vec().unwrap();
let b = 0usize;
let co = 2usize;
let oh = 2usize;
let ow = 2usize;
let mut expected = 0.0f64;
for ci in 0..3 {
for kh in 0..3 {
for kw in 0..3 {
let ih = oh as isize + kh as isize - 1;
let iw = ow as isize + kw as isize - 1;
if ih >= 0 && ih < 5 && iw >= 0 && iw < 5 {
let x_idx = ((b * 3 + ci) * 5 + ih as usize) * 5 + iw as usize;
let w_idx = ((co * 3 + ci) * 3 + kh) * 3 + kw;
expected += x_data[x_idx] * w_data[w_idx];
}
}
}
}
let out_idx = ((b * 4 + co) * 5 + oh) * 5 + ow;
assert!(
(out[out_idx] - expected).abs() < 1e-10,
"got {}, expected {expected}",
out[out_idx]
);
}
#[test]
fn test_conv2d_small_channel_f16() {
use burn_std::f16;
let x_data_f32: Vec<f32> = (0..3 * 4 * 4).map(|i| i as f32 * 0.1).collect();
let w_data_f32: Vec<f32> = (0..4 * 3 * 3 * 3).map(|i| i as f32 * 0.01).collect();
let x_data_f16: Vec<f16> = x_data_f32.iter().copied().map(f16::from_f32).collect();
let w_data_f16: Vec<f16> = w_data_f32.iter().copied().map(f16::from_f32).collect();
let x_f16 = FlexTensor::from_data(TensorData::new(x_data_f16, vec![1, 3, 4, 4]));
let weight_f16 = FlexTensor::from_data(TensorData::new(w_data_f16, vec![4, 3, 3, 3]));
let x_f32 = FlexTensor::from_data(TensorData::new(x_data_f32, vec![1, 3, 4, 4]));
let weight_f32 = FlexTensor::from_data(TensorData::new(w_data_f32, vec![4, 3, 3, 3]));
let options = ConvOptions::new([1, 1], [1, 1], [1, 1], 1);
let result_f16 = conv2d_f16(x_f16, weight_f16, None, &options);
let result_f32 = conv2d_f32(x_f32, weight_f32, None, &options);
assert_eq!(result_f16.layout().shape().to_vec(), vec![1, 4, 4, 4]);
assert_eq!(result_f32.layout().shape().to_vec(), vec![1, 4, 4, 4]);
let out_f16: Vec<f16> = result_f16.into_data().to_vec().unwrap();
let out_f32: Vec<f32> = result_f32.into_data().to_vec().unwrap();
assert_eq!(out_f16.len(), out_f32.len());
let rel_tol = 3e-3f32;
let abs_tol = 1e-2f32;
for (i, (actual, expected)) in out_f16.iter().zip(out_f32.iter()).enumerate() {
let actual_f32 = actual.to_f32();
let bound = (expected.abs() * rel_tol).max(abs_tol);
assert!(
!actual_f32.is_nan() && (actual_f32 - expected).abs() <= bound,
"f16 small-channel mismatch at {i}: got {actual_f32}, expected {expected}, bound {bound}"
);
}
}
#[test]
fn test_conv2d_small_channel_bias_length_mismatch_panics() {
let x = FlexTensor::from_data(TensorData::new(vec![0.0f32; 48], vec![1, 3, 4, 4]));
let weight = FlexTensor::from_data(TensorData::new(vec![0.0f32; 108], vec![4, 3, 3, 3]));
let bias = FlexTensor::from_data(TensorData::new(vec![0.0f32, 0.0], vec![2]));
let options = ConvOptions::new([1, 1], [1, 1], [1, 1], 1);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
conv2d_f32(x, weight, Some(bias), &options)
}));
assert!(result.is_err(), "expected panic on bias length mismatch");
}
#[test]
fn test_conv2d_depthwise_bias_length_mismatch_panics() {
let x = FlexTensor::from_data(TensorData::new(vec![0.0f32; 48], vec![1, 3, 4, 4]));
let weight = FlexTensor::from_data(TensorData::new(vec![0.0f32; 27], vec![3, 1, 3, 3]));
let bias = FlexTensor::from_data(TensorData::new(vec![0.0f32, 0.0], vec![2]));
let options = ConvOptions::new([1, 1], [1, 1], [1, 1], 3); let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
conv2d_f32(x, weight, Some(bias), &options)
}));
assert!(result.is_err(), "expected panic on bias length mismatch");
}
#[test]
fn test_conv_plane_accumulate_accumulates_into_prefilled() {
let x = vec![1.0f32, 2.0, 3.0, 4.0];
let w = vec![1.0f32]; let mut out_plane = vec![10.0f32, 20.0, 30.0, 40.0];
let oh_ranges = [(0usize, 2usize)]; let ow_ranges = [(0usize, 2usize)];
super::conv_plane_accumulate::<f32>(
&mut out_plane,
&x,
&w,
1,
1,
2,
2,
1,
1,
0,
0,
1,
1,
&oh_ranges,
&ow_ranges,
);
assert_eq!(out_plane, vec![11.0f32, 22.0, 33.0, 44.0]);
}
#[test]
fn test_conv1d_small_channel_3in() {
let batch = 2;
let channels_in = 3;
let channels_out = 5;
let in_w = 24;
let kw = 5;
let stride = 1;
let pad = 2;
let out_w = in_w;
let x_data = seeded_vec_f32(batch * channels_in * in_w, 400);
let w_data = seeded_vec_f32(channels_out * channels_in * kw, 500);
let x = FlexTensor::from_data(TensorData::new(
x_data.clone(),
vec![batch, channels_in, in_w],
));
let weight = FlexTensor::from_data(TensorData::new(
w_data.clone(),
vec![channels_out, channels_in, kw],
));
let options = ConvOptions::new([stride], [pad], [1], 1);
let result = conv1d_f32(x, weight, None, &options);
assert_eq!(
result.layout().shape().to_vec(),
vec![batch, channels_out, out_w]
);
let out: Vec<f32> = result.into_data().to_vec().unwrap();
for b in 0..batch {
for co in 0..channels_out {
for o in 0..out_w {
let mut expected = 0.0f32;
for ci in 0..channels_in {
for k in 0..kw {
let i = o as isize + k as isize - pad as isize;
if i >= 0 && i < in_w as isize {
let x_idx = (b * channels_in + ci) * in_w + i as usize;
let w_idx = (co * channels_in + ci) * kw + k;
expected += x_data[x_idx] * w_data[w_idx];
}
}
}
let actual = out[(b * channels_out + co) * out_w + o];
assert!(
(actual - expected).abs() < 1e-4,
"mismatch at b={b}, co={co}, o={o}: expected {expected}, got {actual}"
);
}
}
}
}
#[test]
fn test_conv2d_small_channel_single_input() {
check_small_channel_conv2d_f32(1, 1, 4, 8, 8, 3, 3, [1, 1], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_threshold_exact() {
assert!(should_use_small_channel_conv(
&[1, 4, 1, 8, 8],
&[16, 4, 1, 3, 3],
&ConvOptions::new([1, 1, 1], [0, 1, 1], [1, 1, 1], 1),
));
assert!(!should_use_small_channel_conv(
&[1, 5, 1, 8, 8],
&[2, 5, 1, 3, 3],
&ConvOptions::new([1, 1, 1], [0, 1, 1], [1, 1, 1], 1),
));
assert!(!should_use_small_channel_conv(
&[1, 3, 1, 8, 8],
&[17, 3, 1, 3, 3],
&ConvOptions::new([1, 1, 1], [0, 1, 1], [1, 1, 1], 1),
));
assert!(!should_use_small_channel_conv(
&[1, 3, 1, 224, 224],
&[64, 3, 1, 7, 7],
&ConvOptions::new([1, 2, 2], [0, 3, 3], [1, 1, 1], 1),
));
assert!(!should_use_small_channel_conv(
&[1, 4, 1, 8, 8],
&[4, 1, 1, 3, 3],
&ConvOptions::new([1, 1, 1], [0, 1, 1], [1, 1, 1], 4),
));
check_small_channel_conv2d_f32(1, 5, 4, 8, 8, 3, 3, [1, 1], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_oh_outer_k3x3() {
check_depthwise_conv2d_f32(1, 2, 96, 96, 3, 3, [1, 1], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_oh_outer_k3x3_stride2() {
check_depthwise_conv2d_f32(1, 2, 192, 192, 3, 3, [2, 2], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_oh_outer_k5x1() {
check_depthwise_conv2d_f32(1, 2, 97, 97, 5, 1, [1, 1], [2, 0], [1, 1], false);
}
#[test]
fn test_conv2d_depthwise_oh_outer_k1x5() {
check_depthwise_conv2d_f32(1, 2, 97, 97, 1, 5, [1, 1], [0, 2], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_oh_outer_k3x3() {
check_small_channel_conv2d_f32(1, 3, 4, 96, 96, 3, 3, [1, 1], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_oh_outer_k3x3_stride2() {
check_small_channel_conv2d_f32(1, 3, 4, 192, 192, 3, 3, [2, 2], [1, 1], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_oh_outer_k5x1_sobel() {
check_small_channel_conv2d_f32(1, 3, 3, 97, 97, 5, 1, [1, 1], [2, 0], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_oh_outer_k1x5_sobel() {
check_small_channel_conv2d_f32(1, 3, 3, 97, 97, 1, 5, [1, 1], [0, 2], [1, 1], false);
}
#[test]
fn test_conv2d_small_channel_oh_outer_with_bias_and_dilation() {
check_small_channel_conv2d_f32(1, 3, 8, 100, 100, 3, 3, [1, 1], [2, 2], [2, 2], true);
}
#[test]
fn test_conv2d_depthwise_predicate_triggers() {
assert!(should_use_depthwise_conv(
&[4, 32, 1, 56, 56],
&[32, 1, 1, 3, 3],
&ConvOptions::new([1, 1, 1], [0, 1, 1], [1, 1, 1], 32),
));
assert!(should_use_depthwise_conv(
&[4, 48, 1, 56, 56],
&[48, 1, 1, 7, 7],
&ConvOptions::new([1, 1, 1], [0, 3, 3], [1, 1, 1], 48),
));
assert!(should_use_depthwise_conv(
&[8, 64, 1, 1, 1024],
&[64, 1, 1, 1, 3],
&ConvOptions::new([1, 1, 1], [0, 0, 1], [1, 1, 1], 64),
));
assert!(should_use_depthwise_conv(
&[1, 16, 1, 32, 32],
&[16, 1, 1, 3, 3],
&ConvOptions::new([1, 2, 2], [0, 1, 1], [1, 2, 2], 16),
));
assert!(!should_use_depthwise_conv(
&[1, 8, 1, 16, 16],
&[16, 8, 1, 3, 3],
&ConvOptions::new([1, 1, 1], [0, 1, 1], [1, 1, 1], 1),
));
assert!(!should_use_depthwise_conv(
&[1, 8, 1, 16, 16],
&[8, 4, 1, 3, 3],
&ConvOptions::new([1, 1, 1], [0, 1, 1], [1, 1, 1], 2),
));
assert!(!should_use_depthwise_conv(
&[1, 8, 1, 16, 16],
&[16, 1, 1, 3, 3],
&ConvOptions::new([1, 1, 1], [0, 1, 1], [1, 1, 1], 8),
));
assert!(!should_use_depthwise_conv(
&[1, 8, 4, 16, 16],
&[8, 1, 3, 3, 3],
&ConvOptions::new([1, 1, 1], [1, 1, 1], [1, 1, 1], 8),
));
}
#[test]
fn test_valid_out_range_basics() {
let (s, e) = valid_out_range(0, 1, 0, 1, 5, 3);
assert_eq!((s, e), (0, 3));
let (s, e) = valid_out_range(2, 1, 0, 1, 5, 3);
assert_eq!((s, e), (0, 3));
let (s, e) = valid_out_range(0, 1, 1, 1, 5, 5);
assert_eq!((s, e), (1, 5));
let (s, e) = valid_out_range(2, 1, 1, 1, 5, 5);
assert_eq!((s, e), (0, 4));
let (s, e) = valid_out_range(0, 1, 0, 2, 5, 3);
assert_eq!((s, e), (0, 3));
let (s, e) = valid_out_range(0, 2, 2, 1, 5, 5);
assert_eq!((s, e), (2, 5));
}
}