1use tensor_rs::tensor::Tensor;
2use super::{OpTrait, OpHandle};
3
4#[cfg(feature = "use-serde")]
5use serde::{Serialize, Deserialize};
6#[cfg(feature = "use-serde")]
7use std::any::Any;
8
9#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
12pub struct MaxPool2d {
13 #[cfg_attr(feature = "use-serde", serde(skip))]
14 handle: OpHandle,
15 kernel_size: (usize, usize),
16 stride: (usize, usize),
17 padding: Tensor,
18 dilation: (usize, usize),
19 return_indices: bool,
20 ceil_mode: bool,
21}
22impl MaxPool2d {
23 pub fn new(kernel_size: Option<(usize, usize)>,
24 stride: Option<(usize, usize)>,
25 padding: Option<Tensor>,
26 dilation: Option<(usize, usize)>,
27 return_indices: Option<bool>,
28 ceil_mode: Option<bool>,) -> MaxPool2d {
29 let kernel_size = if let Some(v) = kernel_size {v} else {(2, 2)};
30 let stride = if let Some(v) = stride {v} else {(2, 2)};
31 let padding = if let Some(v) = padding {v} else {Tensor::zeros(&[1])};
32 let dilation = if let Some(v) = dilation {v} else {(2, 2)};
33 let return_indices = if let Some(v) = return_indices {v} else {false};
34 let ceil_mode = if let Some(v) = ceil_mode {v} else {false};
35 MaxPool2d {
36 handle: OpHandle::new(),
37 kernel_size,
38 stride,
39 padding,
40 dilation,
41 return_indices,
42 ceil_mode,
43 }
44 }
45 fn get_handle(&self) -> &OpHandle {
46 &self.handle
47 }
48 fn get_handle_mut(&mut self) -> &mut OpHandle {
49 &mut self.handle
50 }
51}
52impl OpTrait for MaxPool2d {
53
54 fn get_name(&self) -> &'static str {
55 "MaxPool2d"
56 }
57 fn get_input_size(&self) -> usize {
58 1
59 }
60 fn get_output_size(&self) -> usize {
61 1
62 }
63 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
64 unimplemented!();
65 }
66 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
67 unimplemented!();
68 }
69 fn get_values(&self) -> Vec<Tensor> {
70 Vec::new()
71 }
72 fn get_grads(&self) -> Vec<Tensor> {
73 Vec::new()
74 }
75 fn set_values(&self, _v: &[Tensor]) {
76 }
77 #[cfg(feature = "use-serde")]
78 fn as_any(&self) -> &dyn Any {
79 self
80 }
81}
82
83