Skip to main content

jepa_vision/
patch.rs

1//! Patch embedding for images.
2//!
3//! Implements the patchification step from RFC-002 (Encoder Module).
4//!
5//! Patch embedding is the first stage of a Vision Transformer: it converts
6//! a raw image into a sequence of learnable token vectors.
7//!
8//! ```text
9//! [B, C, H, W]  ──reshape──►  [B, grid_h·grid_w, C·patch_h·patch_w]  ──linear──►  [B, S, D]
10//! ```
11//!
12//! Steps:
13//! 1. Divide the image into non-overlapping patches of size `(patch_h, patch_w)`.
14//! 2. Flatten each patch to a vector of length `C × patch_h × patch_w`.
15//! 3. Project through a learned linear layer to `embed_dim`.
16//!
17//! For video, see [`crate::video`] which uses 3-D *tubelet* embedding instead.
18
19use burn::nn::{Linear, LinearConfig};
20use burn::prelude::*;
21use burn::tensor::backend::Backend;
22
23/// Configuration for patch embedding.
24///
25/// # Example
26///
27/// ```
28/// use jepa_vision::patch::PatchEmbeddingConfig;
29/// use burn_ndarray::NdArray;
30/// use burn::prelude::*;
31///
32/// type B = NdArray<f32>;
33/// let device = burn_ndarray::NdArrayDevice::Cpu;
34///
35/// let config = PatchEmbeddingConfig::new(3, 16, 16, 256);
36/// let patch_embed = config.init::<B>(&device);
37/// assert_eq!(patch_embed.num_patches(224, 224), 196);
38/// ```
39#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40pub struct PatchEmbeddingConfig {
41    /// Number of input channels (e.g., 3 for RGB).
42    pub in_channels: usize,
43    /// Patch height in pixels.
44    pub patch_h: usize,
45    /// Patch width in pixels.
46    pub patch_w: usize,
47    /// Output embedding dimension.
48    pub embed_dim: usize,
49}
50
51impl PatchEmbeddingConfig {
52    /// Create a new config with the given parameters.
53    pub fn new(in_channels: usize, patch_h: usize, patch_w: usize, embed_dim: usize) -> Self {
54        Self {
55            in_channels,
56            patch_h,
57            patch_w,
58            embed_dim,
59        }
60    }
61
62    /// Initialize a [`PatchEmbedding`] module.
63    pub fn init<B: Backend>(&self, device: &B::Device) -> PatchEmbedding<B> {
64        let patch_dim = self.in_channels * self.patch_h * self.patch_w;
65        let projection = LinearConfig::new(patch_dim, self.embed_dim).init(device);
66        PatchEmbedding {
67            projection,
68            patch_h: self.patch_h,
69            patch_w: self.patch_w,
70            in_channels: self.in_channels,
71        }
72    }
73}
74
75/// Patch embedding module.
76///
77/// Splits an image into non-overlapping patches and projects each
78/// through a linear layer to produce patch embeddings.
79///
80/// This is the first stage of a Vision Transformer (ViT) encoder.
81#[derive(Module, Debug)]
82pub struct PatchEmbedding<B: Backend> {
83    /// Linear projection from flattened patch to embedding space.
84    projection: Linear<B>,
85    /// Patch height in pixels.
86    patch_h: usize,
87    /// Patch width in pixels.
88    patch_w: usize,
89    /// Number of input channels.
90    in_channels: usize,
91}
92
93impl<B: Backend> PatchEmbedding<B> {
94    /// Convert an image batch to patch embeddings.
95    ///
96    /// # Arguments
97    /// * `images` - Input images. Shape: `[batch, channels, height, width]`
98    ///
99    /// # Returns
100    /// Patch embeddings. Shape: `[batch, num_patches, embed_dim]`
101    ///
102    /// # Panics
103    /// If `height` is not divisible by `patch_h` or `width` is not divisible by `patch_w`.
104    pub fn forward(&self, images: Tensor<B, 4>) -> Tensor<B, 3> {
105        let [batch, _channels, height, width] = images.dims();
106
107        let grid_h = height / self.patch_h;
108        let grid_w = width / self.patch_w;
109        let num_patches = grid_h * grid_w;
110        let patch_dim = self.in_channels * self.patch_h * self.patch_w;
111
112        // Reshape: [batch, C, H, W] -> [batch, C, grid_h, patch_h, grid_w, patch_w]
113        let x = images.reshape([
114            batch,
115            self.in_channels,
116            grid_h,
117            self.patch_h,
118            grid_w,
119            self.patch_w,
120        ]);
121        // Permute to: [batch, grid_h, grid_w, C, patch_h, patch_w]
122        let x = x.permute([0, 2, 4, 1, 3, 5]);
123        // Flatten patches: [batch, num_patches, patch_dim]
124        let x = x.reshape([batch, num_patches, patch_dim]);
125
126        // Project: [batch, num_patches, embed_dim]
127        self.projection.forward(x)
128    }
129
130    /// Get the number of patches for a given image size.
131    pub fn num_patches(&self, height: usize, width: usize) -> usize {
132        (height / self.patch_h) * (width / self.patch_w)
133    }
134
135    /// Get the grid dimensions for a given image size.
136    pub fn grid_size(&self, height: usize, width: usize) -> (usize, usize) {
137        (height / self.patch_h, width / self.patch_w)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use burn_ndarray::NdArray;
145
146    type TestBackend = NdArray<f32>;
147
148    fn device() -> burn_ndarray::NdArrayDevice {
149        burn_ndarray::NdArrayDevice::Cpu
150    }
151
152    #[test]
153    fn test_patch_embedding_output_shape() {
154        let config = PatchEmbeddingConfig::new(3, 16, 16, 256);
155        let pe = config.init::<TestBackend>(&device());
156
157        let images: Tensor<TestBackend, 4> = Tensor::zeros([2, 3, 224, 224], &device());
158        let output = pe.forward(images);
159
160        assert_eq!(output.dims(), [2, 196, 256]); // 224/16 = 14, 14*14 = 196
161    }
162
163    #[test]
164    fn test_patch_embedding_small_image() {
165        let config = PatchEmbeddingConfig::new(1, 2, 2, 8);
166        let pe = config.init::<TestBackend>(&device());
167
168        let images: Tensor<TestBackend, 4> = Tensor::zeros([1, 1, 4, 4], &device());
169        let output = pe.forward(images);
170
171        assert_eq!(output.dims(), [1, 4, 8]); // 4/2 = 2, 2*2 = 4 patches
172    }
173
174    #[test]
175    fn test_num_patches() {
176        let config = PatchEmbeddingConfig::new(3, 16, 16, 256);
177        let pe = config.init::<TestBackend>(&device());
178        assert_eq!(pe.num_patches(224, 224), 196);
179        assert_eq!(pe.num_patches(32, 32), 4);
180    }
181
182    #[test]
183    fn test_grid_size() {
184        let config = PatchEmbeddingConfig::new(3, 16, 16, 256);
185        let pe = config.init::<TestBackend>(&device());
186        assert_eq!(pe.grid_size(224, 224), (14, 14));
187    }
188
189    #[test]
190    fn test_patch_embedding_nonzero_output() {
191        let config = PatchEmbeddingConfig::new(3, 16, 16, 64);
192        let pe = config.init::<TestBackend>(&device());
193
194        // Use ones instead of zeros to ensure non-trivial output
195        let images: Tensor<TestBackend, 4> = Tensor::ones([1, 3, 32, 32], &device());
196        let output = pe.forward(images);
197        let [_b, _s, _d] = output.dims();
198        // Output should not be all zeros (linear projection has random init)
199        // We just check the shape is correct
200        assert_eq!(output.dims(), [1, 4, 64]);
201    }
202
203    use proptest::prelude::*;
204
205    proptest! {
206        #[test]
207        fn prop_num_patches_equals_grid_product(
208            grid_h in 1usize..8,
209            grid_w in 1usize..8,
210            patch_size in proptest::sample::select(vec![2usize, 4, 8]),
211        ) {
212            let config = PatchEmbeddingConfig::new(1, patch_size, patch_size, 16);
213            let pe = config.init::<TestBackend>(&device());
214            let h = grid_h * patch_size;
215            let w = grid_w * patch_size;
216            let np = pe.num_patches(h, w);
217            prop_assert_eq!(np, grid_h * grid_w);
218        }
219
220        #[test]
221        fn prop_patch_embedding_output_shape(
222            grid_h in 1usize..4,
223            grid_w in 1usize..4,
224            batch in 1usize..3,
225        ) {
226            let patch_size = 2;
227            let embed_dim = 8;
228            let config = PatchEmbeddingConfig::new(1, patch_size, patch_size, embed_dim);
229            let pe = config.init::<TestBackend>(&device());
230            let h = grid_h * patch_size;
231            let w = grid_w * patch_size;
232            let images: Tensor<TestBackend, 4> = Tensor::zeros([batch, 1, h, w], &device());
233            let output = pe.forward(images);
234            let expected_patches = grid_h * grid_w;
235            prop_assert_eq!(output.dims(), [batch, expected_patches, embed_dim]);
236        }
237    }
238}