burn_core/nn/
unfold.rs

1use crate as burn;
2
3use crate::config::Config;
4use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
5
6use burn_tensor::Tensor;
7use burn_tensor::backend::Backend;
8use burn_tensor::module::unfold4d;
9use burn_tensor::ops::UnfoldOptions;
10
11/// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init).
12#[derive(Config, Debug)]
13pub struct Unfold4dConfig {
14    /// The size of the kernel.
15    pub kernel_size: [usize; 2],
16    /// The stride of the convolution.
17    #[config(default = "[1, 1]")]
18    pub stride: [usize; 2],
19    /// Spacing between kernel elements.
20    #[config(default = "[1, 1]")]
21    pub dilation: [usize; 2],
22    /// The padding configuration.
23    #[config(default = "[0, 0]")]
24    pub padding: [usize; 2],
25}
26
27/// Four-dimensional unfolding.
28///
29/// Should be created with [Unfold4dConfig].
30#[derive(Module, Clone, Debug)]
31#[module(custom_display)]
32pub struct Unfold4d {
33    /// The size of the kernel.
34    pub kernel_size: [usize; 2],
35    /// The stride of the convolution.
36    pub stride: [usize; 2],
37    /// Spacing between kernel elements.
38    pub dilation: [usize; 2],
39    /// The padding configuration.
40    pub padding: [usize; 2],
41}
42
43impl ModuleDisplay for Unfold4d {
44    fn custom_settings(&self) -> Option<DisplaySettings> {
45        DisplaySettings::new()
46            .with_new_line_after_attribute(false)
47            .optional()
48    }
49
50    fn custom_content(&self, content: Content) -> Option<Content> {
51        content
52            .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
53            .add("stride", &alloc::format!("{:?}", &self.stride))
54            .add("dilation", &alloc::format!("{:?}", &self.dilation))
55            .add("padding", &alloc::format!("{:?}", &self.padding))
56            .optional()
57    }
58}
59
60impl Unfold4dConfig {
61    /// Initializes a new [Unfold4d] module.
62    pub fn init(&self) -> Unfold4d {
63        Unfold4d {
64            kernel_size: self.kernel_size,
65            stride: self.stride,
66            dilation: self.dilation,
67            padding: self.padding,
68        }
69    }
70}
71
72impl Unfold4d {
73    /// Applies the forward pass on the input tensor.
74    ///
75    /// See [unfold4d](crate::tensor::module::unfold4d) for more information.
76    ///
77    /// # Shapes
78    ///
79    /// input:   `[batch_size, channels_in, height, width]`
80    /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`
81    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
82        unfold4d(
83            input,
84            self.kernel_size,
85            UnfoldOptions::new(self.stride, self.padding, self.dilation),
86        )
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn display() {
96        let config = Unfold4dConfig::new([3, 3]);
97        let unfold = config.init();
98
99        assert_eq!(
100            alloc::format!("{unfold}"),
101            "Unfold4d {kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], padding: [0, 0]}"
102        );
103    }
104}