use crate::nn::module::Module;
use crate::tensor::Tensor;
pub struct MaxPool2d {
pub kernel_size: (usize, usize),
pub stride: (usize, usize),
}
impl MaxPool2d {
pub fn new(kernel_size: (usize, usize), stride: (usize, usize)) -> Self {
Self {
kernel_size,
stride,
}
}
pub fn square(size: usize) -> Self {
Self {
kernel_size: (size, size),
stride: (size, size),
}
}
}
impl Module for MaxPool2d {
fn forward(&self, inputs: &Tensor) -> Tensor {
inputs.max_pool2d(self.kernel_size, self.stride)
}
fn parameters(&self) -> Vec<Tensor> {
vec![] }
}
pub struct AvgPool2d {
pub kernel_size: (usize, usize),
pub stride: (usize, usize),
pub padding: (usize, usize),
}
impl AvgPool2d {
pub fn new(kernel_size: (usize, usize), stride: (usize, usize)) -> Self {
Self {
kernel_size,
stride,
padding: (0, 0),
}
}
pub fn square(size: usize) -> Self {
Self {
kernel_size: (size, size),
stride: (size, size),
padding: (0, 0),
}
}
pub fn with_padding(mut self, padding: (usize, usize)) -> Self {
self.padding = padding;
self
}
}
impl Module for AvgPool2d {
fn forward(&self, inputs: &Tensor) -> Tensor {
inputs.avg_pool2d(self.kernel_size, self.stride, self.padding)
}
fn parameters(&self) -> Vec<Tensor> {
vec![]
}
}
pub struct AdaptiveAvgPool2d {
pub output_size: (usize, usize),
}
impl AdaptiveAvgPool2d {
pub fn new(output_size: (usize, usize)) -> Self {
Self { output_size }
}
pub fn global() -> Self {
Self {
output_size: (1, 1),
}
}
}
impl Module for AdaptiveAvgPool2d {
fn forward(&self, inputs: &Tensor) -> Tensor {
inputs.adaptive_avg_pool2d(self.output_size)
}
fn parameters(&self) -> Vec<Tensor> {
vec![]
}
}
pub type GlobalAvgPool2d = AdaptiveAvgPool2d;
impl GlobalAvgPool2d {
pub fn new_global() -> Self {
AdaptiveAvgPool2d::global()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::GraphContext;
use std::cell::RefCell;
use std::rc::Rc;
#[test]
fn test_max_pool2d() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let input = Tensor::new_input(&context, "input");
let pool = MaxPool2d::new((2, 2), (2, 2));
let _ = pool.forward(&input);
assert!(pool.parameters().is_empty());
}
#[test]
fn test_max_pool2d_square() {
let pool = MaxPool2d::square(2);
assert_eq!(pool.kernel_size, (2, 2));
assert_eq!(pool.stride, (2, 2));
}
#[test]
fn test_avg_pool2d() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let input = Tensor::new_input(&context, "input");
let pool = AvgPool2d::new((2, 2), (2, 2)).with_padding((1, 1));
let _ = pool.forward(&input);
assert_eq!(pool.padding, (1, 1));
}
#[test]
fn test_adaptive_avg_pool2d() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let input = Tensor::new_input(&context, "input");
let pool = AdaptiveAvgPool2d::new((7, 7));
let _ = pool.forward(&input);
assert_eq!(pool.output_size, (7, 7));
}
#[test]
fn test_global_avg_pool() {
let gap = AdaptiveAvgPool2d::global();
assert_eq!(gap.output_size, (1, 1));
}
}