use crate as burn;
use burn_tensor::ops::conv::calculate_conv_padding;
use crate::config::Config;
use crate::module::Module;
#[derive(Module, Config, Debug, PartialEq)]
pub enum PaddingConfig1d {
Same,
Valid,
Explicit(usize),
}
impl PaddingConfig1d {
pub(crate) fn calculate_padding_1d(
&self,
length: usize,
kernel_size: usize,
stride: usize,
) -> usize {
let same_padding = || calculate_conv_padding(kernel_size, stride, length, length);
match self {
Self::Valid => 0,
Self::Same => same_padding(),
Self::Explicit(value) => *value,
}
}
}
#[derive(Module, Config, Debug, PartialEq)]
pub enum PaddingConfig2d {
Same,
Valid,
Explicit(usize, usize),
}
impl PaddingConfig2d {
pub(crate) fn calculate_padding_2d(
&self,
height: usize,
width: usize,
kernel_size: &[usize; 2],
stride: &[usize; 2],
) -> [usize; 2] {
let same_padding = || {
let p1 = calculate_conv_padding(kernel_size[0], stride[0], height, height);
let p2 = calculate_conv_padding(kernel_size[1], stride[1], width, width);
[p1, p2]
};
match self {
Self::Same => same_padding(),
Self::Valid => [0, 0],
Self::Explicit(v1, v2) => [*v1, *v2],
}
}
}