1use crate as burn;
2
3use crate::config::Config;
4use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
5
6use burn_tensor::Tensor;
7use burn_tensor::backend::Backend;
8use burn_tensor::module::unfold4d;
9use burn_tensor::ops::UnfoldOptions;
10
11#[derive(Config, Debug)]
13pub struct Unfold4dConfig {
14 pub kernel_size: [usize; 2],
16 #[config(default = "[1, 1]")]
18 pub stride: [usize; 2],
19 #[config(default = "[1, 1]")]
21 pub dilation: [usize; 2],
22 #[config(default = "[0, 0]")]
24 pub padding: [usize; 2],
25}
26
27#[derive(Module, Clone, Debug)]
31#[module(custom_display)]
32pub struct Unfold4d {
33 pub kernel_size: [usize; 2],
35 pub stride: [usize; 2],
37 pub dilation: [usize; 2],
39 pub padding: [usize; 2],
41}
42
43impl ModuleDisplay for Unfold4d {
44 fn custom_settings(&self) -> Option<DisplaySettings> {
45 DisplaySettings::new()
46 .with_new_line_after_attribute(false)
47 .optional()
48 }
49
50 fn custom_content(&self, content: Content) -> Option<Content> {
51 content
52 .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
53 .add("stride", &alloc::format!("{:?}", &self.stride))
54 .add("dilation", &alloc::format!("{:?}", &self.dilation))
55 .add("padding", &alloc::format!("{:?}", &self.padding))
56 .optional()
57 }
58}
59
60impl Unfold4dConfig {
61 pub fn init(&self) -> Unfold4d {
63 Unfold4d {
64 kernel_size: self.kernel_size,
65 stride: self.stride,
66 dilation: self.dilation,
67 padding: self.padding,
68 }
69 }
70}
71
72impl Unfold4d {
73 pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {
82 unfold4d(
83 input,
84 self.kernel_size,
85 UnfoldOptions::new(self.stride, self.padding, self.dilation),
86 )
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn display() {
96 let config = Unfold4dConfig::new([3, 3]);
97 let unfold = config.init();
98
99 assert_eq!(
100 alloc::format!("{unfold}"),
101 "Unfold4d {kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], padding: [0, 0]}"
102 );
103 }
104}