burn_dinov2/layers/
patch_embed.rs

1use burn::prelude::*;
2
3
4#[derive(Config)]
5pub struct PatchEmbedConfig {
6    pub image_size: usize,
7    pub patch_size: usize,
8    pub input_channels: usize,
9    pub embedding_dimension: usize,
10}
11
12impl Default for PatchEmbedConfig {
13    fn default() -> Self {
14        Self {
15            image_size: 224,
16            patch_size: 16,
17            input_channels: 3,
18            embedding_dimension: 768,
19        }
20    }
21}
22
23impl PatchEmbedConfig {
24    pub fn init<B: Backend>(&self, device: &B::Device) -> PatchEmbed<B> {
25        PatchEmbed::new(device, self.clone())
26    }
27}
28
29
30#[derive(Module, Debug)]
31pub struct PatchEmbed<B: Backend> {
32    proj: nn::conv::Conv2d<B>,
33}
34
35impl<B: Backend> PatchEmbed<B> {
36    pub fn new(
37        device: &B::Device,
38        config: PatchEmbedConfig,
39    ) -> Self {
40        let kernel_size = [config.patch_size, config.patch_size];
41        let proj = nn::conv::Conv2dConfig::new(
42                [config.input_channels, config.embedding_dimension],
43                kernel_size,
44            )
45            .with_stride(kernel_size)
46            .init(device);
47
48        Self {
49            proj,
50        }
51    }
52
53    #[allow(non_snake_case)]
54    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 3> {
55        self.proj.forward(x)
56            .flatten(2, 3)
57            .swap_dims(1, 2)
58    }
59}
60