use burn_core as burn;
use burn::module::Module;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn_nn::PaddingConfig2d;
use burn_nn::conv::{Conv2d, Conv2dConfig};
#[derive(Debug, Clone)]
pub struct L2Pool2dConfig {
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
}
impl Default for L2Pool2dConfig {
fn default() -> Self {
Self {
kernel_size: 5,
stride: 2,
padding: 2,
}
}
}
impl L2Pool2dConfig {
#[allow(dead_code)]
pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
Self {
kernel_size,
stride,
padding,
}
}
pub fn init<B: Backend>(&self, channels: usize, device: &B::Device) -> L2Pool2d<B> {
L2Pool2d::new(
channels,
self.kernel_size,
self.stride,
self.padding,
device,
)
}
}
#[derive(Module, Debug)]
pub struct L2Pool2d<B: Backend> {
conv: Conv2d<B>,
}
impl<B: Backend> L2Pool2d<B> {
pub fn new(
channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
device: &B::Device,
) -> Self {
let kernel = Self::create_hanning_kernel(channels, kernel_size, device);
let mut conv = Conv2dConfig::new([channels, channels], [kernel_size, kernel_size])
.with_stride([stride, stride])
.with_padding(PaddingConfig2d::Explicit(
padding, padding, padding, padding,
))
.with_groups(channels)
.with_bias(false)
.init(device);
conv.weight = burn::module::Param::from_tensor(kernel);
Self { conv }
}
fn create_hanning_kernel<B2: Backend>(
channels: usize,
kernel_size: usize,
device: &B2::Device,
) -> Tensor<B2, 4> {
let mut hanning_1d = Vec::with_capacity(kernel_size);
for i in 0..kernel_size {
let n = i as f32;
let n_minus_1 = (kernel_size - 1) as f32;
let value = if n_minus_1 == 0.0 {
1.0
} else {
0.5 * (1.0 - (2.0 * std::f32::consts::PI * n / n_minus_1).cos())
};
hanning_1d.push(value);
}
let mut hanning_2d = Vec::with_capacity(kernel_size * kernel_size);
let mut sum = 0.0;
for i in 0..kernel_size {
for j in 0..kernel_size {
let value = hanning_1d[i] * hanning_1d[j];
hanning_2d.push(value);
sum += value;
}
}
for v in hanning_2d.iter_mut() {
*v /= sum;
}
let kernel_single = Tensor::<B2, 1>::from_floats(hanning_2d.as_slice(), device).reshape([
1,
1,
kernel_size,
kernel_size,
]);
kernel_single.repeat_dim(0, channels)
}
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x_sq = x.clone().mul(x);
let pooled = self.conv.forward(x_sq);
pooled.clamp_min(1e-10).sqrt()
}
}