1use crate::{
2 Int, Tensor, TensorPrimitive,
3 backend::Backend,
4 check,
5 check::TensorCheck,
6 ops::{ConvOptions, ConvTransposeOptions, InterpolateOptions, UnfoldOptions},
7};
8
9use super::ops::DeformConvOptions;
10
11pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
13where
14 B: Backend,
15{
16 Tensor::new(TensorPrimitive::Float(B::embedding(
17 weights.primitive.tensor(),
18 indices.primitive,
19 )))
20}
21
22pub fn conv1d<B>(
24 x: Tensor<B, 3>,
25 weight: Tensor<B, 3>,
26 bias: Option<Tensor<B, 1>>,
27 options: ConvOptions<1>,
28) -> Tensor<B, 3>
29where
30 B: Backend,
31{
32 check!(TensorCheck::conv(
33 "conv1d",
34 x.dims(),
35 weight.dims(),
36 options.groups,
37 ));
38 Tensor::new(TensorPrimitive::Float(B::conv1d(
39 x.primitive.tensor(),
40 weight.primitive.tensor(),
41 bias.map(|b| b.primitive.tensor()),
42 options,
43 )))
44}
45
46pub fn conv2d<B>(
48 x: Tensor<B, 4>,
49 weight: Tensor<B, 4>,
50 bias: Option<Tensor<B, 1>>,
51 options: ConvOptions<2>,
52) -> Tensor<B, 4>
53where
54 B: Backend,
55{
56 check!(TensorCheck::conv(
57 "conv2d",
58 x.dims(),
59 weight.dims(),
60 options.groups,
61 ));
62 Tensor::new(TensorPrimitive::Float(B::conv2d(
63 x.primitive.tensor(),
64 weight.primitive.tensor(),
65 bias.map(|b| b.primitive.tensor()),
66 options,
67 )))
68}
69
70pub fn conv3d<B>(
72 x: Tensor<B, 5>,
73 weight: Tensor<B, 5>,
74 bias: Option<Tensor<B, 1>>,
75 options: ConvOptions<3>,
76) -> Tensor<B, 5>
77where
78 B: Backend,
79{
80 check!(TensorCheck::conv(
81 "conv3d",
82 x.dims(),
83 weight.dims(),
84 options.groups,
85 ));
86 Tensor::new(TensorPrimitive::Float(B::conv3d(
87 x.primitive.tensor(),
88 weight.primitive.tensor(),
89 bias.map(|b| b.primitive.tensor()),
90 options,
91 )))
92}
93
94pub fn deform_conv2d<B>(
96 x: Tensor<B, 4>,
97 offset: Tensor<B, 4>,
98 weight: Tensor<B, 4>,
99 mask: Option<Tensor<B, 4>>,
100 bias: Option<Tensor<B, 1>>,
101 options: DeformConvOptions<2>,
102) -> Tensor<B, 4>
103where
104 B: Backend,
105{
106 check!(TensorCheck::conv(
107 "deform_conv2d",
108 x.dims(),
109 weight.dims(),
110 options.weight_groups,
111 ));
112 Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
113 x.primitive.tensor(),
114 offset.primitive.tensor(),
115 weight.primitive.tensor(),
116 mask.map(|m| m.primitive.tensor()),
117 bias.map(|b| b.primitive.tensor()),
118 options,
119 )))
120}
121
122pub fn conv_transpose1d<B>(
124 x: Tensor<B, 3>,
125 weight: Tensor<B, 3>,
126 bias: Option<Tensor<B, 1>>,
127 options: ConvTransposeOptions<1>,
128) -> Tensor<B, 3>
129where
130 B: Backend,
131{
132 check!(TensorCheck::conv_transpose(
133 "conv_transpose1d",
134 x.dims(),
135 weight.dims(),
136 ));
137 Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(
138 x.primitive.tensor(),
139 weight.primitive.tensor(),
140 bias.map(|b| b.primitive.tensor()),
141 options,
142 )))
143}
144
145pub fn conv_transpose2d<B>(
147 x: Tensor<B, 4>,
148 weight: Tensor<B, 4>,
149 bias: Option<Tensor<B, 1>>,
150 options: ConvTransposeOptions<2>,
151) -> Tensor<B, 4>
152where
153 B: Backend,
154{
155 check!(TensorCheck::conv_transpose(
156 "conv_transpose2d",
157 x.dims(),
158 weight.dims(),
159 ));
160 Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(
161 x.primitive.tensor(),
162 weight.primitive.tensor(),
163 bias.map(|b| b.primitive.tensor()),
164 options,
165 )))
166}
167
168pub fn conv_transpose3d<B>(
170 x: Tensor<B, 5>,
171 weight: Tensor<B, 5>,
172 bias: Option<Tensor<B, 1>>,
173 options: ConvTransposeOptions<3>,
174) -> Tensor<B, 5>
175where
176 B: Backend,
177{
178 check!(TensorCheck::conv_transpose(
179 "conv_transpose3d",
180 x.dims(),
181 weight.dims(),
182 ));
183 Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(
184 x.primitive.tensor(),
185 weight.primitive.tensor(),
186 bias.map(|b| b.primitive.tensor()),
187 options,
188 )))
189}
190
191pub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>
193where
194 B: Backend,
195{
196 Tensor::new(TensorPrimitive::Float(B::unfold4d(
197 x.primitive.tensor(),
198 kernel_size,
199 options,
200 )))
201}
202
203pub fn max_pool1d<B>(
205 x: Tensor<B, 3>,
206 kernel_size: usize,
207 stride: usize,
208 padding: usize,
209 dilation: usize,
210) -> Tensor<B, 3>
211where
212 B: Backend,
213{
214 Tensor::new(TensorPrimitive::Float(B::max_pool1d(
215 x.primitive.tensor(),
216 kernel_size,
217 stride,
218 padding,
219 dilation,
220 )))
221}
222
223pub fn max_pool2d<B>(
225 x: Tensor<B, 4>,
226 kernel_size: [usize; 2],
227 stride: [usize; 2],
228 padding: [usize; 2],
229 dilation: [usize; 2],
230) -> Tensor<B, 4>
231where
232 B: Backend,
233{
234 Tensor::new(TensorPrimitive::Float(B::max_pool2d(
235 x.primitive.tensor(),
236 kernel_size,
237 stride,
238 padding,
239 dilation,
240 )))
241}
242
243pub fn avg_pool2d<B>(
245 x: Tensor<B, 4>,
246 kernel_size: [usize; 2],
247 stride: [usize; 2],
248 padding: [usize; 2],
249 count_include_pad: bool,
250) -> Tensor<B, 4>
251where
252 B: Backend,
253{
254 Tensor::new(TensorPrimitive::Float(B::avg_pool2d(
255 x.primitive.tensor(),
256 kernel_size,
257 stride,
258 padding,
259 count_include_pad,
260 )))
261}
262
263pub fn avg_pool1d<B>(
265 x: Tensor<B, 3>,
266 kernel_size: usize,
267 stride: usize,
268 padding: usize,
269 count_include_pad: bool,
270) -> Tensor<B, 3>
271where
272 B: Backend,
273{
274 Tensor::new(TensorPrimitive::Float(B::avg_pool1d(
275 x.primitive.tensor(),
276 kernel_size,
277 stride,
278 padding,
279 count_include_pad,
280 )))
281}
282
283pub fn max_pool1d_with_indices<B>(
285 x: Tensor<B, 3>,
286 kernel_size: usize,
287 stride: usize,
288 padding: usize,
289 dilation: usize,
290) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
291where
292 B: Backend,
293{
294 let output =
295 B::max_pool1d_with_indices(x.primitive.tensor(), kernel_size, stride, padding, dilation);
296
297 (
298 Tensor::new(TensorPrimitive::Float(output.output)),
299 Tensor::new(output.indices),
300 )
301}
302
303pub fn max_pool2d_with_indices<B>(
305 x: Tensor<B, 4>,
306 kernel_size: [usize; 2],
307 stride: [usize; 2],
308 padding: [usize; 2],
309 dilation: [usize; 2],
310) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
311where
312 B: Backend,
313{
314 let output =
315 B::max_pool2d_with_indices(x.primitive.tensor(), kernel_size, stride, padding, dilation);
316
317 (
318 Tensor::new(TensorPrimitive::Float(output.output)),
319 Tensor::new(output.indices),
320 )
321}
322
323pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
325where
326 B: Backend,
327{
328 Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
329 x.primitive.tensor(),
330 output_size,
331 )))
332}
333
334pub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>
336where
337 B: Backend,
338{
339 Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(
340 x.primitive.tensor(),
341 output_size,
342 )))
343}
344
345pub fn interpolate<B>(
347 x: Tensor<B, 4>,
348 output_size: [usize; 2],
349 options: InterpolateOptions,
350) -> Tensor<B, 4>
351where
352 B: Backend,
353{
354 Tensor::new(TensorPrimitive::Float(B::interpolate(
355 x.primitive.tensor(),
356 output_size,
357 options,
358 )))
359}
360
361pub fn linear<B: Backend, const D: usize>(
387 input: Tensor<B, D>,
388 weight: Tensor<B, 2>,
389 bias: Option<Tensor<B, 1>>,
390) -> Tensor<B, D> {
391 if D == 1 {
392 let input = input.unsqueeze::<2>();
394 let output = linear(input, weight, bias);
395 return output.squeeze_dim(0);
396 }
397
398 let weight = weight.unsqueeze::<D>();
402 let bias = bias.map(|bias| bias.unsqueeze::<D>());
403
404 let output = input.matmul(weight);
405 match bias {
406 Some(bias) => output.add(bias),
407 None => output,
408 }
409}