burn_tensor/tensor/
module.rs

1use crate::{
2    Int, Tensor, TensorPrimitive,
3    backend::Backend,
4    check,
5    check::TensorCheck,
6    ops::{ConvOptions, ConvTransposeOptions, InterpolateOptions, UnfoldOptions},
7};
8
9use super::ops::DeformConvOptions;
10
11/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
12pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
13where
14    B: Backend,
15{
16    Tensor::new(TensorPrimitive::Float(B::embedding(
17        weights.primitive.tensor(),
18        indices.primitive,
19    )))
20}
21
22/// Applies a [1D convolution](crate::ops::ModuleOps::conv2d).
23pub fn conv1d<B>(
24    x: Tensor<B, 3>,
25    weight: Tensor<B, 3>,
26    bias: Option<Tensor<B, 1>>,
27    options: ConvOptions<1>,
28) -> Tensor<B, 3>
29where
30    B: Backend,
31{
32    check!(TensorCheck::conv(
33        "conv1d",
34        x.dims(),
35        weight.dims(),
36        options.groups,
37    ));
38    Tensor::new(TensorPrimitive::Float(B::conv1d(
39        x.primitive.tensor(),
40        weight.primitive.tensor(),
41        bias.map(|b| b.primitive.tensor()),
42        options,
43    )))
44}
45
46/// Applies a [2D convolution](crate::ops::ModuleOps::conv2d).
47pub fn conv2d<B>(
48    x: Tensor<B, 4>,
49    weight: Tensor<B, 4>,
50    bias: Option<Tensor<B, 1>>,
51    options: ConvOptions<2>,
52) -> Tensor<B, 4>
53where
54    B: Backend,
55{
56    check!(TensorCheck::conv(
57        "conv2d",
58        x.dims(),
59        weight.dims(),
60        options.groups,
61    ));
62    Tensor::new(TensorPrimitive::Float(B::conv2d(
63        x.primitive.tensor(),
64        weight.primitive.tensor(),
65        bias.map(|b| b.primitive.tensor()),
66        options,
67    )))
68}
69
70/// Applies a [3D convolution](crate::ops::ModuleOps::conv3d).
71pub fn conv3d<B>(
72    x: Tensor<B, 5>,
73    weight: Tensor<B, 5>,
74    bias: Option<Tensor<B, 1>>,
75    options: ConvOptions<3>,
76) -> Tensor<B, 5>
77where
78    B: Backend,
79{
80    check!(TensorCheck::conv(
81        "conv3d",
82        x.dims(),
83        weight.dims(),
84        options.groups,
85    ));
86    Tensor::new(TensorPrimitive::Float(B::conv3d(
87        x.primitive.tensor(),
88        weight.primitive.tensor(),
89        bias.map(|b| b.primitive.tensor()),
90        options,
91    )))
92}
93
94/// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d).
95pub fn deform_conv2d<B>(
96    x: Tensor<B, 4>,
97    offset: Tensor<B, 4>,
98    weight: Tensor<B, 4>,
99    mask: Option<Tensor<B, 4>>,
100    bias: Option<Tensor<B, 1>>,
101    options: DeformConvOptions<2>,
102) -> Tensor<B, 4>
103where
104    B: Backend,
105{
106    check!(TensorCheck::conv(
107        "deform_conv2d",
108        x.dims(),
109        weight.dims(),
110        options.weight_groups,
111    ));
112    Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
113        x.primitive.tensor(),
114        offset.primitive.tensor(),
115        weight.primitive.tensor(),
116        mask.map(|m| m.primitive.tensor()),
117        bias.map(|b| b.primitive.tensor()),
118        options,
119    )))
120}
121
122/// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d).
123pub fn conv_transpose1d<B>(
124    x: Tensor<B, 3>,
125    weight: Tensor<B, 3>,
126    bias: Option<Tensor<B, 1>>,
127    options: ConvTransposeOptions<1>,
128) -> Tensor<B, 3>
129where
130    B: Backend,
131{
132    check!(TensorCheck::conv_transpose(
133        "conv_transpose1d",
134        x.dims(),
135        weight.dims(),
136    ));
137    Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(
138        x.primitive.tensor(),
139        weight.primitive.tensor(),
140        bias.map(|b| b.primitive.tensor()),
141        options,
142    )))
143}
144
145/// Applies a [2D transposed convolution](crate::ops::ModuleOps::conv_transpose2d).
146pub fn conv_transpose2d<B>(
147    x: Tensor<B, 4>,
148    weight: Tensor<B, 4>,
149    bias: Option<Tensor<B, 1>>,
150    options: ConvTransposeOptions<2>,
151) -> Tensor<B, 4>
152where
153    B: Backend,
154{
155    check!(TensorCheck::conv_transpose(
156        "conv_transpose2d",
157        x.dims(),
158        weight.dims(),
159    ));
160    Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(
161        x.primitive.tensor(),
162        weight.primitive.tensor(),
163        bias.map(|b| b.primitive.tensor()),
164        options,
165    )))
166}
167
168/// Applies a 3D transposed convolution](crate::ops::ModuleOps::conv_transpose3d).
169pub fn conv_transpose3d<B>(
170    x: Tensor<B, 5>,
171    weight: Tensor<B, 5>,
172    bias: Option<Tensor<B, 1>>,
173    options: ConvTransposeOptions<3>,
174) -> Tensor<B, 5>
175where
176    B: Backend,
177{
178    check!(TensorCheck::conv_transpose(
179        "conv_transpose3d",
180        x.dims(),
181        weight.dims(),
182    ));
183    Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(
184        x.primitive.tensor(),
185        weight.primitive.tensor(),
186        bias.map(|b| b.primitive.tensor()),
187        options,
188    )))
189}
190
191/// Applies a [4D to 3D unfold](crate::ops::ModuleOps::unfold4d).
192pub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>
193where
194    B: Backend,
195{
196    Tensor::new(TensorPrimitive::Float(B::unfold4d(
197        x.primitive.tensor(),
198        kernel_size,
199        options,
200    )))
201}
202
203/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
204pub fn max_pool1d<B>(
205    x: Tensor<B, 3>,
206    kernel_size: usize,
207    stride: usize,
208    padding: usize,
209    dilation: usize,
210) -> Tensor<B, 3>
211where
212    B: Backend,
213{
214    Tensor::new(TensorPrimitive::Float(B::max_pool1d(
215        x.primitive.tensor(),
216        kernel_size,
217        stride,
218        padding,
219        dilation,
220    )))
221}
222
223/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
224pub fn max_pool2d<B>(
225    x: Tensor<B, 4>,
226    kernel_size: [usize; 2],
227    stride: [usize; 2],
228    padding: [usize; 2],
229    dilation: [usize; 2],
230) -> Tensor<B, 4>
231where
232    B: Backend,
233{
234    Tensor::new(TensorPrimitive::Float(B::max_pool2d(
235        x.primitive.tensor(),
236        kernel_size,
237        stride,
238        padding,
239        dilation,
240    )))
241}
242
243/// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d).
244pub fn avg_pool2d<B>(
245    x: Tensor<B, 4>,
246    kernel_size: [usize; 2],
247    stride: [usize; 2],
248    padding: [usize; 2],
249    count_include_pad: bool,
250) -> Tensor<B, 4>
251where
252    B: Backend,
253{
254    Tensor::new(TensorPrimitive::Float(B::avg_pool2d(
255        x.primitive.tensor(),
256        kernel_size,
257        stride,
258        padding,
259        count_include_pad,
260    )))
261}
262
263/// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d).
264pub fn avg_pool1d<B>(
265    x: Tensor<B, 3>,
266    kernel_size: usize,
267    stride: usize,
268    padding: usize,
269    count_include_pad: bool,
270) -> Tensor<B, 3>
271where
272    B: Backend,
273{
274    Tensor::new(TensorPrimitive::Float(B::avg_pool1d(
275        x.primitive.tensor(),
276        kernel_size,
277        stride,
278        padding,
279        count_include_pad,
280    )))
281}
282
283/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
284pub fn max_pool1d_with_indices<B>(
285    x: Tensor<B, 3>,
286    kernel_size: usize,
287    stride: usize,
288    padding: usize,
289    dilation: usize,
290) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
291where
292    B: Backend,
293{
294    let output =
295        B::max_pool1d_with_indices(x.primitive.tensor(), kernel_size, stride, padding, dilation);
296
297    (
298        Tensor::new(TensorPrimitive::Float(output.output)),
299        Tensor::new(output.indices),
300    )
301}
302
303/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
304pub fn max_pool2d_with_indices<B>(
305    x: Tensor<B, 4>,
306    kernel_size: [usize; 2],
307    stride: [usize; 2],
308    padding: [usize; 2],
309    dilation: [usize; 2],
310) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
311where
312    B: Backend,
313{
314    let output =
315        B::max_pool2d_with_indices(x.primitive.tensor(), kernel_size, stride, padding, dilation);
316
317    (
318        Tensor::new(TensorPrimitive::Float(output.output)),
319        Tensor::new(output.indices),
320    )
321}
322
323/// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d).
324pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
325where
326    B: Backend,
327{
328    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
329        x.primitive.tensor(),
330        output_size,
331    )))
332}
333
334/// Applies a [1D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool1d).
335pub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>
336where
337    B: Backend,
338{
339    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(
340        x.primitive.tensor(),
341        output_size,
342    )))
343}
344
345/// Applies a [2D interpolation](crate::ops::ModuleOps::interpolate).
346pub fn interpolate<B>(
347    x: Tensor<B, 4>,
348    output_size: [usize; 2],
349    options: InterpolateOptions,
350) -> Tensor<B, 4>
351where
352    B: Backend,
353{
354    Tensor::new(TensorPrimitive::Float(B::interpolate(
355        x.primitive.tensor(),
356        output_size,
357        options,
358    )))
359}
360
361/// Applies a linear transformation to the input tensor using the given weight and bias.
362///
363/// ```math
364/// y = x @ weight + [bias]
365/// ```
366///
367/// # Arguments:
368///
369/// - `input` is the input tensor, ``[..., d_input]``.
370/// - `weight` is the weight tensor, ``[d_input, d_output]``.
371/// - `bias` is the bias tensor (optional), ``[d_output]``.
372///
373/// # Returns:
374///
375/// The transformed tensor, ``[..., d_output]``.
376///
377/// # Compatibility
378///
379/// This function differs from PyTorch's ``torch.nn.functional.linear`` in that it does not
380/// transpose the weight matrix. In PyTorch, the weight matrix is transposed before
381/// multiplication:
382///
383/// ```math
384/// y = x @ weight^T + [bias]
385/// ```
386pub fn linear<B: Backend, const D: usize>(
387    input: Tensor<B, D>,
388    weight: Tensor<B, 2>,
389    bias: Option<Tensor<B, 1>>,
390) -> Tensor<B, D> {
391    if D == 1 {
392        // Insert and remove an extra batch dimension for the batch matmul to work.
393        let input = input.unsqueeze::<2>();
394        let output = linear(input, weight, bias);
395        return output.squeeze_dim(0);
396    }
397
398    // Perform broadcasting
399    //
400    // Important to be done before doing operations to easily fuse.
401    let weight = weight.unsqueeze::<D>();
402    let bias = bias.map(|bias| bias.unsqueeze::<D>());
403
404    let output = input.matmul(weight);
405    match bias {
406        Some(bias) => output.add(bias),
407        None => output,
408    }
409}