1use crate::{
2 backend::Backend,
3 ops::{FloatTensor, IntTensor},
4 Shape, TensorMetadata,
5};
6
7use super::{MaxPool1dBackward, MaxPool1dWithIndices};
8
9pub(crate) fn avg_pool1d_from_2d<B: Backend>(
10 x: FloatTensor<B>,
11 kernel_size: usize,
12 stride: usize,
13 padding: usize,
14 count_include_pad: bool,
15) -> FloatTensor<B> {
16 let [batch_size, channels, length] = x.shape().dims();
17
18 let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
19 let x = B::avg_pool2d(
20 x,
21 [kernel_size, 1],
22 [stride, 1],
23 [padding, 0],
24 count_include_pad,
25 );
26
27 let [batch_size, channels, length, _] = x.shape().dims();
28
29 B::float_reshape(x, Shape::from([batch_size, channels, length]))
30}
31
32pub(crate) fn avg_pool1d_backward_from_2d<B: Backend>(
33 x: FloatTensor<B>,
34 grad: FloatTensor<B>,
35 kernel_size: usize,
36 stride: usize,
37 padding: usize,
38 count_include_pad: bool,
39) -> FloatTensor<B> {
40 let [batch_size, channels, length_in] = x.shape().dims();
41 let [_, _, length_out] = grad.shape().dims();
42
43 let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
44 let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
45
46 let grad_x = B::avg_pool2d_backward(
47 x,
48 grad_x,
49 [kernel_size, 1],
50 [stride, 1],
51 [padding, 0],
52 count_include_pad,
53 );
54
55 B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in]))
56}
57
58pub(crate) fn adaptive_avg_pool1d_from_2d<B: Backend>(
59 x: FloatTensor<B>,
60 output_size: usize,
61) -> FloatTensor<B> {
62 let [batch_size, channels, length] = x.shape().dims();
63
64 let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
65 let x = B::adaptive_avg_pool2d(x, [output_size, 1]);
66
67 let [batch_size, channels, length, _] = x.shape().dims();
68
69 B::float_reshape(x, Shape::from([batch_size, channels, length]))
70}
71
72pub(crate) fn adaptive_avg_pool1d_backward_from_2d<B: Backend>(
73 x: FloatTensor<B>,
74 grad: FloatTensor<B>,
75) -> FloatTensor<B> {
76 let [batch_size, channels, length_in] = x.shape().dims();
77 let [_, _, length_out] = grad.shape().dims();
78
79 let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
80 let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1]));
81
82 let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x);
83
84 B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in]))
85}
86
87pub(crate) fn max_pool1d_from_2d<B: Backend>(
88 x: FloatTensor<B>,
89 kernel_size: usize,
90 stride: usize,
91 padding: usize,
92 dilation: usize,
93) -> FloatTensor<B> {
94 let [batch_size, channels, length] = x.shape().dims();
95
96 let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));
97 let x = B::max_pool2d(
98 x,
99 [kernel_size, 1],
100 [stride, 1],
101 [padding, 0],
102 [dilation, 1],
103 );
104
105 let [batch_size, channels, length, _] = x.shape().dims();
106
107 B::float_reshape(x, Shape::from([batch_size, channels, length]))
108}
109
110pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
111 x: FloatTensor<B>,
112 kernel_size: usize,
113 stride: usize,
114 padding: usize,
115 dilation: usize,
116) -> MaxPool1dWithIndices<B> {
117 let [batch_size, channels, length] = x.shape().dims();
118
119 let x = B::float_reshape(x, Shape::from([batch_size, channels, 1, length]));
120 let x = B::max_pool2d_with_indices(
121 x,
122 [1, kernel_size],
123 [1, stride],
124 [0, padding],
125 [1, dilation],
126 );
127 let [batch_size, channels, _, length] = x.output.shape().dims();
128 let output = B::float_reshape(x.output, Shape::from([batch_size, channels, length]));
129 let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length]));
130 MaxPool1dWithIndices::new(output, indices)
131}
132
133pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
134 x: FloatTensor<B>,
135 kernel_size: usize,
136 stride: usize,
137 padding: usize,
138 dilation: usize,
139 output_grad: FloatTensor<B>,
140 indices: IntTensor<B>,
141) -> MaxPool1dBackward<B> {
142 let [batch_size, channels, length_in] = x.shape().dims();
143 let [_, _, length_out] = output_grad.shape().dims();
144
145 let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));
146 let grad_x = B::float_reshape(
147 output_grad,
148 Shape::from([batch_size, channels, length_out, 1]),
149 );
150 let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1]));
151
152 let grad_x = B::max_pool2d_with_indices_backward(
153 x,
154 [kernel_size, 1],
155 [stride, 1],
156 [padding, 0],
157 [dilation, 1],
158 grad_x,
159 indices,
160 )
161 .x_grad;
162
163 MaxPool1dBackward::new(B::float_reshape(
164 grad_x,
165 Shape::from([batch_size, channels, length_in]),
166 ))
167}