use bimm_contracts::assert_shape_contract_periodically;
use burn::{
config::Config,
module::Module,
nn::{
LayerNorm,
LayerNormConfig,
conv::{
Conv2d,
Conv2dConfig,
},
},
prelude::{
Backend,
Tensor,
},
};
pub trait PatchEmbedMeta {
fn input_resolution(&self) -> [usize; 2];
fn input_height(&self) -> usize {
self.input_resolution()[0]
}
fn input_width(&self) -> usize {
self.input_resolution()[1]
}
fn d_input(&self) -> usize;
fn patch_size(&self) -> usize;
fn patches_resolution(&self) -> [usize; 2] {
[self.patches_height(), self.patches_width()]
}
fn patches_height(&self) -> usize {
self.input_height() / self.patch_size()
}
fn patches_width(&self) -> usize {
self.input_width() / self.patch_size()
}
fn num_patches(&self) -> usize {
self.patches_height() * self.patches_width()
}
fn d_output(&self) -> usize;
fn enable_patch_norm(&self) -> bool;
}
#[derive(Config, Debug, Copy)]
pub struct PatchEmbedConfig {
input_resolution: [usize; 2],
patch_size: usize,
d_input: usize,
d_output: usize,
#[config(default = true)]
enable_patch_norm: bool,
}
impl PatchEmbedMeta for PatchEmbedConfig {
fn input_resolution(&self) -> [usize; 2] {
self.input_resolution
}
fn patch_size(&self) -> usize {
self.patch_size
}
fn d_input(&self) -> usize {
self.d_input
}
fn d_output(&self) -> usize {
self.d_output
}
fn enable_patch_norm(&self) -> bool {
self.enable_patch_norm
}
}
impl PatchEmbedConfig {
#[must_use]
pub fn init<B: Backend>(
&self,
device: &B::Device,
) -> PatchEmbed<B> {
let [h, w] = self.input_resolution;
assert!(
h % self.patch_size == 0 && w % self.patch_size == 0,
"Input resolution must be divisible by patch size: {:?}",
self.input_resolution
);
let stride = [self.patch_size, self.patch_size];
PatchEmbed {
input_resolution: self.input_resolution,
patch_size: self.patch_size,
projection: Conv2dConfig::new([self.d_input, self.d_output], stride)
.with_stride(stride)
.init(device),
norm: match self.enable_patch_norm {
true => Some(LayerNormConfig::new(self.d_output()).init(device)),
false => None,
},
}
}
}
#[derive(Module, Debug)]
pub struct PatchEmbed<B: Backend> {
pub input_resolution: [usize; 2],
pub patch_size: usize,
pub projection: Conv2d<B>,
pub norm: Option<LayerNorm<B>>,
}
impl<B: Backend> PatchEmbedMeta for PatchEmbed<B> {
fn input_resolution(&self) -> [usize; 2] {
self.input_resolution
}
fn patch_size(&self) -> usize {
self.patch_size
}
fn d_input(&self) -> usize {
self.projection.weight.dims()[1]
}
fn d_output(&self) -> usize {
self.projection.weight.dims()[0]
}
fn enable_patch_norm(&self) -> bool {
self.norm.is_some()
}
}
impl<B: Backend> PatchEmbed<B> {
#[must_use]
pub fn forward(
&self,
x: Tensor<B, 4>,
) -> Tensor<B, 3> {
assert_shape_contract_periodically!(
["batch", "d_input", "height", "width"],
&x.dims(),
&[
("d_input", self.d_input()),
("height", self.input_height()),
("width", self.input_width()),
]
);
let batch = x.dims()[0];
let x = self.projection.forward(x);
assert_shape_contract_periodically!(
["batch", "d_output", "patches_height", "patches_width"],
&x.dims(),
&[
("batch", batch),
("d_output", self.d_output()),
("patches_height", self.patches_height()),
("patches_width", self.patches_width()),
],
);
let x = x.flatten(2, 3);
let x = x.swap_dims(1, 2);
assert_shape_contract_periodically!(
["batch", "num_patches", "d_output"],
&x.dims(),
&[
("batch", batch),
("num_patches", self.num_patches()),
("d_output", self.d_output()),
],
);
let x = match self.norm {
None => x,
Some(ref norm) => norm.forward(x),
};
assert_shape_contract_periodically!(
["batch", "num_patches", "d_output"],
&x.dims(),
&[
("batch", batch),
("num_patches", self.num_patches()),
("d_output", self.d_output()),
],
);
x
}
}
#[cfg(test)]
mod tests {
use burn::{
backend::NdArray,
tensor::TensorData,
};
use super::*;
#[test]
fn test_patch_embed_meta() {
let config = PatchEmbedConfig {
input_resolution: [224, 224],
patch_size: 16,
d_input: 3,
d_output: 768,
enable_patch_norm: true,
};
assert_eq!(config.input_resolution(), [224, 224]);
assert_eq!(config.patch_size(), 16);
assert_eq!(config.d_input(), 3);
assert_eq!(config.d_output(), 768);
assert!(config.enable_patch_norm());
assert_eq!(config.patches_resolution(), [14, 14]);
assert_eq!(config.patches_height(), 14);
assert_eq!(config.patches_width(), 14);
assert_eq!(config.num_patches(), 196);
assert_eq!(config.d_output(), 768);
assert!(config.enable_patch_norm());
let device = Default::default();
let patch_embed = config.init::<NdArray>(&device);
assert_eq!(patch_embed.input_resolution(), [224, 224]);
assert_eq!(patch_embed.patch_size(), 16);
assert_eq!(patch_embed.d_input(), 3);
assert_eq!(patch_embed.d_output(), 768);
assert!(patch_embed.enable_patch_norm());
assert_eq!(patch_embed.patches_resolution(), [14, 14]);
assert_eq!(patch_embed.patches_height(), 14);
assert_eq!(patch_embed.patches_width(), 14);
assert_eq!(patch_embed.num_patches(), 196);
assert_eq!(patch_embed.d_output(), 768);
assert!(patch_embed.enable_patch_norm());
}
#[should_panic(expected = "Input resolution must be divisible by patch size")]
#[test]
fn test_patch_embed_invalid_resolution() {
let config = PatchEmbedConfig {
input_resolution: [224, 223], patch_size: 16,
d_input: 3,
d_output: 768,
enable_patch_norm: true,
};
let device = Default::default();
let _d = config.init::<NdArray>(&device);
}
#[test]
fn test_patch_embed_forward() {
let config = PatchEmbedConfig {
input_resolution: [224, 224],
patch_size: 16,
d_input: 3,
d_output: 768,
enable_patch_norm: true,
};
let device = Default::default();
let patch_embed = config.init::<NdArray>(&device);
let input = Tensor::<NdArray, 4>::from_data(
TensorData::new(vec![1.0; 3 * 224 * 224], [1, 3, 224, 224]),
&device,
);
let output = patch_embed.forward(input);
assert_eq!(output.dims(), [1, 196, 768]);
}
#[test]
fn test_patch_embed_without_norm() {
let config = PatchEmbedConfig {
input_resolution: [224, 224],
patch_size: 16,
d_input: 3,
d_output: 768,
enable_patch_norm: false,
};
let device = Default::default();
let patch_embed = config.init::<NdArray>(&device);
let input = Tensor::<NdArray, 4>::from_data(
TensorData::new(vec![1.0; 3 * 224 * 224], [1, 3, 224, 224]),
&device,
);
let output = patch_embed.forward(input);
assert_eq!(output.dims(), [1, 196, 768]);
}
}