use burn::nn::{Gelu, Linear, LinearConfig};
use burn::prelude::*;
#[derive(Config, Debug)]
pub struct DenoisingNetworkConfig {
pub input_dim: usize,
pub cond_dim: usize,
#[config(default = 256)]
pub hidden_dim: usize,
#[config(default = 128)]
pub time_embed_dim: usize,
#[config(default = 3)]
pub num_blocks: usize,
}
#[derive(Module, Debug)]
pub struct ResidualBlock<B: Backend> {
linear1: Linear<B>,
linear2: Linear<B>,
cond_proj: Linear<B>,
time_proj: Linear<B>,
activation: Gelu,
}
#[derive(Config, Debug)]
pub struct ResidualBlockConfig {
pub hidden_dim: usize,
pub cond_dim: usize,
pub time_dim: usize,
}
impl ResidualBlockConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> ResidualBlock<B> {
ResidualBlock {
linear1: LinearConfig::new(self.hidden_dim, self.hidden_dim).init(device),
linear2: LinearConfig::new(self.hidden_dim, self.hidden_dim).init(device),
cond_proj: LinearConfig::new(self.cond_dim, self.hidden_dim).init(device),
time_proj: LinearConfig::new(self.time_dim, self.hidden_dim).init(device),
activation: Gelu::new(),
}
}
}
impl<B: Backend> ResidualBlock<B> {
pub fn forward(
&self,
x: Tensor<B, 2>,
cond: &Tensor<B, 2>,
time_emb: &Tensor<B, 2>,
) -> Tensor<B, 2> {
let h = self.linear1.forward(x.clone());
let h = h + self.cond_proj.forward(cond.clone());
let h = h + self.time_proj.forward(time_emb.clone());
let h = self.activation.forward(h);
let h = self.linear2.forward(h);
x + h
}
}
#[derive(Module, Debug)]
pub struct DenoisingNetwork<B: Backend> {
input_proj: Linear<B>,
blocks: Vec<ResidualBlock<B>>,
output_proj: Linear<B>,
time_mlp: Linear<B>,
time_mlp2: Linear<B>,
time_activation: Gelu,
}
impl DenoisingNetworkConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> DenoisingNetwork<B> {
let blocks = (0..self.num_blocks)
.map(|_| {
ResidualBlockConfig {
hidden_dim: self.hidden_dim,
cond_dim: self.hidden_dim,
time_dim: self.time_embed_dim,
}
.init(device)
})
.collect();
DenoisingNetwork {
input_proj: LinearConfig::new(self.input_dim, self.hidden_dim).init(device),
blocks,
output_proj: LinearConfig::new(self.hidden_dim, self.input_dim).init(device),
time_mlp: LinearConfig::new(self.time_embed_dim, self.time_embed_dim).init(device),
time_mlp2: LinearConfig::new(self.cond_dim, self.hidden_dim).init(device),
time_activation: Gelu::new(),
}
}
}
impl<B: Backend> DenoisingNetwork<B> {
pub fn forward(
&self,
noisy_actions: Tensor<B, 2>,
cond: Tensor<B, 2>,
time_emb: Tensor<B, 2>,
) -> Tensor<B, 2> {
let time_emb = self.time_mlp.forward(time_emb);
let time_emb = self.time_activation.forward(time_emb);
let cond_proj = self.time_mlp2.forward(cond);
let mut h = self.input_proj.forward(noisy_actions);
for block in &self.blocks {
h = block.forward(h, &cond_proj, &time_emb);
}
self.output_proj.forward(h)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray;
#[test]
fn test_denoising_network_output_shape() {
let device = <TestBackend as Backend>::Device::default();
let config = DenoisingNetworkConfig {
input_dim: 32, cond_dim: 40, hidden_dim: 64,
time_embed_dim: 32,
num_blocks: 2,
};
let network = config.init::<TestBackend>(&device);
let batch_size = 4;
let noisy = Tensor::<TestBackend, 2>::zeros([batch_size, 32], &device);
let cond = Tensor::<TestBackend, 2>::zeros([batch_size, 40], &device);
let time = Tensor::<TestBackend, 2>::zeros([batch_size, 32], &device);
let output = network.forward(noisy, cond, time);
assert_eq!(output.dims(), [batch_size, 32]);
}
#[test]
fn test_residual_block_preserves_shape() {
let device = <TestBackend as Backend>::Device::default();
let block = ResidualBlockConfig {
hidden_dim: 64,
cond_dim: 64,
time_dim: 32,
}
.init::<TestBackend>(&device);
let x = Tensor::<TestBackend, 2>::zeros([2, 64], &device);
let cond = Tensor::<TestBackend, 2>::zeros([2, 64], &device);
let time = Tensor::<TestBackend, 2>::zeros([2, 32], &device);
let out = block.forward(x, &cond, &time);
assert_eq!(out.dims(), [2, 64]);
}
}