use bimm_contracts::unpack_shape_contract;
use burn::nn::{Initializer, PaddingConfig2d};
pub fn scalar_to_array<const D: usize, T>(v: T) -> [T; D]
where
T: Copy,
{
[v; D]
}
pub fn stride_div_output_resolution(
input_resolution: [usize; 2],
stride: usize,
) -> [usize; 2] {
unpack_shape_contract!(
[
"in_height" = "out_height" * "stride",
"in_width" = "out_width" * "stride"
],
&input_resolution,
&["out_height", "out_width"],
&[("stride", stride)]
)
}
pub static CONV_INTO_RELU_INITIALIZER: Initializer = Initializer::KaimingNormal {
gain: core::f64::consts::SQRT_2,
fan_out_only: true,
};
pub fn get_square_conv2d_padding(
kernel: usize,
stride: usize,
dilation: usize,
) -> usize {
assert_eq!(kernel % 2, 1, "Kernel size must be odd");
assert!(stride >= 1, "Stride must be >= 1");
assert!(dilation >= 1, "Dilation must be >= 1");
((stride - 1) + dilation * (kernel - 1)) / 2
}
pub fn build_square_conv2d_padding_config(
kernel: usize,
stride: usize,
dilation: usize,
) -> PaddingConfig2d {
let padding = get_square_conv2d_padding(kernel, stride, dilation);
PaddingConfig2d::Explicit(padding, padding)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_to_narray() {
assert_eq!(scalar_to_array::<4, usize>(1), [1, 1, 1, 1]);
}
#[test]
fn test_get_padding() {
assert_eq!(get_square_conv2d_padding(1, 1, 1), 0);
assert_eq!(get_square_conv2d_padding(3, 1, 1), 1);
assert_eq!(get_square_conv2d_padding(5, 1, 1), 2);
assert_eq!(get_square_conv2d_padding(1, 2, 1), 0);
assert_eq!(get_square_conv2d_padding(3, 2, 1), 1);
assert_eq!(get_square_conv2d_padding(5, 2, 1), 2);
assert_eq!(get_square_conv2d_padding(1, 1, 2), 0);
assert_eq!(get_square_conv2d_padding(3, 1, 2), 2);
assert_eq!(get_square_conv2d_padding(5, 1, 2), 4);
assert_eq!(get_square_conv2d_padding(1, 2, 2), 0);
assert_eq!(get_square_conv2d_padding(3, 2, 2), 2);
assert_eq!(get_square_conv2d_padding(5, 2, 2), 4);
}
#[test]
#[should_panic(expected = "Kernel size must be odd")]
fn test_get_padding_panic() {
get_square_conv2d_padding(2, 1, 1);
}
#[test]
#[should_panic(expected = "Stride must be >= 1")]
fn test_get_padding_panic_stride() {
get_square_conv2d_padding(1, 0, 1);
}
#[test]
#[should_panic(expected = "Dilation must be >= 1")]
fn test_get_padding_panic_dilation() {
get_square_conv2d_padding(1, 1, 0);
}
#[test]
fn test_build_square_conv2d_padding_config() {
assert_eq!(
build_square_conv2d_padding_config(1, 1, 1),
PaddingConfig2d::Explicit(0, 0)
);
assert_eq!(
build_square_conv2d_padding_config(3, 2, 2),
PaddingConfig2d::Explicit(2, 2)
);
}
}