burn_dinov2/layers/
patch_embed.rs1use burn::prelude::*;
2
3
4#[derive(Config)]
5pub struct PatchEmbedConfig {
6 pub image_size: usize,
7 pub patch_size: usize,
8 pub input_channels: usize,
9 pub embedding_dimension: usize,
10}
11
12impl Default for PatchEmbedConfig {
13 fn default() -> Self {
14 Self {
15 image_size: 224,
16 patch_size: 16,
17 input_channels: 3,
18 embedding_dimension: 768,
19 }
20 }
21}
22
23impl PatchEmbedConfig {
24 pub fn init<B: Backend>(&self, device: &B::Device) -> PatchEmbed<B> {
25 PatchEmbed::new(device, self.clone())
26 }
27}
28
29
30#[derive(Module, Debug)]
31pub struct PatchEmbed<B: Backend> {
32 proj: nn::conv::Conv2d<B>,
33}
34
35impl<B: Backend> PatchEmbed<B> {
36 pub fn new(
37 device: &B::Device,
38 config: PatchEmbedConfig,
39 ) -> Self {
40 let kernel_size = [config.patch_size, config.patch_size];
41 let proj = nn::conv::Conv2dConfig::new(
42 [config.input_channels, config.embedding_dimension],
43 kernel_size,
44 )
45 .with_stride(kernel_size)
46 .init(device);
47
48 Self {
49 proj,
50 }
51 }
52
53 #[allow(non_snake_case)]
54 pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 3> {
55 self.proj.forward(x)
56 .flatten(2, 3)
57 .swap_dims(1, 2)
58 }
59}
60