1use burn::nn::{Linear, LinearConfig};
20use burn::prelude::*;
21use burn::tensor::backend::Backend;
22
23#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40pub struct PatchEmbeddingConfig {
41 pub in_channels: usize,
43 pub patch_h: usize,
45 pub patch_w: usize,
47 pub embed_dim: usize,
49}
50
51impl PatchEmbeddingConfig {
52 pub fn new(in_channels: usize, patch_h: usize, patch_w: usize, embed_dim: usize) -> Self {
54 Self {
55 in_channels,
56 patch_h,
57 patch_w,
58 embed_dim,
59 }
60 }
61
62 pub fn init<B: Backend>(&self, device: &B::Device) -> PatchEmbedding<B> {
64 let patch_dim = self.in_channels * self.patch_h * self.patch_w;
65 let projection = LinearConfig::new(patch_dim, self.embed_dim).init(device);
66 PatchEmbedding {
67 projection,
68 patch_h: self.patch_h,
69 patch_w: self.patch_w,
70 in_channels: self.in_channels,
71 }
72 }
73}
74
75#[derive(Module, Debug)]
82pub struct PatchEmbedding<B: Backend> {
83 projection: Linear<B>,
85 patch_h: usize,
87 patch_w: usize,
89 in_channels: usize,
91}
92
93impl<B: Backend> PatchEmbedding<B> {
94 pub fn forward(&self, images: Tensor<B, 4>) -> Tensor<B, 3> {
105 let [batch, _channels, height, width] = images.dims();
106
107 let grid_h = height / self.patch_h;
108 let grid_w = width / self.patch_w;
109 let num_patches = grid_h * grid_w;
110 let patch_dim = self.in_channels * self.patch_h * self.patch_w;
111
112 let x = images.reshape([
114 batch,
115 self.in_channels,
116 grid_h,
117 self.patch_h,
118 grid_w,
119 self.patch_w,
120 ]);
121 let x = x.permute([0, 2, 4, 1, 3, 5]);
123 let x = x.reshape([batch, num_patches, patch_dim]);
125
126 self.projection.forward(x)
128 }
129
130 pub fn num_patches(&self, height: usize, width: usize) -> usize {
132 (height / self.patch_h) * (width / self.patch_w)
133 }
134
135 pub fn grid_size(&self, height: usize, width: usize) -> (usize, usize) {
137 (height / self.patch_h, width / self.patch_w)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use burn_ndarray::NdArray;
145
146 type TestBackend = NdArray<f32>;
147
148 fn device() -> burn_ndarray::NdArrayDevice {
149 burn_ndarray::NdArrayDevice::Cpu
150 }
151
152 #[test]
153 fn test_patch_embedding_output_shape() {
154 let config = PatchEmbeddingConfig::new(3, 16, 16, 256);
155 let pe = config.init::<TestBackend>(&device());
156
157 let images: Tensor<TestBackend, 4> = Tensor::zeros([2, 3, 224, 224], &device());
158 let output = pe.forward(images);
159
160 assert_eq!(output.dims(), [2, 196, 256]); }
162
163 #[test]
164 fn test_patch_embedding_small_image() {
165 let config = PatchEmbeddingConfig::new(1, 2, 2, 8);
166 let pe = config.init::<TestBackend>(&device());
167
168 let images: Tensor<TestBackend, 4> = Tensor::zeros([1, 1, 4, 4], &device());
169 let output = pe.forward(images);
170
171 assert_eq!(output.dims(), [1, 4, 8]); }
173
174 #[test]
175 fn test_num_patches() {
176 let config = PatchEmbeddingConfig::new(3, 16, 16, 256);
177 let pe = config.init::<TestBackend>(&device());
178 assert_eq!(pe.num_patches(224, 224), 196);
179 assert_eq!(pe.num_patches(32, 32), 4);
180 }
181
182 #[test]
183 fn test_grid_size() {
184 let config = PatchEmbeddingConfig::new(3, 16, 16, 256);
185 let pe = config.init::<TestBackend>(&device());
186 assert_eq!(pe.grid_size(224, 224), (14, 14));
187 }
188
189 #[test]
190 fn test_patch_embedding_nonzero_output() {
191 let config = PatchEmbeddingConfig::new(3, 16, 16, 64);
192 let pe = config.init::<TestBackend>(&device());
193
194 let images: Tensor<TestBackend, 4> = Tensor::ones([1, 3, 32, 32], &device());
196 let output = pe.forward(images);
197 let [_b, _s, _d] = output.dims();
198 assert_eq!(output.dims(), [1, 4, 64]);
201 }
202
203 use proptest::prelude::*;
204
205 proptest! {
206 #[test]
207 fn prop_num_patches_equals_grid_product(
208 grid_h in 1usize..8,
209 grid_w in 1usize..8,
210 patch_size in proptest::sample::select(vec![2usize, 4, 8]),
211 ) {
212 let config = PatchEmbeddingConfig::new(1, patch_size, patch_size, 16);
213 let pe = config.init::<TestBackend>(&device());
214 let h = grid_h * patch_size;
215 let w = grid_w * patch_size;
216 let np = pe.num_patches(h, w);
217 prop_assert_eq!(np, grid_h * grid_w);
218 }
219
220 #[test]
221 fn prop_patch_embedding_output_shape(
222 grid_h in 1usize..4,
223 grid_w in 1usize..4,
224 batch in 1usize..3,
225 ) {
226 let patch_size = 2;
227 let embed_dim = 8;
228 let config = PatchEmbeddingConfig::new(1, patch_size, patch_size, embed_dim);
229 let pe = config.init::<TestBackend>(&device());
230 let h = grid_h * patch_size;
231 let w = grid_w * patch_size;
232 let images: Tensor<TestBackend, 4> = Tensor::zeros([batch, 1, h, w], &device());
233 let output = pe.forward(images);
234 let expected_patches = grid_h * grid_w;
235 prop_assert_eq!(output.dims(), [batch, expected_patches, embed_dim]);
236 }
237 }
238}