burn_backend/backend/ops/modules/
unfold.rs1use super::{ConvOptions, UnfoldOptions};
2use crate::tensor::FloatTensor;
3use crate::{Backend, TensorData, TensorMetadata, element::ElementConversion};
4use alloc::vec;
5use alloc::vec::Vec;
6use burn_std::{DType, Shape};
7
8pub(crate) fn create_unfolding_weight<B: Backend>(
17 in_channels: usize,
18 kernel_size: [usize; 2],
19 device: &B::Device,
20 dtype: DType,
21) -> FloatTensor<B> {
22 let shape = Shape::new([
23 in_channels * kernel_size[0] * kernel_size[1],
24 in_channels,
25 kernel_size[0],
26 kernel_size[1],
27 ]);
28
29 let mut strides = [0; 4];
30 let mut current = 1;
31 shape.iter().enumerate().rev().for_each(|(index, val)| {
32 strides[index] = current;
33 current *= val;
34 });
35
36 let num_elements = shape.num_elements();
37
38 let mut weight: Vec<B::FloatElem> = vec![0.0.elem(); num_elements];
39
40 for k in 0..in_channels {
41 for i in 0..kernel_size[0] {
42 for j in 0..kernel_size[1] {
43 let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j;
44 let index =
45 output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3];
46
47 weight[index] = 1.elem();
48 }
49 }
50 }
51
52 B::float_from_data(TensorData::new(weight, shape).convert_dtype(dtype), device)
53}
54
55pub(crate) fn unfold4d_using_conv2d<B: Backend>(
57 x: FloatTensor<B>,
58 kernel_size: [usize; 2],
59 options: UnfoldOptions,
60) -> FloatTensor<B> {
61 let [_batch_size, in_channels, _in_height, _in_width] = x.shape().dims();
62 let weight =
63 create_unfolding_weight::<B>(in_channels, kernel_size, &B::float_device(&x), x.dtype());
64 let unfolded = B::conv2d(
65 x,
66 weight,
67 None,
68 ConvOptions::new(options.stride, options.padding, options.dilation, 1),
69 );
70
71 let [batch_size, channels_out, out_height, out_width] = unfolded.shape().dims();
72
73 B::float_reshape(
74 unfolded,
75 Shape::new([batch_size, channels_out, out_height * out_width]),
76 )
77}
78
79pub fn calculate_unfold_windows(dim_size: usize, window_size: usize, step_size: usize) -> usize {
81 assert!(step_size > 0);
82 let x = dim_size + step_size;
83 if x < window_size {
84 0
85 } else {
86 (x - window_size) / step_size
87 }
88}
89
90pub fn calculate_unfold_shape<S: Into<Shape>>(
108 shape: S,
109 dim: usize,
110 size: usize,
111 step: usize,
112) -> Shape {
113 let mut shape = shape.into();
114 let d_shape = shape[dim];
115 let windows = calculate_unfold_windows(d_shape, size, step);
116 shape[dim] = windows;
117 shape.push(size);
118
119 shape
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test]
127 fn test_calculate_unfold_windows() {
128 assert_eq!(calculate_unfold_windows(2, 5, 1), 0);
129
130 assert_eq!(calculate_unfold_windows(2, 3, 1), 0);
131 assert_eq!(calculate_unfold_windows(3, 3, 1), 1);
132 assert_eq!(calculate_unfold_windows(4, 3, 1), 2);
133 assert_eq!(calculate_unfold_windows(5, 3, 1), 3);
134
135 assert_eq!(calculate_unfold_windows(2, 3, 2), 0);
136 assert_eq!(calculate_unfold_windows(3, 3, 2), 1);
137 assert_eq!(calculate_unfold_windows(4, 3, 2), 1);
138 assert_eq!(calculate_unfold_windows(5, 3, 2), 2);
139 }
140
141 #[test]
142 fn test_calculate_unfold_shape() {
143 assert_eq!(
144 calculate_unfold_shape([2, 6, 6], 1, 3, 2),
145 Shape::new([2, 2, 6, 3])
146 );
147 }
148}