use burn::{
module::Module,
nn::{Dropout, DropoutConfig, LayerNorm, LayerNormConfig, Linear, LinearConfig},
prelude::Backend,
tensor::{activation, Tensor},
};
#[derive(Debug, Clone)]
pub struct MlpProjectorConfig {
pub input_dim: usize,
pub output_dim: usize,
pub dropout: f64,
}
impl MlpProjectorConfig {
pub fn new(input_dim: usize, output_dim: usize) -> Self {
Self {
input_dim,
output_dim,
dropout: 0.0,
}
}
pub fn init<B: Backend>(&self, device: &B::Device) -> MlpProjector<B> {
MlpProjector {
norm: LayerNormConfig::new(self.input_dim).init::<B>(device),
linear: LinearConfig::new(self.input_dim, self.output_dim).init::<B>(device),
dropout: DropoutConfig::new(self.dropout).init(),
}
}
}
#[derive(Module, Debug)]
pub struct MlpProjector<B: Backend> {
norm: LayerNorm<B>,
linear: Linear<B>,
dropout: Dropout,
}
impl<B: Backend> MlpProjector<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.norm.forward(x);
let x = self.linear.forward(x);
let x = activation::gelu(x);
self.dropout.forward(x)
}
}