burn_tensor/tensor/
module.rs

1use crate::{
2    backend::Backend,
3    ops::{ConvOptions, ConvTransposeOptions, InterpolateOptions, UnfoldOptions},
4    Int, Tensor, TensorPrimitive,
5};
6
7use super::ops::DeformConvOptions;
8
9/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
10pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
11where
12    B: Backend,
13{
14    Tensor::new(TensorPrimitive::Float(B::embedding(
15        weights.primitive.tensor(),
16        indices.primitive,
17    )))
18}
19
20/// Applies a [1D convolution](crate::ops::ModuleOps::conv2d).
21pub fn conv1d<B>(
22    x: Tensor<B, 3>,
23    weight: Tensor<B, 3>,
24    bias: Option<Tensor<B, 1>>,
25    options: ConvOptions<1>,
26) -> Tensor<B, 3>
27where
28    B: Backend,
29{
30    Tensor::new(TensorPrimitive::Float(B::conv1d(
31        x.primitive.tensor(),
32        weight.primitive.tensor(),
33        bias.map(|b| b.primitive.tensor()),
34        options,
35    )))
36}
37
38/// Applies a [2D convolution](crate::ops::ModuleOps::conv2d).
39pub fn conv2d<B>(
40    x: Tensor<B, 4>,
41    weight: Tensor<B, 4>,
42    bias: Option<Tensor<B, 1>>,
43    options: ConvOptions<2>,
44) -> Tensor<B, 4>
45where
46    B: Backend,
47{
48    Tensor::new(TensorPrimitive::Float(B::conv2d(
49        x.primitive.tensor(),
50        weight.primitive.tensor(),
51        bias.map(|b| b.primitive.tensor()),
52        options,
53    )))
54}
55
56/// Applies a [3D convolution](crate::ops::ModuleOps::conv3d).
57pub fn conv3d<B>(
58    x: Tensor<B, 5>,
59    weight: Tensor<B, 5>,
60    bias: Option<Tensor<B, 1>>,
61    options: ConvOptions<3>,
62) -> Tensor<B, 5>
63where
64    B: Backend,
65{
66    Tensor::new(TensorPrimitive::Float(B::conv3d(
67        x.primitive.tensor(),
68        weight.primitive.tensor(),
69        bias.map(|b| b.primitive.tensor()),
70        options,
71    )))
72}
73
74/// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d).
75pub fn deform_conv2d<B>(
76    x: Tensor<B, 4>,
77    offset: Tensor<B, 4>,
78    weight: Tensor<B, 4>,
79    mask: Option<Tensor<B, 4>>,
80    bias: Option<Tensor<B, 1>>,
81    options: DeformConvOptions<2>,
82) -> Tensor<B, 4>
83where
84    B: Backend,
85{
86    Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
87        x.primitive.tensor(),
88        offset.primitive.tensor(),
89        weight.primitive.tensor(),
90        mask.map(|m| m.primitive.tensor()),
91        bias.map(|b| b.primitive.tensor()),
92        options,
93    )))
94}
95
96/// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d).
97pub fn conv_transpose1d<B>(
98    x: Tensor<B, 3>,
99    weight: Tensor<B, 3>,
100    bias: Option<Tensor<B, 1>>,
101    options: ConvTransposeOptions<1>,
102) -> Tensor<B, 3>
103where
104    B: Backend,
105{
106    Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(
107        x.primitive.tensor(),
108        weight.primitive.tensor(),
109        bias.map(|b| b.primitive.tensor()),
110        options,
111    )))
112}
113
114/// Applies a [2D transposed convolution](crate::ops::ModuleOps::conv_transpose2d).
115pub fn conv_transpose2d<B>(
116    x: Tensor<B, 4>,
117    weight: Tensor<B, 4>,
118    bias: Option<Tensor<B, 1>>,
119    options: ConvTransposeOptions<2>,
120) -> Tensor<B, 4>
121where
122    B: Backend,
123{
124    Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(
125        x.primitive.tensor(),
126        weight.primitive.tensor(),
127        bias.map(|b| b.primitive.tensor()),
128        options,
129    )))
130}
131
132/// Applies a 3D transposed convolution](crate::ops::ModuleOps::conv_transpose3d).
133pub fn conv_transpose3d<B>(
134    x: Tensor<B, 5>,
135    weight: Tensor<B, 5>,
136    bias: Option<Tensor<B, 1>>,
137    options: ConvTransposeOptions<3>,
138) -> Tensor<B, 5>
139where
140    B: Backend,
141{
142    Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(
143        x.primitive.tensor(),
144        weight.primitive.tensor(),
145        bias.map(|b| b.primitive.tensor()),
146        options,
147    )))
148}
149
150/// Applies a [4D to 3D unfold](crate::ops::ModuleOps::unfold4d).
151pub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>
152where
153    B: Backend,
154{
155    Tensor::new(TensorPrimitive::Float(B::unfold4d(
156        x.primitive.tensor(),
157        kernel_size,
158        options,
159    )))
160}
161
162/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
163pub fn max_pool1d<B>(
164    x: Tensor<B, 3>,
165    kernel_size: usize,
166    stride: usize,
167    padding: usize,
168    dilation: usize,
169) -> Tensor<B, 3>
170where
171    B: Backend,
172{
173    Tensor::new(TensorPrimitive::Float(B::max_pool1d(
174        x.primitive.tensor(),
175        kernel_size,
176        stride,
177        padding,
178        dilation,
179    )))
180}
181
182/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
183pub fn max_pool2d<B>(
184    x: Tensor<B, 4>,
185    kernel_size: [usize; 2],
186    stride: [usize; 2],
187    padding: [usize; 2],
188    dilation: [usize; 2],
189) -> Tensor<B, 4>
190where
191    B: Backend,
192{
193    Tensor::new(TensorPrimitive::Float(B::max_pool2d(
194        x.primitive.tensor(),
195        kernel_size,
196        stride,
197        padding,
198        dilation,
199    )))
200}
201
202/// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d).
203pub fn avg_pool2d<B>(
204    x: Tensor<B, 4>,
205    kernel_size: [usize; 2],
206    stride: [usize; 2],
207    padding: [usize; 2],
208    count_include_pad: bool,
209) -> Tensor<B, 4>
210where
211    B: Backend,
212{
213    Tensor::new(TensorPrimitive::Float(B::avg_pool2d(
214        x.primitive.tensor(),
215        kernel_size,
216        stride,
217        padding,
218        count_include_pad,
219    )))
220}
221
222/// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d).
223pub fn avg_pool1d<B>(
224    x: Tensor<B, 3>,
225    kernel_size: usize,
226    stride: usize,
227    padding: usize,
228    count_include_pad: bool,
229) -> Tensor<B, 3>
230where
231    B: Backend,
232{
233    Tensor::new(TensorPrimitive::Float(B::avg_pool1d(
234        x.primitive.tensor(),
235        kernel_size,
236        stride,
237        padding,
238        count_include_pad,
239    )))
240}
241
242/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
243pub fn max_pool1d_with_indices<B>(
244    x: Tensor<B, 3>,
245    kernel_size: usize,
246    stride: usize,
247    padding: usize,
248    dilation: usize,
249) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
250where
251    B: Backend,
252{
253    let output =
254        B::max_pool1d_with_indices(x.primitive.tensor(), kernel_size, stride, padding, dilation);
255
256    (
257        Tensor::new(TensorPrimitive::Float(output.output)),
258        Tensor::new(output.indices),
259    )
260}
261
262/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
263pub fn max_pool2d_with_indices<B>(
264    x: Tensor<B, 4>,
265    kernel_size: [usize; 2],
266    stride: [usize; 2],
267    padding: [usize; 2],
268    dilation: [usize; 2],
269) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
270where
271    B: Backend,
272{
273    let output =
274        B::max_pool2d_with_indices(x.primitive.tensor(), kernel_size, stride, padding, dilation);
275
276    (
277        Tensor::new(TensorPrimitive::Float(output.output)),
278        Tensor::new(output.indices),
279    )
280}
281
282/// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d).
283pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
284where
285    B: Backend,
286{
287    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
288        x.primitive.tensor(),
289        output_size,
290    )))
291}
292
293/// Applies a [1D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool1d).
294pub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>
295where
296    B: Backend,
297{
298    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(
299        x.primitive.tensor(),
300        output_size,
301    )))
302}
303
304/// Applies a [2D interpolation](crate::ops::ModuleOps::interpolate).
305pub fn interpolate<B>(
306    x: Tensor<B, 4>,
307    output_size: [usize; 2],
308    options: InterpolateOptions,
309) -> Tensor<B, 4>
310where
311    B: Backend,
312{
313    Tensor::new(TensorPrimitive::Float(B::interpolate(
314        x.primitive.tensor(),
315        output_size,
316        options,
317    )))
318}