burn_tensor/tensor/ops/modules/
pool.rs

1use crate::{
2    backend::Backend,
3    ops::{FloatTensor, IntTensor},
4    Shape, TensorMetadata,
5};
6
7use super::{MaxPool1dBackward, MaxPool1dWithIndices};
8
9pub(crate) fn avg_pool1d_from_2d<B: Backend>(
10    x: FloatTensor<B>,
11    kernel_size: usize,
12    stride: usize,
13    padding: usize,
14    count_include_pad: bool,
15) -> FloatTensor<B> {
16    let [batch_size, channels, length] = x.shape().dims();
17
18    let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
19    let x = B::avg_pool2d(
20        x,
21        [kernel_size, 1],
22        [stride, 1],
23        [padding, 0],
24        count_include_pad,
25    );
26
27    let [batch_size, channels, length, _] = x.shape().dims();
28
29    B::float_reshape(x, Shape::from([batch_size, channels, length]))
30}
31
32pub(crate) fn avg_pool1d_backward_from_2d<B: Backend>(
33    x: FloatTensor<B>,
34    grad: FloatTensor<B>,
35    kernel_size: usize,
36    stride: usize,
37    padding: usize,
38    count_include_pad: bool,
39) -> FloatTensor<B> {
40    let [batch_size, channels, length_in] = x.shape().dims();
41    let [_, _, length_out] = grad.shape().dims();
42
43    let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
44    let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
45
46    let grad_x = B::avg_pool2d_backward(
47        x,
48        grad_x,
49        [kernel_size, 1],
50        [stride, 1],
51        [padding, 0],
52        count_include_pad,
53    );
54
55    B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in]))
56}
57
58pub(crate) fn adaptive_avg_pool1d_from_2d<B: Backend>(
59    x: FloatTensor<B>,
60    output_size: usize,
61) -> FloatTensor<B> {
62    let [batch_size, channels, length] = x.shape().dims();
63
64    let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
65    let x = B::adaptive_avg_pool2d(x, [output_size, 1]);
66
67    let [batch_size, channels, length, _] = x.shape().dims();
68
69    B::float_reshape(x, Shape::from([batch_size, channels, length]))
70}
71
72pub(crate) fn adaptive_avg_pool1d_backward_from_2d<B: Backend>(
73    x: FloatTensor<B>,
74    grad: FloatTensor<B>,
75) -> FloatTensor<B> {
76    let [batch_size, channels, length_in] = x.shape().dims();
77    let [_, _, length_out] = grad.shape().dims();
78
79    let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
80    let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
81
82    let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x);
83
84    B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in]))
85}
86
87pub(crate) fn max_pool1d_from_2d<B: Backend>(
88    x: FloatTensor<B>,
89    kernel_size: usize,
90    stride: usize,
91    padding: usize,
92    dilation: usize,
93) -> FloatTensor<B> {
94    let [batch_size, channels, length] = x.shape().dims();
95
96    let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
97    let x = B::max_pool2d(
98        x,
99        [kernel_size, 1],
100        [stride, 1],
101        [padding, 0],
102        [dilation, 1],
103    );
104
105    let [batch_size, channels, length, _] = x.shape().dims();
106
107    B::float_reshape(x, Shape::from([batch_size, channels, length]))
108}
109
110pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
111    x: FloatTensor<B>,
112    kernel_size: usize,
113    stride: usize,
114    padding: usize,
115    dilation: usize,
116) -> MaxPool1dWithIndices<B> {
117    let [batch_size, channels, length] = x.shape().dims();
118
119    let x = B::float_reshape(x, Shape::from([batch_size, channels, 1, length]));
120    let x = B::max_pool2d_with_indices(
121        x,
122        [1, kernel_size],
123        [1, stride],
124        [0, padding],
125        [1, dilation],
126    );
127    let [batch_size, channels, _, length] = x.output.shape().dims();
128    let output = B::float_reshape(x.output, Shape::from([batch_size, channels, length]));
129    let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length]));
130    MaxPool1dWithIndices::new(output, indices)
131}
132
133pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
134    x: FloatTensor<B>,
135    kernel_size: usize,
136    stride: usize,
137    padding: usize,
138    dilation: usize,
139    output_grad: FloatTensor<B>,
140    indices: IntTensor<B>,
141) -> MaxPool1dBackward<B> {
142    let [batch_size, channels, length_in] = x.shape().dims();
143    let [_, _, length_out] = output_grad.shape().dims();
144
145    let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
146    let grad_x = B::float_reshape(
147        output_grad,
148        Shape::from([batch_size, channels, length_out, 1]),
149    );
150    let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1]));
151
152    let grad_x = B::max_pool2d_with_indices_backward(
153        x,
154        [kernel_size, 1],
155        [stride, 1],
156        [padding, 0],
157        [dilation, 1],
158        grad_x,
159        indices,
160    )
161    .x_grad;
162
163    MaxPool1dBackward::new(B::float_reshape(
164        grad_x,
165        Shape::from([batch_size, channels, length_in]),
166    ))
167}