burn_tensor/tensor/
module.rs

1use crate::{
2    Bool, Int, Tensor, TensorPrimitive,
3    backend::Backend,
4    check,
5    check::TensorCheck,
6    ops::{ConvOptions, ConvTransposeOptions, InterpolateOptions, UnfoldOptions},
7};
8
9use super::ops::DeformConvOptions;
10
11/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
12pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
13where
14    B: Backend,
15{
16    Tensor::new(TensorPrimitive::Float(B::embedding(
17        weights.primitive.tensor(),
18        indices.primitive,
19    )))
20}
21
22/// Applies a [1D convolution](crate::ops::ModuleOps::conv2d).
23pub fn conv1d<B>(
24    x: Tensor<B, 3>,
25    weight: Tensor<B, 3>,
26    bias: Option<Tensor<B, 1>>,
27    options: ConvOptions<1>,
28) -> Tensor<B, 3>
29where
30    B: Backend,
31{
32    check!(TensorCheck::conv(
33        "conv1d",
34        x.dims(),
35        weight.dims(),
36        options.groups,
37    ));
38    Tensor::new(TensorPrimitive::Float(B::conv1d(
39        x.primitive.tensor(),
40        weight.primitive.tensor(),
41        bias.map(|b| b.primitive.tensor()),
42        options,
43    )))
44}
45
46/// Applies a [2D convolution](crate::ops::ModuleOps::conv2d).
47pub fn conv2d<B>(
48    x: Tensor<B, 4>,
49    weight: Tensor<B, 4>,
50    bias: Option<Tensor<B, 1>>,
51    options: ConvOptions<2>,
52) -> Tensor<B, 4>
53where
54    B: Backend,
55{
56    check!(TensorCheck::conv(
57        "conv2d",
58        x.dims(),
59        weight.dims(),
60        options.groups,
61    ));
62    Tensor::new(TensorPrimitive::Float(B::conv2d(
63        x.primitive.tensor(),
64        weight.primitive.tensor(),
65        bias.map(|b| b.primitive.tensor()),
66        options,
67    )))
68}
69
70/// Applies a [3D convolution](crate::ops::ModuleOps::conv3d).
71pub fn conv3d<B>(
72    x: Tensor<B, 5>,
73    weight: Tensor<B, 5>,
74    bias: Option<Tensor<B, 1>>,
75    options: ConvOptions<3>,
76) -> Tensor<B, 5>
77where
78    B: Backend,
79{
80    check!(TensorCheck::conv(
81        "conv3d",
82        x.dims(),
83        weight.dims(),
84        options.groups,
85    ));
86    Tensor::new(TensorPrimitive::Float(B::conv3d(
87        x.primitive.tensor(),
88        weight.primitive.tensor(),
89        bias.map(|b| b.primitive.tensor()),
90        options,
91    )))
92}
93
94/// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d).
95pub fn deform_conv2d<B>(
96    x: Tensor<B, 4>,
97    offset: Tensor<B, 4>,
98    weight: Tensor<B, 4>,
99    mask: Option<Tensor<B, 4>>,
100    bias: Option<Tensor<B, 1>>,
101    options: DeformConvOptions<2>,
102) -> Tensor<B, 4>
103where
104    B: Backend,
105{
106    check!(TensorCheck::conv(
107        "deform_conv2d",
108        x.dims(),
109        weight.dims(),
110        options.weight_groups,
111    ));
112    Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
113        x.primitive.tensor(),
114        offset.primitive.tensor(),
115        weight.primitive.tensor(),
116        mask.map(|m| m.primitive.tensor()),
117        bias.map(|b| b.primitive.tensor()),
118        options,
119    )))
120}
121
122/// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d).
123pub fn conv_transpose1d<B>(
124    x: Tensor<B, 3>,
125    weight: Tensor<B, 3>,
126    bias: Option<Tensor<B, 1>>,
127    options: ConvTransposeOptions<1>,
128) -> Tensor<B, 3>
129where
130    B: Backend,
131{
132    check!(TensorCheck::conv_transpose(
133        "conv_transpose1d",
134        x.dims(),
135        weight.dims(),
136    ));
137    Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(
138        x.primitive.tensor(),
139        weight.primitive.tensor(),
140        bias.map(|b| b.primitive.tensor()),
141        options,
142    )))
143}
144
145/// Applies a [2D transposed convolution](crate::ops::ModuleOps::conv_transpose2d).
146pub fn conv_transpose2d<B>(
147    x: Tensor<B, 4>,
148    weight: Tensor<B, 4>,
149    bias: Option<Tensor<B, 1>>,
150    options: ConvTransposeOptions<2>,
151) -> Tensor<B, 4>
152where
153    B: Backend,
154{
155    check!(TensorCheck::conv_transpose(
156        "conv_transpose2d",
157        x.dims(),
158        weight.dims(),
159    ));
160    Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(
161        x.primitive.tensor(),
162        weight.primitive.tensor(),
163        bias.map(|b| b.primitive.tensor()),
164        options,
165    )))
166}
167
168/// Applies a 3D transposed convolution](crate::ops::ModuleOps::conv_transpose3d).
169pub fn conv_transpose3d<B>(
170    x: Tensor<B, 5>,
171    weight: Tensor<B, 5>,
172    bias: Option<Tensor<B, 1>>,
173    options: ConvTransposeOptions<3>,
174) -> Tensor<B, 5>
175where
176    B: Backend,
177{
178    check!(TensorCheck::conv_transpose(
179        "conv_transpose3d",
180        x.dims(),
181        weight.dims(),
182    ));
183    Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(
184        x.primitive.tensor(),
185        weight.primitive.tensor(),
186        bias.map(|b| b.primitive.tensor()),
187        options,
188    )))
189}
190
191/// Applies a [4D to 3D unfold](crate::ops::ModuleOps::unfold4d).
192pub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>
193where
194    B: Backend,
195{
196    Tensor::new(TensorPrimitive::Float(B::unfold4d(
197        x.primitive.tensor(),
198        kernel_size,
199        options,
200    )))
201}
202
203/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
204pub fn max_pool1d<B>(
205    x: Tensor<B, 3>,
206    kernel_size: usize,
207    stride: usize,
208    padding: usize,
209    dilation: usize,
210    ceil_mode: bool,
211) -> Tensor<B, 3>
212where
213    B: Backend,
214{
215    Tensor::new(TensorPrimitive::Float(B::max_pool1d(
216        x.primitive.tensor(),
217        kernel_size,
218        stride,
219        padding,
220        dilation,
221        ceil_mode,
222    )))
223}
224
225/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
226pub fn max_pool2d<B>(
227    x: Tensor<B, 4>,
228    kernel_size: [usize; 2],
229    stride: [usize; 2],
230    padding: [usize; 2],
231    dilation: [usize; 2],
232    ceil_mode: bool,
233) -> Tensor<B, 4>
234where
235    B: Backend,
236{
237    Tensor::new(TensorPrimitive::Float(B::max_pool2d(
238        x.primitive.tensor(),
239        kernel_size,
240        stride,
241        padding,
242        dilation,
243        ceil_mode,
244    )))
245}
246
247/// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d).
248pub fn avg_pool2d<B>(
249    x: Tensor<B, 4>,
250    kernel_size: [usize; 2],
251    stride: [usize; 2],
252    padding: [usize; 2],
253    count_include_pad: bool,
254    ceil_mode: bool,
255) -> Tensor<B, 4>
256where
257    B: Backend,
258{
259    Tensor::new(TensorPrimitive::Float(B::avg_pool2d(
260        x.primitive.tensor(),
261        kernel_size,
262        stride,
263        padding,
264        count_include_pad,
265        ceil_mode,
266    )))
267}
268
269/// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d).
270pub fn avg_pool1d<B>(
271    x: Tensor<B, 3>,
272    kernel_size: usize,
273    stride: usize,
274    padding: usize,
275    count_include_pad: bool,
276    ceil_mode: bool,
277) -> Tensor<B, 3>
278where
279    B: Backend,
280{
281    Tensor::new(TensorPrimitive::Float(B::avg_pool1d(
282        x.primitive.tensor(),
283        kernel_size,
284        stride,
285        padding,
286        count_include_pad,
287        ceil_mode,
288    )))
289}
290
291/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
292pub fn max_pool1d_with_indices<B>(
293    x: Tensor<B, 3>,
294    kernel_size: usize,
295    stride: usize,
296    padding: usize,
297    dilation: usize,
298    ceil_mode: bool,
299) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
300where
301    B: Backend,
302{
303    let output = B::max_pool1d_with_indices(
304        x.primitive.tensor(),
305        kernel_size,
306        stride,
307        padding,
308        dilation,
309        ceil_mode,
310    );
311
312    (
313        Tensor::new(TensorPrimitive::Float(output.output)),
314        Tensor::new(output.indices),
315    )
316}
317
318/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
319pub fn max_pool2d_with_indices<B>(
320    x: Tensor<B, 4>,
321    kernel_size: [usize; 2],
322    stride: [usize; 2],
323    padding: [usize; 2],
324    dilation: [usize; 2],
325    ceil_mode: bool,
326) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
327where
328    B: Backend,
329{
330    let output = B::max_pool2d_with_indices(
331        x.primitive.tensor(),
332        kernel_size,
333        stride,
334        padding,
335        dilation,
336        ceil_mode,
337    );
338
339    (
340        Tensor::new(TensorPrimitive::Float(output.output)),
341        Tensor::new(output.indices),
342    )
343}
344
345/// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d).
346pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
347where
348    B: Backend,
349{
350    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
351        x.primitive.tensor(),
352        output_size,
353    )))
354}
355
356/// Applies a [1D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool1d).
357pub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>
358where
359    B: Backend,
360{
361    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(
362        x.primitive.tensor(),
363        output_size,
364    )))
365}
366
367/// Applies a [2D interpolation](crate::ops::ModuleOps::interpolate).
368pub fn interpolate<B>(
369    x: Tensor<B, 4>,
370    output_size: [usize; 2],
371    options: InterpolateOptions,
372) -> Tensor<B, 4>
373where
374    B: Backend,
375{
376    Tensor::new(TensorPrimitive::Float(B::interpolate(
377        x.primitive.tensor(),
378        output_size,
379        options,
380    )))
381}
382
383/// Applies a linear transformation to the input tensor using the given weight and bias.
384///
385/// ```math
386/// y = x @ weight + [bias]
387/// ```
388///
389/// # Arguments:
390///
391/// - `input` is the input tensor, ``[..., d_input]``.
392/// - `weight` is the weight tensor, ``[d_input, d_output]``.
393/// - `bias` is the bias tensor (optional), ``[d_output]``.
394///
395/// # Returns:
396///
397/// The transformed tensor, ``[..., d_output]``.
398///
399/// # Compatibility
400///
401/// This function differs from PyTorch's ``torch.nn.functional.linear`` in that it does not
402/// transpose the weight matrix. In PyTorch, the weight matrix is transposed before
403/// multiplication:
404///
405/// ```math
406/// y = x @ weight^T + [bias]
407/// ```
408pub fn linear<B: Backend, const D: usize>(
409    input: Tensor<B, D>,
410    weight: Tensor<B, 2>,
411    bias: Option<Tensor<B, 1>>,
412) -> Tensor<B, D> {
413    if D == 1 {
414        // Insert and remove an extra batch dimension for the batch matmul to work.
415        let input = input.unsqueeze::<2>();
416        let output = linear(input, weight, bias);
417        return output.squeeze_dim(0);
418    }
419
420    // Perform broadcasting
421    //
422    // Important to be done before doing operations to easily fuse.
423    let weight = weight.unsqueeze::<D>();
424    let bias = bias.map(|bias| bias.unsqueeze::<D>());
425
426    let output = input.matmul(weight);
427    match bias {
428        Some(bias) => output.add(bias),
429        None => output,
430    }
431}
432
433/// Computes scaled dot-product attention: softmax(QKᵗ / √d) · V,
434/// optionally applying a 4D mask to the attention scores.
435///
436/// # Arguments
437/// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`
438/// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
439/// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
440/// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
441///   where `true` indicates positions to mask (i.e. set to -∞ before softmax).
442///
443/// # Returns
444/// A tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`
445/// representing the attended context per head.
446///
447/// # Note
448/// This implementation does not support dropout and is intended for inference or
449/// use cases where dropout is not needed.
450pub fn attention<B: Backend>(
451    query: Tensor<B, 4>,
452    key: Tensor<B, 4>,
453    value: Tensor<B, 4>,
454    mask: Option<Tensor<B, 4, Bool>>,
455) -> Tensor<B, 4> {
456    Tensor::new(TensorPrimitive::Float(B::attention(
457        query.primitive.tensor(),
458        key.primitive.tensor(),
459        value.primitive.tensor(),
460        mask.map(|mask| mask.primitive),
461    )))
462}
463
464/// Exports naive attention to test backend's attention against
465pub fn naive_attention<B: Backend>(
466    query: Tensor<B, 4>,
467    key: Tensor<B, 4>,
468    value: Tensor<B, 4>,
469    mask: Option<Tensor<B, 4, Bool>>,
470) -> Tensor<B, 4> {
471    Tensor::new(TensorPrimitive::Float(
472        crate::ops::attention::naive_attention::<B>(
473            query.primitive.tensor(),
474            key.primitive.tensor(),
475            value.primitive.tensor(),
476            mask.map(|mask| mask.primitive),
477        ),
478    ))
479}