use crate::autograd::{self, Variable};
use crate::tensor::Result;
use super::Module;
use super::parameter::Parameter;
pub struct MaxPool2d {
kernel_size: [i64; 2],
stride: [i64; 2],
padding: [i64; 2],
dilation: [i64; 2],
ceil_mode: bool,
}
impl MaxPool2d {
pub fn new(kernel_size: i64) -> Self {
Self {
kernel_size: [kernel_size, kernel_size],
stride: [kernel_size, kernel_size],
padding: [0, 0],
dilation: [1, 1],
ceil_mode: false,
}
}
pub fn with_stride(kernel_size: i64, stride: i64) -> Self {
Self {
kernel_size: [kernel_size, kernel_size],
stride: [stride, stride],
padding: [0, 0],
dilation: [1, 1],
ceil_mode: false,
}
}
pub fn padding(mut self, padding: i64) -> Self {
self.padding = [padding, padding];
self
}
pub fn dilation(mut self, dilation: i64) -> Self {
self.dilation = [dilation, dilation];
self
}
pub fn ceil_mode(mut self, ceil_mode: bool) -> Self {
self.ceil_mode = ceil_mode;
self
}
}
impl Module for MaxPool2d {
fn name(&self) -> &str { "maxpool2d" }
fn forward(&self, input: &Variable) -> Result<Variable> {
autograd::max_pool2d(input, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode)
}
fn parameters(&self) -> Vec<Parameter> {
vec![] }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_max_pool2d_basic() {
let opts = crate::tensor::test_opts();
let x = Variable::new(
Tensor::randn(&[1, 1, 4, 4], opts).unwrap(),
false,
);
let pool = MaxPool2d::new(2);
let y = pool.forward(&x).unwrap();
assert_eq!(y.shape(), vec![1, 1, 2, 2]);
}
#[test]
fn test_max_pool2d_with_padding() {
let opts = crate::tensor::test_opts();
let x = Variable::new(
Tensor::randn(&[2, 3, 8, 8], opts).unwrap(),
false,
);
let pool = MaxPool2d::with_stride(3, 2).padding(1);
let y = pool.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 3, 4, 4]);
}
#[test]
fn test_max_pool2d_gradient() {
let opts = crate::tensor::test_opts();
let x = Variable::new(
Tensor::randn(&[2, 1, 4, 4], opts).unwrap(),
true,
);
let pool = MaxPool2d::new(2);
let y = pool.forward(&x).unwrap();
let loss = y.sum().unwrap();
loss.backward().unwrap();
let grad = x.grad().unwrap();
assert_eq!(grad.shape(), vec![2, 1, 4, 4]);
}
#[test]
fn test_max_pool2d_values() {
let device = crate::tensor::test_device();
let data = vec![
1.0_f32, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let x = Variable::new(
Tensor::from_f32(&data, &[1, 1, 4, 4], device).unwrap(),
false,
);
let pool = MaxPool2d::new(2);
let y = pool.forward(&x).unwrap();
let y_data = y.data().to_f32_vec().unwrap();
assert_eq!(y_data, vec![6.0, 8.0, 14.0, 16.0]);
}
}