Skip to main content

burn_tensor/tensor/
module.rs

1use crate::{
2    Bool, Int, Tensor, TensorPrimitive,
3    backend::Backend,
4    check,
5    check::TensorCheck,
6    ops::{
7        AttentionModuleOptions, ConvOptions, ConvTransposeOptions, InterpolateOptions, PadMode,
8        PaddedConvOptions, UnfoldOptions,
9    },
10};
11
12use super::ops::DeformConvOptions;
13
14/// Computes the [CTC loss](crate::ops::ModuleOps::ctc_loss).
15///
16/// # Arguments
17///
18/// * `log_probs` - Log-probabilities of shape `[T, N, C]`
19/// * `targets` - Target label indices of shape `[N, S]`
20/// * `input_lengths` - Actual input sequence lengths per batch element `[N]`
21/// * `target_lengths` - Actual target lengths per batch element `[N]`
22/// * `blank` - Index of the blank label
23///
24/// # Returns
25///
26/// Per-sample loss of shape `[N]`
27pub fn ctc_loss<B>(
28    log_probs: Tensor<B, 3>,
29    targets: Tensor<B, 2, Int>,
30    input_lengths: Tensor<B, 1, Int>,
31    target_lengths: Tensor<B, 1, Int>,
32    blank: usize,
33) -> Tensor<B, 1>
34where
35    B: Backend,
36{
37    Tensor::new(TensorPrimitive::Float(B::ctc_loss(
38        log_probs.primitive.tensor(),
39        targets.primitive,
40        input_lengths.primitive,
41        target_lengths.primitive,
42        blank,
43    )))
44}
45
46/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
47pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
48where
49    B: Backend,
50{
51    Tensor::new(TensorPrimitive::Float(B::embedding(
52        weights.primitive.tensor(),
53        indices.primitive,
54    )))
55}
56
57/// Applies a [1D convolution](crate::ops::ModuleOps::conv1d).
58///
59/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for
60/// asymmetric padding. When asymmetric padding is specified, an explicit pad
61/// operation is applied before the convolution backend op.
62pub fn conv1d<B>(
63    x: Tensor<B, 3>,
64    weight: Tensor<B, 3>,
65    bias: Option<Tensor<B, 1>>,
66    options: impl Into<PaddedConvOptions<1>>,
67) -> Tensor<B, 3>
68where
69    B: Backend,
70{
71    let padded_options = options.into();
72    check!(TensorCheck::conv(
73        "conv1d",
74        x.dims(),
75        weight.dims(),
76        padded_options.options.groups,
77    ));
78
79    if let Some(padding_end) = padded_options.padding_end {
80        let left = padded_options.options.padding[0];
81        let right = padding_end[0];
82        // For 1D (NCL format), pad the length dimension
83        let padded = x.pad((left, right, 0, 0), PadMode::Constant(0.0));
84        let zero_options = ConvOptions::new(
85            padded_options.options.stride,
86            [0],
87            padded_options.options.dilation,
88            padded_options.options.groups,
89        );
90        Tensor::new(TensorPrimitive::Float(B::conv1d(
91            padded.primitive.tensor(),
92            weight.primitive.tensor(),
93            bias.map(|b| b.primitive.tensor()),
94            zero_options,
95        )))
96    } else {
97        Tensor::new(TensorPrimitive::Float(B::conv1d(
98            x.primitive.tensor(),
99            weight.primitive.tensor(),
100            bias.map(|b| b.primitive.tensor()),
101            padded_options.options,
102        )))
103    }
104}
105
106/// Applies a [2D convolution](crate::ops::ModuleOps::conv2d).
107///
108/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for
109/// asymmetric padding. When asymmetric padding is specified, an explicit pad
110/// operation is applied before the convolution backend op.
111pub fn conv2d<B>(
112    x: Tensor<B, 4>,
113    weight: Tensor<B, 4>,
114    bias: Option<Tensor<B, 1>>,
115    options: impl Into<PaddedConvOptions<2>>,
116) -> Tensor<B, 4>
117where
118    B: Backend,
119{
120    let padded_options = options.into();
121    check!(TensorCheck::conv(
122        "conv2d",
123        x.dims(),
124        weight.dims(),
125        padded_options.options.groups,
126    ));
127
128    if let Some(padding_end) = padded_options.padding_end {
129        let top = padded_options.options.padding[0];
130        let left = padded_options.options.padding[1];
131        let bottom = padding_end[0];
132        let right = padding_end[1];
133        // For 2D (NCHW format), pad height and width
134        let padded = x.pad((left, right, top, bottom), PadMode::Constant(0.0));
135        let zero_options = ConvOptions::new(
136            padded_options.options.stride,
137            [0, 0],
138            padded_options.options.dilation,
139            padded_options.options.groups,
140        );
141        Tensor::new(TensorPrimitive::Float(B::conv2d(
142            padded.primitive.tensor(),
143            weight.primitive.tensor(),
144            bias.map(|b| b.primitive.tensor()),
145            zero_options,
146        )))
147    } else {
148        Tensor::new(TensorPrimitive::Float(B::conv2d(
149            x.primitive.tensor(),
150            weight.primitive.tensor(),
151            bias.map(|b| b.primitive.tensor()),
152            padded_options.options,
153        )))
154    }
155}
156
157/// Applies a [3D convolution](crate::ops::ModuleOps::conv3d).
158///
159/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for
160/// asymmetric padding. Asymmetric 3D padding is not yet supported.
161pub fn conv3d<B>(
162    x: Tensor<B, 5>,
163    weight: Tensor<B, 5>,
164    bias: Option<Tensor<B, 1>>,
165    options: impl Into<PaddedConvOptions<3>>,
166) -> Tensor<B, 5>
167where
168    B: Backend,
169{
170    let padded_options = options.into();
171    check!(TensorCheck::conv(
172        "conv3d",
173        x.dims(),
174        weight.dims(),
175        padded_options.options.groups,
176    ));
177
178    if padded_options.is_asymmetric() {
179        panic!("Asymmetric padding is not yet supported for conv3d");
180    }
181
182    Tensor::new(TensorPrimitive::Float(B::conv3d(
183        x.primitive.tensor(),
184        weight.primitive.tensor(),
185        bias.map(|b| b.primitive.tensor()),
186        padded_options.options,
187    )))
188}
189
190/// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d).
191pub fn deform_conv2d<B>(
192    x: Tensor<B, 4>,
193    offset: Tensor<B, 4>,
194    weight: Tensor<B, 4>,
195    mask: Option<Tensor<B, 4>>,
196    bias: Option<Tensor<B, 1>>,
197    options: DeformConvOptions<2>,
198) -> Tensor<B, 4>
199where
200    B: Backend,
201{
202    check!(TensorCheck::conv(
203        "deform_conv2d",
204        x.dims(),
205        weight.dims(),
206        options.weight_groups,
207    ));
208    Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
209        x.primitive.tensor(),
210        offset.primitive.tensor(),
211        weight.primitive.tensor(),
212        mask.map(|m| m.primitive.tensor()),
213        bias.map(|b| b.primitive.tensor()),
214        options,
215    )))
216}
217
218/// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d).
219pub fn conv_transpose1d<B>(
220    x: Tensor<B, 3>,
221    weight: Tensor<B, 3>,
222    bias: Option<Tensor<B, 1>>,
223    options: ConvTransposeOptions<1>,
224) -> Tensor<B, 3>
225where
226    B: Backend,
227{
228    check!(TensorCheck::conv_transpose(
229        "conv_transpose1d",
230        x.dims(),
231        weight.dims(),
232    ));
233    Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(
234        x.primitive.tensor(),
235        weight.primitive.tensor(),
236        bias.map(|b| b.primitive.tensor()),
237        options,
238    )))
239}
240
241/// Applies a [2D transposed convolution](crate::ops::ModuleOps::conv_transpose2d).
242pub fn conv_transpose2d<B>(
243    x: Tensor<B, 4>,
244    weight: Tensor<B, 4>,
245    bias: Option<Tensor<B, 1>>,
246    options: ConvTransposeOptions<2>,
247) -> Tensor<B, 4>
248where
249    B: Backend,
250{
251    check!(TensorCheck::conv_transpose(
252        "conv_transpose2d",
253        x.dims(),
254        weight.dims(),
255    ));
256    Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(
257        x.primitive.tensor(),
258        weight.primitive.tensor(),
259        bias.map(|b| b.primitive.tensor()),
260        options,
261    )))
262}
263
264/// Applies a 3D transposed convolution](crate::ops::ModuleOps::conv_transpose3d).
265pub fn conv_transpose3d<B>(
266    x: Tensor<B, 5>,
267    weight: Tensor<B, 5>,
268    bias: Option<Tensor<B, 1>>,
269    options: ConvTransposeOptions<3>,
270) -> Tensor<B, 5>
271where
272    B: Backend,
273{
274    check!(TensorCheck::conv_transpose(
275        "conv_transpose3d",
276        x.dims(),
277        weight.dims(),
278    ));
279    Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(
280        x.primitive.tensor(),
281        weight.primitive.tensor(),
282        bias.map(|b| b.primitive.tensor()),
283        options,
284    )))
285}
286
287/// Applies a [4D to 3D unfold](crate::ops::ModuleOps::unfold4d).
288pub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>
289where
290    B: Backend,
291{
292    Tensor::new(TensorPrimitive::Float(B::unfold4d(
293        x.primitive.tensor(),
294        kernel_size,
295        options,
296    )))
297}
298
299/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
300pub fn max_pool1d<B>(
301    x: Tensor<B, 3>,
302    kernel_size: usize,
303    stride: usize,
304    padding: usize,
305    dilation: usize,
306    ceil_mode: bool,
307) -> Tensor<B, 3>
308where
309    B: Backend,
310{
311    Tensor::new(TensorPrimitive::Float(B::max_pool1d(
312        x.primitive.tensor(),
313        kernel_size,
314        stride,
315        padding,
316        dilation,
317        ceil_mode,
318    )))
319}
320
321/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).
322pub fn max_pool2d<B>(
323    x: Tensor<B, 4>,
324    kernel_size: [usize; 2],
325    stride: [usize; 2],
326    padding: [usize; 2],
327    dilation: [usize; 2],
328    ceil_mode: bool,
329) -> Tensor<B, 4>
330where
331    B: Backend,
332{
333    Tensor::new(TensorPrimitive::Float(B::max_pool2d(
334        x.primitive.tensor(),
335        kernel_size,
336        stride,
337        padding,
338        dilation,
339        ceil_mode,
340    )))
341}
342
343/// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d).
344pub fn avg_pool2d<B>(
345    x: Tensor<B, 4>,
346    kernel_size: [usize; 2],
347    stride: [usize; 2],
348    padding: [usize; 2],
349    count_include_pad: bool,
350    ceil_mode: bool,
351) -> Tensor<B, 4>
352where
353    B: Backend,
354{
355    Tensor::new(TensorPrimitive::Float(B::avg_pool2d(
356        x.primitive.tensor(),
357        kernel_size,
358        stride,
359        padding,
360        count_include_pad,
361        ceil_mode,
362    )))
363}
364
365/// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d).
366pub fn avg_pool1d<B>(
367    x: Tensor<B, 3>,
368    kernel_size: usize,
369    stride: usize,
370    padding: usize,
371    count_include_pad: bool,
372    ceil_mode: bool,
373) -> Tensor<B, 3>
374where
375    B: Backend,
376{
377    Tensor::new(TensorPrimitive::Float(B::avg_pool1d(
378        x.primitive.tensor(),
379        kernel_size,
380        stride,
381        padding,
382        count_include_pad,
383        ceil_mode,
384    )))
385}
386
387/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).
388pub fn max_pool1d_with_indices<B>(
389    x: Tensor<B, 3>,
390    kernel_size: usize,
391    stride: usize,
392    padding: usize,
393    dilation: usize,
394    ceil_mode: bool,
395) -> (Tensor<B, 3>, Tensor<B, 3, Int>)
396where
397    B: Backend,
398{
399    let output = B::max_pool1d_with_indices(
400        x.primitive.tensor(),
401        kernel_size,
402        stride,
403        padding,
404        dilation,
405        ceil_mode,
406    );
407
408    (
409        Tensor::new(TensorPrimitive::Float(output.output)),
410        Tensor::new(output.indices),
411    )
412}
413
414/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
415pub fn max_pool2d_with_indices<B>(
416    x: Tensor<B, 4>,
417    kernel_size: [usize; 2],
418    stride: [usize; 2],
419    padding: [usize; 2],
420    dilation: [usize; 2],
421    ceil_mode: bool,
422) -> (Tensor<B, 4>, Tensor<B, 4, Int>)
423where
424    B: Backend,
425{
426    let output = B::max_pool2d_with_indices(
427        x.primitive.tensor(),
428        kernel_size,
429        stride,
430        padding,
431        dilation,
432        ceil_mode,
433    );
434
435    (
436        Tensor::new(TensorPrimitive::Float(output.output)),
437        Tensor::new(output.indices),
438    )
439}
440
441/// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d).
442pub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>
443where
444    B: Backend,
445{
446    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(
447        x.primitive.tensor(),
448        output_size,
449    )))
450}
451
452/// Applies a [1D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool1d).
453pub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>
454where
455    B: Backend,
456{
457    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(
458        x.primitive.tensor(),
459        output_size,
460    )))
461}
462
463/// Applies a [2D interpolation](crate::ops::ModuleOps::interpolate).
464pub fn interpolate<B>(
465    x: Tensor<B, 4>,
466    output_size: [usize; 2],
467    options: InterpolateOptions,
468) -> Tensor<B, 4>
469where
470    B: Backend,
471{
472    Tensor::new(TensorPrimitive::Float(B::interpolate(
473        x.primitive.tensor(),
474        output_size,
475        options,
476    )))
477}
478
479/// Applies a linear transformation to the input tensor using the given weight and bias.
480///
481/// ```math
482/// y = x @ weight + [bias]
483/// ```
484///
485/// # Arguments:
486///
487/// - `input` is the input tensor, ``[..., d_input]``.
488/// - `weight` is the weight tensor, ``[d_input, d_output]``.
489/// - `bias` is the bias tensor (optional), ``[d_output]``.
490///
491/// # Returns:
492///
493/// The transformed tensor, ``[..., d_output]``.
494///
495/// # Compatibility
496///
497/// This function differs from PyTorch's ``torch.nn.functional.linear`` in that it does not
498/// transpose the weight matrix. In PyTorch, the weight matrix is transposed before
499/// multiplication:
500///
501/// ```math
502/// y = x @ weight^T + [bias]
503/// ```
504pub fn linear<B: Backend, const D: usize>(
505    input: Tensor<B, D>,
506    weight: Tensor<B, 2>,
507    bias: Option<Tensor<B, 1>>,
508) -> Tensor<B, D> {
509    if D == 1 {
510        // Insert and remove an extra batch dimension for the batch matmul to work.
511        let input = input.unsqueeze::<2>();
512        let output = linear(input, weight, bias);
513        return output.squeeze_dim(0);
514    }
515
516    Tensor::new(TensorPrimitive::Float(B::linear(
517        input.primitive.tensor(),
518        weight.primitive.tensor(),
519        bias.map(|b| b.primitive.tensor()),
520    )))
521}
522
523/// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V,
524/// where scale defaults to 1/sqrt(head_dim) (configurable via `options.scale`).
525/// Optionally applies masking, additive bias, causal masking, and softcap.
526///
527/// # Arguments
528/// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`
529/// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
530/// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`
531/// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
532///   where `true` indicates positions to mask (i.e. set to -inf before softmax).
533/// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`
534///   added to the attention scores before softmax (e.g. ALiBi, relative position biases).
535/// - `options`: Additional attention options (custom scale, softcap, causal masking).
536///
537/// # Returns
538/// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`
539/// representing the attended context per head.
540///
541/// # Note
542/// This implementation does not support dropout and is intended for inference or
543/// use cases where dropout is not needed.
544pub fn attention<B: Backend>(
545    query: Tensor<B, 4>,
546    key: Tensor<B, 4>,
547    value: Tensor<B, 4>,
548    mask: Option<Tensor<B, 4, Bool>>,
549    attn_bias: Option<Tensor<B, 4>>,
550    options: AttentionModuleOptions,
551) -> Tensor<B, 4> {
552    Tensor::new(TensorPrimitive::Float(B::attention(
553        query.primitive.tensor(),
554        key.primitive.tensor(),
555        value.primitive.tensor(),
556        mask.map(|mask| mask.primitive),
557        attn_bias.map(|bias| bias.primitive.tensor()),
558        options,
559    )))
560}
561
562/// Exports attention fallback to test backend's attention against.
563pub fn attention_fallback<B: Backend>(
564    query: Tensor<B, 4>,
565    key: Tensor<B, 4>,
566    value: Tensor<B, 4>,
567    mask: Option<Tensor<B, 4, Bool>>,
568    attn_bias: Option<Tensor<B, 4>>,
569    options: AttentionModuleOptions,
570) -> Tensor<B, 4> {
571    Tensor::new(TensorPrimitive::Float(
572        crate::ops::attention::attention_fallback::<B>(
573            query.primitive.tensor(),
574            key.primitive.tensor(),
575            value.primitive.tensor(),
576            mask.map(|mask| mask.primitive),
577            attn_bias.map(|bias| bias.primitive.tensor()),
578            options,
579        ),
580    ))
581}