auto_diff/op/
pooling.rs

1use tensor_rs::tensor::Tensor;
2use super::{OpTrait, OpHandle};
3
4#[cfg(feature = "use-serde")]
5use serde::{Serialize, Deserialize};
6#[cfg(feature = "use-serde")]
7use std::any::Any;
8
9// MaxPool1d
10// Maxpool2d
11#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
12pub struct MaxPool2d {
13    #[cfg_attr(feature = "use-serde", serde(skip))]
14    handle: OpHandle,
15    kernel_size: (usize, usize),
16    stride: (usize, usize),
17    padding: Tensor,
18    dilation: (usize, usize),
19    return_indices: bool,
20    ceil_mode: bool,
21}
22impl MaxPool2d {
23    pub fn new(kernel_size: Option<(usize, usize)>,
24               stride: Option<(usize, usize)>,
25               padding: Option<Tensor>,
26               dilation: Option<(usize, usize)>,
27               return_indices: Option<bool>,
28               ceil_mode: Option<bool>,) -> MaxPool2d {
29        let kernel_size = if let Some(v) = kernel_size {v} else {(2, 2)};
30        let stride = if let Some(v) = stride {v} else {(2, 2)};
31        let padding = if let Some(v) = padding {v} else {Tensor::zeros(&[1])};
32        let dilation = if let Some(v) = dilation {v} else {(2, 2)};
33        let return_indices = if let Some(v) = return_indices {v} else {false};
34        let ceil_mode = if let Some(v) = ceil_mode {v} else {false};
35        MaxPool2d {
36            handle: OpHandle::new(),
37            kernel_size,
38            stride,
39            padding,
40            dilation,
41            return_indices,
42            ceil_mode,
43        }
44    }
45    fn get_handle(&self) -> &OpHandle {
46        &self.handle
47    }
48    fn get_handle_mut(&mut self) -> &mut OpHandle {
49        &mut self.handle
50    }
51}
52impl OpTrait for MaxPool2d {
53     
54    fn get_name(&self) -> &'static str {
55        "MaxPool2d"
56    }
57    fn get_input_size(&self) -> usize {
58        1
59    }
60    fn get_output_size(&self) -> usize {
61        1
62    }
63    fn apply(&self, input: &[Tensor], output: &[Tensor]) {
64        unimplemented!();
65    }
66    fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
67        unimplemented!();
68    }
69    fn get_values(&self) -> Vec<Tensor> {
70        Vec::new()
71    }
72    fn get_grads(&self) -> Vec<Tensor> {
73        Vec::new()
74    }
75    fn set_values(&self, _v: &[Tensor]) {
76    }
77    #[cfg(feature = "use-serde")]
78    fn as_any(&self) -> &dyn Any {
79	self
80    }
81}
82
83// MaxPool3d
84// MaxUnpool1d
85// MaxUnpool2d
86// MaxUnpool3d
87// AvgPool1d
88// AvgPool2d
89// AvgPool3d
90// FractionalMaxPool2d
91// LPPool1d
92// LPPool2d
93// AdaptiveMaxPool1d
94// AdaptiveMaxPool2d
95// AdaptiveMaxPool3d
96// AdaptiveAvgPool1d
97// AdaptiveAvgPool2d
98// AdaptiveAvgPool3d
99//