Skip to main content

burn_backend/backend/ops/modules/
unfold.rs

1use super::{ConvOptions, UnfoldOptions};
2use crate::tensor::FloatTensor;
3use crate::{Backend, TensorData, TensorMetadata, element::ElementConversion};
4use alloc::vec;
5use alloc::vec::Vec;
6use burn_std::{DType, Shape};
7
8/// Constructs a special weight tensor used for unfolding.
9///
10/// # Notes
11///
12/// The idea behind using convolution for unfolding is to leverage the sliding window mechanism of
13/// convolution. By creating a weight tensor with ones in a particular pattern, we are able to borrow
14/// the convolution operation's mechanism as it moves across the input tensor, picking up the desired
15/// values in the pattern of the unfolding operation.
16pub(crate) fn create_unfolding_weight<B: Backend>(
17    in_channels: usize,
18    kernel_size: [usize; 2],
19    device: &B::Device,
20    dtype: DType,
21) -> FloatTensor<B> {
22    let shape = Shape::new([
23        in_channels * kernel_size[0] * kernel_size[1],
24        in_channels,
25        kernel_size[0],
26        kernel_size[1],
27    ]);
28
29    let mut strides = [0; 4];
30    let mut current = 1;
31    shape.iter().enumerate().rev().for_each(|(index, val)| {
32        strides[index] = current;
33        current *= val;
34    });
35
36    let num_elements = shape.num_elements();
37
38    let mut weight: Vec<B::FloatElem> = vec![0.0.elem(); num_elements];
39
40    for k in 0..in_channels {
41        for i in 0..kernel_size[0] {
42            for j in 0..kernel_size[1] {
43                let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j;
44                let index =
45                    output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3];
46
47                weight[index] = 1.elem();
48            }
49        }
50    }
51
52    B::float_from_data(TensorData::new(weight, shape).convert_dtype(dtype), device)
53}
54
55/// Compute the unfold4d operation using the conv2d operations.
56pub(crate) fn unfold4d_using_conv2d<B: Backend>(
57    x: FloatTensor<B>,
58    kernel_size: [usize; 2],
59    options: UnfoldOptions,
60) -> FloatTensor<B> {
61    let [_batch_size, in_channels, _in_height, _in_width] = x.shape().dims();
62    let weight =
63        create_unfolding_weight::<B>(in_channels, kernel_size, &B::float_device(&x), x.dtype());
64    let unfolded = B::conv2d(
65        x,
66        weight,
67        None,
68        ConvOptions::new(options.stride, options.padding, options.dilation, 1),
69    );
70
71    let [batch_size, channels_out, out_height, out_width] = unfolded.shape().dims();
72
73    B::float_reshape(
74        unfolded,
75        Shape::new([batch_size, channels_out, out_height * out_width]),
76    )
77}
78
79/// Calculate the number of unfolding windows that can be extracted from a dimension of given size.
80pub fn calculate_unfold_windows(dim_size: usize, window_size: usize, step_size: usize) -> usize {
81    assert!(step_size > 0);
82    let x = dim_size + step_size;
83    if x < window_size {
84        0
85    } else {
86        (x - window_size) / step_size
87    }
88}
89
90/// Calculate the output shape for an unfold operation.
91///
92/// The operation yields a view with all complete windows of size `size` in dimension `dim`;
93/// where windows are advanced by `step` at each index.
94///
95/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
96///
97/// # Arguments
98///
99/// * `shape` - The input shape to unfold; of shape ``[pre=..., dim shape, post=...]``
100/// * `dim` - the dimension to unfold.
101/// * `size` - the size of each unfolded window.
102/// * `step` - the step between each window.
103///
104/// # Returns
105///
106/// A shape with ``[pre=..., windows, post=..., size]``.
107pub fn calculate_unfold_shape<S: Into<Shape>>(
108    shape: S,
109    dim: usize,
110    size: usize,
111    step: usize,
112) -> Shape {
113    let mut shape = shape.into();
114    let d_shape = shape[dim];
115    let windows = calculate_unfold_windows(d_shape, size, step);
116    shape[dim] = windows;
117    shape.push(size);
118
119    shape
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_calculate_unfold_windows() {
128        assert_eq!(calculate_unfold_windows(2, 5, 1), 0);
129
130        assert_eq!(calculate_unfold_windows(2, 3, 1), 0);
131        assert_eq!(calculate_unfold_windows(3, 3, 1), 1);
132        assert_eq!(calculate_unfold_windows(4, 3, 1), 2);
133        assert_eq!(calculate_unfold_windows(5, 3, 1), 3);
134
135        assert_eq!(calculate_unfold_windows(2, 3, 2), 0);
136        assert_eq!(calculate_unfold_windows(3, 3, 2), 1);
137        assert_eq!(calculate_unfold_windows(4, 3, 2), 1);
138        assert_eq!(calculate_unfold_windows(5, 3, 2), 2);
139    }
140
141    #[test]
142    fn test_calculate_unfold_shape() {
143        assert_eq!(
144            calculate_unfold_shape([2, 6, 6], 1, 3, 2),
145            Shape::new([2, 2, 6, 3])
146        );
147    }
148}