1use crate::{
2 Bool, Int, Tensor, TensorPrimitive,
3 backend::Backend,
4 check,
5 check::TensorCheck,
6 ops::{ConvOptions, ConvTransposeOptions, InterpolateOptions, UnfoldOptions},
7};
8
9use super::ops::DeformConvOptions;
10
11pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
13where
14 B: Backend,
15{
16 Tensor::new(TensorPrimitive::Float(B::embedding(
17 weights.primitive.tensor(),
18 indices.primitive,
19 )))
20}
21
22pub fn conv1d<B>(
24 x: Tensor<B, 3>,
25 weight: Tensor<B, 3>,
26 bias: Option<Tensor<B, 1>>,
27 options: ConvOptions<1>,
28) -> Tensor<B, 3>
29where
30 B: Backend,
31{
32 check!(TensorCheck::conv(
33 "conv1d",
34 x.dims(),
35 weight.dims(),
36 options.groups,
37 ));
38 Tensor::new(TensorPrimitive::Float(B::conv1d(
39 x.primitive.tensor(),
40 weight.primitive.tensor(),
41 bias.map(|b| b.primitive.tensor()),
42 options,
43 )))
44}
45
46pub fn conv2d<B>(
48 x: Tensor<B, 4>,
49 weight: Tensor<B, 4>,
50 bias: Option<Tensor<B, 1>>,
51 options: ConvOptions<2>,
52) -> Tensor<B, 4>
53where
54 B: Backend,
55{
56 check!(TensorCheck::conv(
57 "conv2d",
58 x.dims(),
59 weight.dims(),
60 options.groups,
61 ));
62 Tensor::new(TensorPrimitive::Float(B::conv2d(
63 x.primitive.tensor(),
64 weight.primitive.tensor(),
65 bias.map(|b| b.primitive.tensor()),
66 options,
67 )))
68}
69
70pub fn conv3d<B>(
72 x: Tensor<B, 5>,
73 weight: Tensor<B, 5>,
74 bias: Option<Tensor<B, 1>>,
75 options: ConvOptions<3>,
76) -> Tensor<B, 5>
77where
78 B: Backend,
79{
80 check!(TensorCheck::conv(
81 "conv3d",
82 x.dims(),
83 weight.dims(),
84 options.groups,
85 ));
86 Tensor::new(TensorPrimitive::Float(B::conv3d(
87 x.primitive.tensor(),
88 weight.primitive.tensor(),
89 bias.map(|b| b.primitive.tensor()),
90 options,
91 )))
92}
93
94pub fn deform_conv2d<B>(
96 x: Tensor<B, 4>,
97 offset: Tensor<B, 4>,
98 weight: Tensor<B, 4>,
99 mask: Option<Tensor<B, 4>>,
100 bias: Option<Tensor<B, 1>>,
101 options: DeformConvOptions<2>,
102) -> Tensor<B, 4>
103where
104 B: Backend,
105{
106 check!(TensorCheck::conv(
107 "deform_conv2d",
108 x.dims(),
109 weight.dims(),
110 options.weight_groups,
111 ));
112 Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
113 x.primitive.tensor(),
114 offset.primitive.tensor(),
115 weight.primitive.tensor(),
116 mask.map(|m| m.primitive.tensor()),
117 bias.map(|b| b.primitive.tensor()),
118 options,
119 )))
120}
121
122pub fn conv_transpose1d<B>(
124 x: Tensor<B, 3>,
125 weight: Tensor<B, 3>,
126 bias: Option<Tensor<B, 1>>,
127 options: ConvTransposeOptions<1>,
128) -> Tensor<B, 3>
129where
130 B: Backend,
131{
132 check!(TensorCheck::conv_transpose(
133 "conv_transpose1d",
134 x.dims(),
135 weight.dims(),
136 ));
137 Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(
138 x.primitive.tensor(),
139 weight.primitive.tensor(),
140 bias.map(|b| b.primitive.tensor()),
141 options,
142 )))
143}
144
145pub fn conv_transpose2d<B>(
147 x: Tensor<B, 4>,
148 weight: Tensor<B, 4>,
149 bias: Option<Tensor<B, 1>>,
150 options: ConvTransposeOptions<2>,
151) -> Tensor<B, 4>
152where
153 B: Backend,
154{
155 check!(TensorCheck::conv_transpose(
156 "conv_transpose2d",
157 x.dims(),
158 weight.dims(),
159 ));
160 Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(
161 x.primitive.tensor(),
162 weight.primitive.tensor(),
163 bias.map(|b| b.primitive.tensor()),
164 options,
165 )))
166}
167
168pub fn conv_transpose3d<B>(
170 x: Tensor<B, 5>,
171 weight: Tensor<B, 5>,
172 bias: Option<Tensor<B, 1>>,
173 options: ConvTransposeOptions<3>,
174) -> Tensor<B, 5>
175where
176 B: Backend,
177{
178 check!(TensorCheck::conv_transpose(
179 "conv_transpose3d",
180 x.dims(),
181 weight.dims(),
182 ));
183 Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(
184 x.primitive.tensor(),
185 weight.primitive.tensor(),
186 bias.map(|b| b.primitive.tensor()),
187 options,
188 )))
189}
190
191pub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>
193where
194 B: Backend,
195{
196 Tensor::new(TensorPrimitive::Float(B::unfold4d(
197 x.primitive.tensor(),
198 kernel_size,
199 options,
200 )))
201}
202
203pub fn max_pool1d<B>(
205 x: Tensor<B, 3>,
206 kernel_size: usize,
207 stride: usize,
208 padding: usize,
209 dilation: usize,
210 ceil_mode: bool,
211) -> Tensor<B, 3>
212where
213 B: Backend,
214{
215 Tensor::new(TensorPrimitive::Float(B::max_pool1d(
216 x.primitive.tensor(),
217 kernel_size,
218 stride,
219 padding,
220 dilation,
221 ceil_mode,
222 )))
223}
224
225pub fn max_pool2d<B>(
227 x: Tensor<B, 4>,
228 kernel_size: [usize; 2],
229 stride: [usize; 2],
230 padding: [usize; 2],
231 dilation: [usize; 2],
232 ceil_mode: bool,
233) -> Tensor<B, 4>
234where
235 B: Backend,
236{
237 Tensor::new(TensorPrimitive::Float(B::max_pool2d(
238 x.primitive.tensor(),
239 kernel_size,
240 stride,
241 padding,
242 dilation,
243 ceil_mode,
244 )))
245}
246
247pub fn avg_pool2d<B>(
249 x: Tensor<B, 4>,
250 kernel_size: [usize; 2],
251 stride: [usize; 2],
252 padding: [usize; 2],
253 count_include_pad: bool,
254 ceil_mode: bool,
255) -> Tensor<B, 4>
256where
257 B: Backend,
258{
259 Tensor::new(TensorPrimitive::Float(B::avg_pool2d(
260 x.primitive.tensor(),
261 kernel_size,
262 stride,
263 padding,
264 count_include_pad,
265 ceil_mode,
266 )))
267}
268
269pub fn avg_pool1d<B>(
271 x: Tensor<B, 3>,
272 kernel_size: usize,
273 stride: usize,
274 padding: usize,
275 count_include_pad: bool,
276 ceil_mode: bool,
277) -> Tensor<B, 3>
278where
279 B: Backend,
280{
281 Tensor::new(TensorPrimitive::Float(B::avg_pool1d(
282 x.primitive.tensor(),
283 kernel_size,
284 stride,
285 padding,
286 count_include_pad,
287 ceil_mode,
288 )))
289}
290
291pub fn max_pool1d_with_indices<B>(
293 x: Tensor<B, 3>,
294 kernel_size: usize,
295 stride: usize,
296 padding: usize,
297 dilation: usize,
298 ceil_mode: bool,
299) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
300where
301 B: Backend,
302{
303 let output = B::max_pool1d_with_indices(
304 x.primitive.tensor(),
305 kernel_size,
306 stride,
307 padding,
308 dilation,
309 ceil_mode,
310 );
311
312 (
313 Tensor::new(TensorPrimitive::Float(output.output)),
314 Tensor::new(output.indices),
315 )
316}
317
318pub fn max_pool2d_with_indices<B>(
320 x: Tensor<B, 4>,
321 kernel_size: [usize; 2],
322 stride: [usize; 2],
323 padding: [usize; 2],
324 dilation: [usize; 2],
325 ceil_mode: bool,
326) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
327where
328 B: Backend,
329{
330 let output = B::max_pool2d_with_indices(
331 x.primitive.tensor(),
332 kernel_size,
333 stride,
334 padding,
335 dilation,
336 ceil_mode,
337 );
338
339 (
340 Tensor::new(TensorPrimitive::Float(output.output)),
341 Tensor::new(output.indices),
342 )
343}
344
345pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
347where
348 B: Backend,
349{
350 Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
351 x.primitive.tensor(),
352 output_size,
353 )))
354}
355
356pub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>
358where
359 B: Backend,
360{
361 Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(
362 x.primitive.tensor(),
363 output_size,
364 )))
365}
366
367pub fn interpolate<B>(
369 x: Tensor<B, 4>,
370 output_size: [usize; 2],
371 options: InterpolateOptions,
372) -> Tensor<B, 4>
373where
374 B: Backend,
375{
376 Tensor::new(TensorPrimitive::Float(B::interpolate(
377 x.primitive.tensor(),
378 output_size,
379 options,
380 )))
381}
382
383pub fn linear<B: Backend, const D: usize>(
409 input: Tensor<B, D>,
410 weight: Tensor<B, 2>,
411 bias: Option<Tensor<B, 1>>,
412) -> Tensor<B, D> {
413 if D == 1 {
414 let input = input.unsqueeze::<2>();
416 let output = linear(input, weight, bias);
417 return output.squeeze_dim(0);
418 }
419
420 let weight = weight.unsqueeze::<D>();
424 let bias = bias.map(|bias| bias.unsqueeze::<D>());
425
426 let output = input.matmul(weight);
427 match bias {
428 Some(bias) => output.add(bias),
429 None => output,
430 }
431}
432
433pub fn attention<B: Backend>(
451 query: Tensor<B, 4>,
452 key: Tensor<B, 4>,
453 value: Tensor<B, 4>,
454 mask: Option<Tensor<B, 4, Bool>>,
455) -> Tensor<B, 4> {
456 Tensor::new(TensorPrimitive::Float(B::attention(
457 query.primitive.tensor(),
458 key.primitive.tensor(),
459 value.primitive.tensor(),
460 mask.map(|mask| mask.primitive),
461 )))
462}
463
464pub fn naive_attention<B: Backend>(
466 query: Tensor<B, 4>,
467 key: Tensor<B, 4>,
468 value: Tensor<B, 4>,
469 mask: Option<Tensor<B, 4, Bool>>,
470) -> Tensor<B, 4> {
471 Tensor::new(TensorPrimitive::Float(
472 crate::ops::attention::naive_attention::<B>(
473 query.primitive.tensor(),
474 key.primitive.tensor(),
475 value.primitive.tensor(),
476 mask.map(|mask| mask.primitive),
477 ),
478 ))
479}