use burn::prelude::*;
use burn::tensor::Distribution;
use crate::model::encoder::EncoderTransformer;
use crate::model::decoder::DecoderTransformer;
use crate::model::rope::RotaryEmbedding;
#[derive(Module, Debug)]
pub struct EncoderDecoder<B: Backend> {
pub encoder: EncoderTransformer<B>,
pub decoder: DecoderTransformer<B>,
pub global_sigma: f32,
}
impl<B: Backend> EncoderDecoder<B> {
pub fn new(
input_dim: usize, encoder_output_dim: usize, dim: usize, t_dim: usize, n_layers: usize, head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
hidden_dim: usize,
norm_eps: f64,
downsample_factor: usize, global_sigma: f32, device: &B::Device,
) -> Self {
Self {
encoder: EncoderTransformer::new(
input_dim, encoder_output_dim, dim, n_layers,
head_dim, n_heads, n_kv_heads, hidden_dim,
norm_eps, downsample_factor, device,
),
decoder: DecoderTransformer::new(
input_dim, encoder_output_dim, dim, t_dim, n_layers,
head_dim, n_heads, n_kv_heads, hidden_dim, norm_eps, device,
),
global_sigma,
}
}
pub fn sample(
&self,
encoder_input: Tensor<B, 3>,
tok_idx: Tensor<B, 2, Int>,
rope: &RotaryEmbedding<B>,
sample_steps: usize,
cfg: f32,
) -> Tensor<B, 3> {
let device = encoder_input.device();
let [b, s, d] = encoder_input.dims();
let dt = 1.0_f32 / sample_steps as f32;
let enc_out = self.encoder.forward(encoder_input.clone(), tok_idx.clone(), rope);
let sigma = self.global_sigma as f64;
let mut z = Tensor::<B, 3>::random(
[b, s, d],
Distribution::Normal(0.0, sigma),
&device,
);
for i in (1..=sample_steps).rev() {
let t_val = dt * i as f32;
let time_t = Tensor::<B, 3>::full([b, 1, 1], t_val, &device);
let vc = self.decoder.forward(
z.clone(), enc_out.clone(), time_t.clone(), tok_idx.clone(), rope,
);
let vc = if (cfg - 1.0).abs() > 1e-4 {
let enc_zeros = Tensor::zeros([b, s, enc_out.dims()[2]], &device);
let vc_uncond = self.decoder.forward(
z.clone(), enc_zeros, time_t, tok_idx.clone(), rope,
);
vc_uncond.clone() + (vc - vc_uncond).mul_scalar(cfg)
} else {
vc
};
z = z - vc.mul_scalar(dt);
}
z
}
}