1use crate::{
2 Bool, Int, Tensor, TensorPrimitive,
3 backend::Backend,
4 check,
5 check::TensorCheck,
6 ops::{
7 AttentionModuleOptions, ConvOptions, ConvTransposeOptions, InterpolateOptions, PadMode,
8 PaddedConvOptions, UnfoldOptions,
9 },
10};
11
12use super::ops::DeformConvOptions;
13
14pub fn ctc_loss<B>(
28 log_probs: Tensor<B, 3>,
29 targets: Tensor<B, 2, Int>,
30 input_lengths: Tensor<B, 1, Int>,
31 target_lengths: Tensor<B, 1, Int>,
32 blank: usize,
33) -> Tensor<B, 1>
34where
35 B: Backend,
36{
37 Tensor::new(TensorPrimitive::Float(B::ctc_loss(
38 log_probs.primitive.tensor(),
39 targets.primitive,
40 input_lengths.primitive,
41 target_lengths.primitive,
42 blank,
43 )))
44}
45
46pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
48where
49 B: Backend,
50{
51 Tensor::new(TensorPrimitive::Float(B::embedding(
52 weights.primitive.tensor(),
53 indices.primitive,
54 )))
55}
56
57pub fn conv1d<B>(
63 x: Tensor<B, 3>,
64 weight: Tensor<B, 3>,
65 bias: Option<Tensor<B, 1>>,
66 options: impl Into<PaddedConvOptions<1>>,
67) -> Tensor<B, 3>
68where
69 B: Backend,
70{
71 let padded_options = options.into();
72 check!(TensorCheck::conv(
73 "conv1d",
74 x.dims(),
75 weight.dims(),
76 padded_options.options.groups,
77 ));
78
79 if let Some(padding_end) = padded_options.padding_end {
80 let left = padded_options.options.padding[0];
81 let right = padding_end[0];
82 let padded = x.pad((left, right, 0, 0), PadMode::Constant(0.0));
84 let zero_options = ConvOptions::new(
85 padded_options.options.stride,
86 [0],
87 padded_options.options.dilation,
88 padded_options.options.groups,
89 );
90 Tensor::new(TensorPrimitive::Float(B::conv1d(
91 padded.primitive.tensor(),
92 weight.primitive.tensor(),
93 bias.map(|b| b.primitive.tensor()),
94 zero_options,
95 )))
96 } else {
97 Tensor::new(TensorPrimitive::Float(B::conv1d(
98 x.primitive.tensor(),
99 weight.primitive.tensor(),
100 bias.map(|b| b.primitive.tensor()),
101 padded_options.options,
102 )))
103 }
104}
105
106pub fn conv2d<B>(
112 x: Tensor<B, 4>,
113 weight: Tensor<B, 4>,
114 bias: Option<Tensor<B, 1>>,
115 options: impl Into<PaddedConvOptions<2>>,
116) -> Tensor<B, 4>
117where
118 B: Backend,
119{
120 let padded_options = options.into();
121 check!(TensorCheck::conv(
122 "conv2d",
123 x.dims(),
124 weight.dims(),
125 padded_options.options.groups,
126 ));
127
128 if let Some(padding_end) = padded_options.padding_end {
129 let top = padded_options.options.padding[0];
130 let left = padded_options.options.padding[1];
131 let bottom = padding_end[0];
132 let right = padding_end[1];
133 let padded = x.pad((left, right, top, bottom), PadMode::Constant(0.0));
135 let zero_options = ConvOptions::new(
136 padded_options.options.stride,
137 [0, 0],
138 padded_options.options.dilation,
139 padded_options.options.groups,
140 );
141 Tensor::new(TensorPrimitive::Float(B::conv2d(
142 padded.primitive.tensor(),
143 weight.primitive.tensor(),
144 bias.map(|b| b.primitive.tensor()),
145 zero_options,
146 )))
147 } else {
148 Tensor::new(TensorPrimitive::Float(B::conv2d(
149 x.primitive.tensor(),
150 weight.primitive.tensor(),
151 bias.map(|b| b.primitive.tensor()),
152 padded_options.options,
153 )))
154 }
155}
156
157pub fn conv3d<B>(
162 x: Tensor<B, 5>,
163 weight: Tensor<B, 5>,
164 bias: Option<Tensor<B, 1>>,
165 options: impl Into<PaddedConvOptions<3>>,
166) -> Tensor<B, 5>
167where
168 B: Backend,
169{
170 let padded_options = options.into();
171 check!(TensorCheck::conv(
172 "conv3d",
173 x.dims(),
174 weight.dims(),
175 padded_options.options.groups,
176 ));
177
178 if padded_options.is_asymmetric() {
179 panic!("Asymmetric padding is not yet supported for conv3d");
180 }
181
182 Tensor::new(TensorPrimitive::Float(B::conv3d(
183 x.primitive.tensor(),
184 weight.primitive.tensor(),
185 bias.map(|b| b.primitive.tensor()),
186 padded_options.options,
187 )))
188}
189
190pub fn deform_conv2d<B>(
192 x: Tensor<B, 4>,
193 offset: Tensor<B, 4>,
194 weight: Tensor<B, 4>,
195 mask: Option<Tensor<B, 4>>,
196 bias: Option<Tensor<B, 1>>,
197 options: DeformConvOptions<2>,
198) -> Tensor<B, 4>
199where
200 B: Backend,
201{
202 check!(TensorCheck::conv(
203 "deform_conv2d",
204 x.dims(),
205 weight.dims(),
206 options.weight_groups,
207 ));
208 Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
209 x.primitive.tensor(),
210 offset.primitive.tensor(),
211 weight.primitive.tensor(),
212 mask.map(|m| m.primitive.tensor()),
213 bias.map(|b| b.primitive.tensor()),
214 options,
215 )))
216}
217
218pub fn conv_transpose1d<B>(
220 x: Tensor<B, 3>,
221 weight: Tensor<B, 3>,
222 bias: Option<Tensor<B, 1>>,
223 options: ConvTransposeOptions<1>,
224) -> Tensor<B, 3>
225where
226 B: Backend,
227{
228 check!(TensorCheck::conv_transpose(
229 "conv_transpose1d",
230 x.dims(),
231 weight.dims(),
232 ));
233 Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(
234 x.primitive.tensor(),
235 weight.primitive.tensor(),
236 bias.map(|b| b.primitive.tensor()),
237 options,
238 )))
239}
240
241pub fn conv_transpose2d<B>(
243 x: Tensor<B, 4>,
244 weight: Tensor<B, 4>,
245 bias: Option<Tensor<B, 1>>,
246 options: ConvTransposeOptions<2>,
247) -> Tensor<B, 4>
248where
249 B: Backend,
250{
251 check!(TensorCheck::conv_transpose(
252 "conv_transpose2d",
253 x.dims(),
254 weight.dims(),
255 ));
256 Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(
257 x.primitive.tensor(),
258 weight.primitive.tensor(),
259 bias.map(|b| b.primitive.tensor()),
260 options,
261 )))
262}
263
264pub fn conv_transpose3d<B>(
266 x: Tensor<B, 5>,
267 weight: Tensor<B, 5>,
268 bias: Option<Tensor<B, 1>>,
269 options: ConvTransposeOptions<3>,
270) -> Tensor<B, 5>
271where
272 B: Backend,
273{
274 check!(TensorCheck::conv_transpose(
275 "conv_transpose3d",
276 x.dims(),
277 weight.dims(),
278 ));
279 Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(
280 x.primitive.tensor(),
281 weight.primitive.tensor(),
282 bias.map(|b| b.primitive.tensor()),
283 options,
284 )))
285}
286
287pub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>
289where
290 B: Backend,
291{
292 Tensor::new(TensorPrimitive::Float(B::unfold4d(
293 x.primitive.tensor(),
294 kernel_size,
295 options,
296 )))
297}
298
299pub fn max_pool1d<B>(
301 x: Tensor<B, 3>,
302 kernel_size: usize,
303 stride: usize,
304 padding: usize,
305 dilation: usize,
306 ceil_mode: bool,
307) -> Tensor<B, 3>
308where
309 B: Backend,
310{
311 Tensor::new(TensorPrimitive::Float(B::max_pool1d(
312 x.primitive.tensor(),
313 kernel_size,
314 stride,
315 padding,
316 dilation,
317 ceil_mode,
318 )))
319}
320
321pub fn max_pool2d<B>(
323 x: Tensor<B, 4>,
324 kernel_size: [usize; 2],
325 stride: [usize; 2],
326 padding: [usize; 2],
327 dilation: [usize; 2],
328 ceil_mode: bool,
329) -> Tensor<B, 4>
330where
331 B: Backend,
332{
333 Tensor::new(TensorPrimitive::Float(B::max_pool2d(
334 x.primitive.tensor(),
335 kernel_size,
336 stride,
337 padding,
338 dilation,
339 ceil_mode,
340 )))
341}
342
343pub fn avg_pool2d<B>(
345 x: Tensor<B, 4>,
346 kernel_size: [usize; 2],
347 stride: [usize; 2],
348 padding: [usize; 2],
349 count_include_pad: bool,
350 ceil_mode: bool,
351) -> Tensor<B, 4>
352where
353 B: Backend,
354{
355 Tensor::new(TensorPrimitive::Float(B::avg_pool2d(
356 x.primitive.tensor(),
357 kernel_size,
358 stride,
359 padding,
360 count_include_pad,
361 ceil_mode,
362 )))
363}
364
365pub fn avg_pool1d<B>(
367 x: Tensor<B, 3>,
368 kernel_size: usize,
369 stride: usize,
370 padding: usize,
371 count_include_pad: bool,
372 ceil_mode: bool,
373) -> Tensor<B, 3>
374where
375 B: Backend,
376{
377 Tensor::new(TensorPrimitive::Float(B::avg_pool1d(
378 x.primitive.tensor(),
379 kernel_size,
380 stride,
381 padding,
382 count_include_pad,
383 ceil_mode,
384 )))
385}
386
387pub fn max_pool1d_with_indices<B>(
389 x: Tensor<B, 3>,
390 kernel_size: usize,
391 stride: usize,
392 padding: usize,
393 dilation: usize,
394 ceil_mode: bool,
395) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
396where
397 B: Backend,
398{
399 let output = B::max_pool1d_with_indices(
400 x.primitive.tensor(),
401 kernel_size,
402 stride,
403 padding,
404 dilation,
405 ceil_mode,
406 );
407
408 (
409 Tensor::new(TensorPrimitive::Float(output.output)),
410 Tensor::new(output.indices),
411 )
412}
413
414pub fn max_pool2d_with_indices<B>(
416 x: Tensor<B, 4>,
417 kernel_size: [usize; 2],
418 stride: [usize; 2],
419 padding: [usize; 2],
420 dilation: [usize; 2],
421 ceil_mode: bool,
422) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
423where
424 B: Backend,
425{
426 let output = B::max_pool2d_with_indices(
427 x.primitive.tensor(),
428 kernel_size,
429 stride,
430 padding,
431 dilation,
432 ceil_mode,
433 );
434
435 (
436 Tensor::new(TensorPrimitive::Float(output.output)),
437 Tensor::new(output.indices),
438 )
439}
440
441pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
443where
444 B: Backend,
445{
446 Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
447 x.primitive.tensor(),
448 output_size,
449 )))
450}
451
452pub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>
454where
455 B: Backend,
456{
457 Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(
458 x.primitive.tensor(),
459 output_size,
460 )))
461}
462
463pub fn interpolate<B>(
465 x: Tensor<B, 4>,
466 output_size: [usize; 2],
467 options: InterpolateOptions,
468) -> Tensor<B, 4>
469where
470 B: Backend,
471{
472 Tensor::new(TensorPrimitive::Float(B::interpolate(
473 x.primitive.tensor(),
474 output_size,
475 options,
476 )))
477}
478
479pub fn linear<B: Backend, const D: usize>(
505 input: Tensor<B, D>,
506 weight: Tensor<B, 2>,
507 bias: Option<Tensor<B, 1>>,
508) -> Tensor<B, D> {
509 if D == 1 {
510 let input = input.unsqueeze::<2>();
512 let output = linear(input, weight, bias);
513 return output.squeeze_dim(0);
514 }
515
516 Tensor::new(TensorPrimitive::Float(B::linear(
517 input.primitive.tensor(),
518 weight.primitive.tensor(),
519 bias.map(|b| b.primitive.tensor()),
520 )))
521}
522
523pub fn attention<B: Backend>(
545 query: Tensor<B, 4>,
546 key: Tensor<B, 4>,
547 value: Tensor<B, 4>,
548 mask: Option<Tensor<B, 4, Bool>>,
549 attn_bias: Option<Tensor<B, 4>>,
550 options: AttentionModuleOptions,
551) -> Tensor<B, 4> {
552 Tensor::new(TensorPrimitive::Float(B::attention(
553 query.primitive.tensor(),
554 key.primitive.tensor(),
555 value.primitive.tensor(),
556 mask.map(|mask| mask.primitive),
557 attn_bias.map(|bias| bias.primitive.tensor()),
558 options,
559 )))
560}
561
562pub fn attention_fallback<B: Backend>(
564 query: Tensor<B, 4>,
565 key: Tensor<B, 4>,
566 value: Tensor<B, 4>,
567 mask: Option<Tensor<B, 4, Bool>>,
568 attn_bias: Option<Tensor<B, 4>>,
569 options: AttentionModuleOptions,
570) -> Tensor<B, 4> {
571 Tensor::new(TensorPrimitive::Float(
572 crate::ops::attention::attention_fallback::<B>(
573 query.primitive.tensor(),
574 key.primitive.tensor(),
575 value.primitive.tensor(),
576 mask.map(|mask| mask.primitive),
577 attn_bias.map(|bias| bias.primitive.tensor()),
578 options,
579 ),
580 ))
581}