1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
use crate::{backend::Backend, Shape};

use super::{MaxPool1dBackward, MaxPool1dWithIndices};

pub(crate) fn avg_pool1d_from_2d<B: Backend>(
    x: B::TensorPrimitive<3>,
    kernel_size: usize,
    stride: usize,
    padding: usize,
    count_include_pad: bool,
) -> B::TensorPrimitive<3> {
    let [batch_size, channels, length] = B::shape(&x).dims;

    let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
    let x = B::avg_pool2d(
        x,
        [kernel_size, 1],
        [stride, 1],
        [padding, 0],
        count_include_pad,
    );

    let [batch_size, channels, length, _] = B::shape(&x).dims;

    B::reshape(x, Shape::from([batch_size, channels, length]))
}

pub(crate) fn avg_pool1d_backward_from_2d<B: Backend>(
    x: B::TensorPrimitive<3>,
    grad: B::TensorPrimitive<3>,
    kernel_size: usize,
    stride: usize,
    padding: usize,
    count_include_pad: bool,
) -> B::TensorPrimitive<3> {
    let [batch_size, channels, length_in] = B::shape(&x).dims;
    let [_, _, length_out] = B::shape(&grad).dims;

    let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1]));
    let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1]));

    let grad_x = B::avg_pool2d_backward(
        x,
        grad_x,
        [kernel_size, 1],
        [stride, 1],
        [padding, 0],
        count_include_pad,
    );

    B::reshape(grad_x, Shape::from([batch_size, channels, length_in]))
}

pub(crate) fn adaptive_avg_pool1d_from_2d<B: Backend>(
    x: B::TensorPrimitive<3>,
    output_size: usize,
) -> B::TensorPrimitive<3> {
    let [batch_size, channels, length] = B::shape(&x).dims;

    let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
    let x = B::adaptive_avg_pool2d(x, [output_size, 1]);

    let [batch_size, channels, length, _] = B::shape(&x).dims;

    B::reshape(x, Shape::from([batch_size, channels, length]))
}

pub(crate) fn adaptive_avg_pool1d_backward_from_2d<B: Backend>(
    x: B::TensorPrimitive<3>,
    grad: B::TensorPrimitive<3>,
) -> B::TensorPrimitive<3> {
    let [batch_size, channels, length_in] = B::shape(&x).dims;
    let [_, _, length_out] = B::shape(&grad).dims;

    let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1]));
    let grad_x = B::reshape(grad, Shape::from([batch_size, channels, length_out, 1]));

    let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x);

    B::reshape(grad_x, Shape::from([batch_size, channels, length_in]))
}

pub(crate) fn max_pool1d_from_2d<B: Backend>(
    x: B::TensorPrimitive<3>,
    kernel_size: usize,
    stride: usize,
    padding: usize,
    dilation: usize,
) -> B::TensorPrimitive<3> {
    let [batch_size, channels, length] = B::shape(&x).dims;

    let x = B::reshape(x, Shape::from([batch_size, channels, length, 1]));
    let x = B::max_pool2d(
        x,
        [kernel_size, 1],
        [stride, 1],
        [padding, 0],
        [dilation, 1],
    );

    let [batch_size, channels, length, _] = B::shape(&x).dims;

    B::reshape(x, Shape::from([batch_size, channels, length]))
}

pub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(
    x: B::TensorPrimitive<3>,
    kernel_size: usize,
    stride: usize,
    padding: usize,
    dilation: usize,
) -> MaxPool1dWithIndices<B> {
    let [batch_size, channels, length] = B::shape(&x).dims;

    let x = B::reshape(x, Shape::from([batch_size, channels, 1, length]));
    let x = B::max_pool2d_with_indices(
        x,
        [1, kernel_size],
        [1, stride],
        [0, padding],
        [1, dilation],
    );
    let [batch_size, channels, _, length] = B::shape(&x.output).dims;
    let output = B::reshape(x.output, Shape::from([batch_size, channels, length]));
    let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length]));
    MaxPool1dWithIndices::new(output, indices)
}

pub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(
    x: B::TensorPrimitive<3>,
    kernel_size: usize,
    stride: usize,
    padding: usize,
    dilation: usize,
    output_grad: B::TensorPrimitive<3>,
    indices: B::IntTensorPrimitive<3>,
) -> MaxPool1dBackward<B> {
    let [batch_size, channels, length_in] = B::shape(&x).dims;
    let [_, _, length_out] = B::shape(&output_grad).dims;

    let x = B::reshape(x, Shape::from([batch_size, channels, length_in, 1]));
    let grad_x = B::reshape(
        output_grad,
        Shape::from([batch_size, channels, length_out, 1]),
    );
    let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1]));

    let grad_x = B::max_pool2d_with_indices_backward(
        x,
        [kernel_size, 1],
        [stride, 1],
        [padding, 0],
        [dilation, 1],
        grad_x,
        indices,
    )
    .x_grad;

    MaxPool1dBackward::new(B::reshape(
        grad_x,
        Shape::from([batch_size, channels, length_in]),
    ))
}