1use crate::{
2 backend::Backend,
3 ops::{ConvOptions, ConvTransposeOptions, InterpolateOptions, UnfoldOptions},
4 Int, Tensor, TensorPrimitive,
5};
6
7use super::ops::DeformConvOptions;
8
9pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
11where
12 B: Backend,
13{
14 Tensor::new(TensorPrimitive::Float(B::embedding(
15 weights.primitive.tensor(),
16 indices.primitive,
17 )))
18}
19
20pub fn conv1d<B>(
22 x: Tensor<B, 3>,
23 weight: Tensor<B, 3>,
24 bias: Option<Tensor<B, 1>>,
25 options: ConvOptions<1>,
26) -> Tensor<B, 3>
27where
28 B: Backend,
29{
30 Tensor::new(TensorPrimitive::Float(B::conv1d(
31 x.primitive.tensor(),
32 weight.primitive.tensor(),
33 bias.map(|b| b.primitive.tensor()),
34 options,
35 )))
36}
37
38pub fn conv2d<B>(
40 x: Tensor<B, 4>,
41 weight: Tensor<B, 4>,
42 bias: Option<Tensor<B, 1>>,
43 options: ConvOptions<2>,
44) -> Tensor<B, 4>
45where
46 B: Backend,
47{
48 Tensor::new(TensorPrimitive::Float(B::conv2d(
49 x.primitive.tensor(),
50 weight.primitive.tensor(),
51 bias.map(|b| b.primitive.tensor()),
52 options,
53 )))
54}
55
56pub fn conv3d<B>(
58 x: Tensor<B, 5>,
59 weight: Tensor<B, 5>,
60 bias: Option<Tensor<B, 1>>,
61 options: ConvOptions<3>,
62) -> Tensor<B, 5>
63where
64 B: Backend,
65{
66 Tensor::new(TensorPrimitive::Float(B::conv3d(
67 x.primitive.tensor(),
68 weight.primitive.tensor(),
69 bias.map(|b| b.primitive.tensor()),
70 options,
71 )))
72}
73
74pub fn deform_conv2d<B>(
76 x: Tensor<B, 4>,
77 offset: Tensor<B, 4>,
78 weight: Tensor<B, 4>,
79 mask: Option<Tensor<B, 4>>,
80 bias: Option<Tensor<B, 1>>,
81 options: DeformConvOptions<2>,
82) -> Tensor<B, 4>
83where
84 B: Backend,
85{
86 Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
87 x.primitive.tensor(),
88 offset.primitive.tensor(),
89 weight.primitive.tensor(),
90 mask.map(|m| m.primitive.tensor()),
91 bias.map(|b| b.primitive.tensor()),
92 options,
93 )))
94}
95
96pub fn conv_transpose1d<B>(
98 x: Tensor<B, 3>,
99 weight: Tensor<B, 3>,
100 bias: Option<Tensor<B, 1>>,
101 options: ConvTransposeOptions<1>,
102) -> Tensor<B, 3>
103where
104 B: Backend,
105{
106 Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(
107 x.primitive.tensor(),
108 weight.primitive.tensor(),
109 bias.map(|b| b.primitive.tensor()),
110 options,
111 )))
112}
113
114pub fn conv_transpose2d<B>(
116 x: Tensor<B, 4>,
117 weight: Tensor<B, 4>,
118 bias: Option<Tensor<B, 1>>,
119 options: ConvTransposeOptions<2>,
120) -> Tensor<B, 4>
121where
122 B: Backend,
123{
124 Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(
125 x.primitive.tensor(),
126 weight.primitive.tensor(),
127 bias.map(|b| b.primitive.tensor()),
128 options,
129 )))
130}
131
132pub fn conv_transpose3d<B>(
134 x: Tensor<B, 5>,
135 weight: Tensor<B, 5>,
136 bias: Option<Tensor<B, 1>>,
137 options: ConvTransposeOptions<3>,
138) -> Tensor<B, 5>
139where
140 B: Backend,
141{
142 Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(
143 x.primitive.tensor(),
144 weight.primitive.tensor(),
145 bias.map(|b| b.primitive.tensor()),
146 options,
147 )))
148}
149
150pub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>
152where
153 B: Backend,
154{
155 Tensor::new(TensorPrimitive::Float(B::unfold4d(
156 x.primitive.tensor(),
157 kernel_size,
158 options,
159 )))
160}
161
162pub fn max_pool1d<B>(
164 x: Tensor<B, 3>,
165 kernel_size: usize,
166 stride: usize,
167 padding: usize,
168 dilation: usize,
169) -> Tensor<B, 3>
170where
171 B: Backend,
172{
173 Tensor::new(TensorPrimitive::Float(B::max_pool1d(
174 x.primitive.tensor(),
175 kernel_size,
176 stride,
177 padding,
178 dilation,
179 )))
180}
181
182pub fn max_pool2d<B>(
184 x: Tensor<B, 4>,
185 kernel_size: [usize; 2],
186 stride: [usize; 2],
187 padding: [usize; 2],
188 dilation: [usize; 2],
189) -> Tensor<B, 4>
190where
191 B: Backend,
192{
193 Tensor::new(TensorPrimitive::Float(B::max_pool2d(
194 x.primitive.tensor(),
195 kernel_size,
196 stride,
197 padding,
198 dilation,
199 )))
200}
201
202pub fn avg_pool2d<B>(
204 x: Tensor<B, 4>,
205 kernel_size: [usize; 2],
206 stride: [usize; 2],
207 padding: [usize; 2],
208 count_include_pad: bool,
209) -> Tensor<B, 4>
210where
211 B: Backend,
212{
213 Tensor::new(TensorPrimitive::Float(B::avg_pool2d(
214 x.primitive.tensor(),
215 kernel_size,
216 stride,
217 padding,
218 count_include_pad,
219 )))
220}
221
222pub fn avg_pool1d<B>(
224 x: Tensor<B, 3>,
225 kernel_size: usize,
226 stride: usize,
227 padding: usize,
228 count_include_pad: bool,
229) -> Tensor<B, 3>
230where
231 B: Backend,
232{
233 Tensor::new(TensorPrimitive::Float(B::avg_pool1d(
234 x.primitive.tensor(),
235 kernel_size,
236 stride,
237 padding,
238 count_include_pad,
239 )))
240}
241
242pub fn max_pool1d_with_indices<B>(
244 x: Tensor<B, 3>,
245 kernel_size: usize,
246 stride: usize,
247 padding: usize,
248 dilation: usize,
249) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
250where
251 B: Backend,
252{
253 let output =
254 B::max_pool1d_with_indices(x.primitive.tensor(), kernel_size, stride, padding, dilation);
255
256 (
257 Tensor::new(TensorPrimitive::Float(output.output)),
258 Tensor::new(output.indices),
259 )
260}
261
262pub fn max_pool2d_with_indices<B>(
264 x: Tensor<B, 4>,
265 kernel_size: [usize; 2],
266 stride: [usize; 2],
267 padding: [usize; 2],
268 dilation: [usize; 2],
269) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
270where
271 B: Backend,
272{
273 let output =
274 B::max_pool2d_with_indices(x.primitive.tensor(), kernel_size, stride, padding, dilation);
275
276 (
277 Tensor::new(TensorPrimitive::Float(output.output)),
278 Tensor::new(output.indices),
279 )
280}
281
282pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
284where
285 B: Backend,
286{
287 Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
288 x.primitive.tensor(),
289 output_size,
290 )))
291}
292
293pub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>
295where
296 B: Backend,
297{
298 Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(
299 x.primitive.tensor(),
300 output_size,
301 )))
302}
303
304pub fn interpolate<B>(
306 x: Tensor<B, 4>,
307 output_size: [usize; 2],
308 options: InterpolateOptions,
309) -> Tensor<B, 4>
310where
311 B: Backend,
312{
313 Tensor::new(TensorPrimitive::Float(B::interpolate(
314 x.primitive.tensor(),
315 output_size,
316 options,
317 )))
318}