use burn::module::{Module, Param};
use burn::nn::{Dropout, DropoutConfig, Embedding, EmbeddingConfig};
use burn::tensor::backend::Backend;
use burn::tensor::{Distribution as TensorDistribution, Int, Tensor, TensorData, activation};
use rand::distributions::{Distribution, WeightedIndex};
use rand::prelude::*;
use std::cmp::Ordering;
use crate::kernel::{BlockPattern1d, relu_lowrank};
use super::attention::Attention;
use super::config::{BDHConfig, FusedKernelConfig};
#[cfg(feature = "viz")]
use super::state::LayerVizState;
use super::state::{LayerState, ModelState};
const LAYER_NORM_EPS: f32 = 1e-5;
#[derive(Module, Debug)]
pub struct BDH<B: Backend> {
n_layer: usize,
n_embd: usize,
n_head: usize,
mlp_internal_dim_multiplier: usize,
vocab_size: usize,
kernel: FusedKernelConfig,
embed: Embedding<B>,
dropout: Dropout,
attention: Attention<B>,
encoder: Param<Tensor<B, 3>>,
encoder_v: Param<Tensor<B, 3>>,
decoder: Param<Tensor<B, 2>>,
lm_head: Param<Tensor<B, 2>>,
}
impl<B: Backend> BDH<B> {
pub fn new(config: BDHConfig, device: &B::Device) -> Self {
let embed = EmbeddingConfig::new(config.vocab_size, config.n_embd).init(device);
let dropout = DropoutConfig::new(config.dropout).init();
let latent_per_head = config.latent_per_head();
let latent_total = config.latent_total();
let attention = Attention::new(
latent_per_head,
config.n_head,
device,
&config.fused_kernels,
);
let weight_init = |shape: [usize; 2]| {
Tensor::<B, 2>::random(shape, TensorDistribution::Normal(0.0, 0.02), device)
};
let encoder = Param::from_tensor(Tensor::<B, 3>::random(
[config.n_head, config.n_embd, latent_per_head],
TensorDistribution::Normal(0.0, 0.02),
device,
));
let encoder_v = Param::from_tensor(Tensor::<B, 3>::random(
[config.n_head, config.n_embd, latent_per_head],
TensorDistribution::Normal(0.0, 0.02),
device,
));
let decoder = Param::from_tensor(weight_init([latent_total, config.n_embd]));
let lm_head = Param::from_tensor(weight_init([config.n_embd, config.vocab_size]));
Self {
n_layer: config.n_layer,
n_embd: config.n_embd,
n_head: config.n_head,
mlp_internal_dim_multiplier: config.mlp_internal_dim_multiplier,
vocab_size: config.vocab_size,
kernel: config.fused_kernels,
embed,
dropout,
attention,
encoder,
encoder_v,
decoder,
lm_head,
}
}
fn layer_norm<const D: usize>(&self, tensor: Tensor<B, D>) -> Tensor<B, D> {
let (var, mean) = tensor.clone().var_mean_bias(D - 1);
tensor.sub(mean).div(var.add_scalar(LAYER_NORM_EPS).sqrt())
}
pub fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let mut state = self.init_state();
self.forward_with_state(tokens, &mut state)
}
pub fn forward_fast(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let embedded = self.embed.forward(tokens);
let [batch, time, embd] = embedded.shape().dims::<3>();
let mut current = embedded.reshape([batch, 1, time, embd]);
current = self.layer_norm(current);
let encoder_raw = self.encoder.val();
let [heads, embd_enc, latent] = encoder_raw.shape().dims::<3>();
let encoder = encoder_raw.reshape([1, heads, embd_enc, latent]);
let encoder_v_raw = self.encoder_v.val();
let [heads_v, embd_v, latent_v] = encoder_v_raw.shape().dims::<3>();
let encoder_v = encoder_v_raw.reshape([1, heads_v, embd_v, latent_v]);
let decoder = self.decoder.val();
let fused = self.kernel.enabled;
let latent_pattern: &BlockPattern1d = &self.kernel.block_sparse.latent;
for _ in 0..self.n_layer {
let x_sparse = if fused {
relu_lowrank::fused_forward(
current.clone(),
encoder.clone(),
None,
self.kernel.relu_threshold,
latent_pattern,
)
} else {
let mut x_latent = current.clone().matmul(encoder.clone());
if self.kernel.relu_threshold != 0.0 {
x_latent = x_latent.sub_scalar(self.kernel.relu_threshold);
}
activation::relu(x_latent)
};
let attn = self.attention.forward(x_sparse.clone(), current.clone());
let attn = self.layer_norm(attn);
let y_sparse = if fused {
relu_lowrank::fused_forward(
attn.clone(),
encoder_v.clone(),
None,
self.kernel.relu_threshold,
latent_pattern,
)
} else {
let mut y_latent = attn.matmul(encoder_v.clone());
if self.kernel.relu_threshold != 0.0 {
y_latent = y_latent.sub_scalar(self.kernel.relu_threshold);
}
activation::relu(y_latent)
};
let xy_sparse = x_sparse.clone() * y_sparse;
let xy_sparse = self.dropout.forward(xy_sparse);
let mixed = xy_sparse.clone().swap_dims(1, 2);
let [batch, time, heads, latent] = mixed.shape().dims();
let mixed_flat = mixed.reshape([batch * time, heads * latent]);
let mlp_flat = mixed_flat.matmul(decoder.clone());
let mlp_out = mlp_flat.reshape([batch, 1, time, self.n_embd]);
let mlp_out = self.layer_norm(mlp_out);
current = self.layer_norm(current + mlp_out);
}
let [batch, _, time, dim] = current.shape().dims();
current
.reshape([batch * time, dim])
.matmul(self.lm_head.val())
.reshape([batch, time, self.vocab_size])
}
pub fn generate(
&self,
mut indices: Tensor<B, 2, Int>,
max_new_tokens: usize,
temperature: f32,
top_k: Option<usize>,
) -> Tensor<B, 2, Int> {
let [batch, _] = indices.shape().dims();
assert_eq!(batch, 1, "generation currently supports batch size 1");
let mut state = self.init_state();
let mut logits = self.forward_with_state(indices.clone(), &mut state);
let [_, mut time, vocab] = logits.shape().dims();
assert_eq!(time, indices.shape().dims::<2>()[1]);
let mut last_logits = logits
.slice_dim(1, (time - 1)..time)
.reshape([vocab])
.div_scalar(temperature);
for _ in 0..max_new_tokens {
let mut logits_values = last_logits
.clone()
.to_data()
.convert::<f32>()
.into_vec::<f32>()
.expect("logits to vec");
if let Some(k) = top_k
&& k > 0
&& k < vocab
{
let mut sorted = logits_values.clone();
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(Ordering::Equal));
let threshold = sorted[k - 1];
for value in logits_values.iter_mut() {
if *value < threshold {
*value = f32::NEG_INFINITY;
}
}
}
let max_logit = logits_values
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<f32> = logits_values
.iter()
.map(|value| (value - max_logit).exp())
.collect();
let sum: f32 = probs.iter().sum();
if sum == 0.0 || sum.is_nan() {
let uniform = 1.0 / vocab as f32;
for p in probs.iter_mut() {
*p = uniform;
}
} else {
for p in probs.iter_mut() {
*p /= sum;
}
}
let dist = WeightedIndex::new(&probs).expect("valid probability distribution");
let mut rng = thread_rng();
let next = dist.sample(&mut rng) as i64;
let next_token = Tensor::<B, 2, Int>::from_data(
TensorData::new(vec![next], [1, 1]),
&indices.device(),
);
indices = Tensor::cat(vec![indices, next_token.clone()], 1);
logits = self.forward_with_state(next_token, &mut state);
let [_, new_time, _] = logits.shape().dims();
time = new_time;
last_logits = logits
.slice_dim(1, (time - 1)..time)
.reshape([vocab])
.div_scalar(temperature);
}
indices
}
pub fn init_state(&self) -> ModelState<B> {
ModelState::new(self.n_layer)
}
fn recurrent_attention(
&self,
query: Tensor<B, 4>,
value: Tensor<B, 4>,
layer_state: &mut LayerState<B>,
position: usize,
) -> Tensor<B, 4> {
let query = self.attention.rotate_positions(query, position);
let [batch, heads, time, latent] = query.shape().dims();
let n_embd = value.shape().dims::<4>()[3];
let device = value.device();
let decay = self
.attention
.alibi_decay()
.map(|tensor| tensor.reshape([1, heads, 1, 1]));
let mut rho = match layer_state.rho.take() {
Some(existing) => {
let dims = existing.shape().dims::<4>();
if dims == [batch, heads, latent, n_embd] {
existing
} else {
Tensor::<B, 4>::zeros([batch, heads, latent, n_embd], &device)
}
}
None => Tensor::<B, 4>::zeros([batch, heads, latent, n_embd], &device),
};
let mut outputs: Vec<Tensor<B, 4>> = Vec::with_capacity(time);
for t in 0..time {
let x_t = query.clone().slice_dim(2, t..t + 1);
let v_t = value.clone().slice_dim(2, t..t + 1).repeat_dim(1, heads);
let x_t_latent = x_t.swap_dims(2, 3);
let attn_t = (rho.clone() * x_t_latent.clone())
.sum_dim(2)
.reshape([batch, heads, 1, n_embd]);
outputs.push(attn_t);
rho = rho + x_t_latent * v_t;
if let Some(decay) = &decay {
rho = rho * decay.clone();
}
}
layer_state.rho = Some(rho);
Tensor::cat(outputs, 2)
}
pub fn forward_with_state(
&self,
tokens: Tensor<B, 2, Int>,
state: &mut ModelState<B>,
) -> Tensor<B, 3> {
assert_eq!(
state.layers.len(),
self.n_layer,
"model state layers mismatch"
);
let embedded = self.embed.forward(tokens);
let [batch, time, embd] = embedded.shape().dims::<3>();
let mut current = embedded.reshape([batch, 1, time, embd]);
current = self.layer_norm(current);
let encoder_raw = self.encoder.val();
let [heads, embd_enc, latent] = encoder_raw.shape().dims::<3>();
let encoder = encoder_raw.reshape([1, heads, embd_enc, latent]);
let encoder_v_raw = self.encoder_v.val();
let [heads_v, embd_v, latent_v] = encoder_v_raw.shape().dims::<3>();
let encoder_v = encoder_v_raw.reshape([1, heads_v, embd_v, latent_v]);
let decoder = self.decoder.val();
let fused = self.kernel.enabled;
let latent_pattern: &BlockPattern1d = &self.kernel.block_sparse.latent;
let start_pos = state.position;
for layer_state in &mut state.layers {
let x_sparse = if fused {
relu_lowrank::fused_forward(
current.clone(),
encoder.clone(),
None,
self.kernel.relu_threshold,
latent_pattern,
)
} else {
let mut x_latent = current.clone().matmul(encoder.clone());
if self.kernel.relu_threshold != 0.0 {
x_latent = x_latent.sub_scalar(self.kernel.relu_threshold);
}
activation::relu(x_latent)
};
let attn =
self.recurrent_attention(x_sparse.clone(), current.clone(), layer_state, start_pos);
let attn = self.layer_norm(attn);
let y_sparse = if fused {
relu_lowrank::fused_forward(
attn.clone(),
encoder_v.clone(),
None,
self.kernel.relu_threshold,
latent_pattern,
)
} else {
let mut y_latent = attn.matmul(encoder_v.clone());
if self.kernel.relu_threshold != 0.0 {
y_latent = y_latent.sub_scalar(self.kernel.relu_threshold);
}
activation::relu(y_latent)
};
#[cfg(feature = "viz")]
let xy_sparse = x_sparse.clone() * y_sparse.clone();
#[cfg(not(feature = "viz"))]
let xy_sparse = x_sparse * y_sparse;
let xy_sparse = self.dropout.forward(xy_sparse);
let mixed = xy_sparse.clone().swap_dims(1, 2);
let [batch, time, heads, latent] = mixed.shape().dims();
#[cfg(feature = "viz")]
if time > 0 {
let last = time - 1;
let x_last = x_sparse
.clone()
.slice_dim(2, last..time)
.reshape([batch, heads, latent])
.slice_dim(0, 0..1)
.reshape([heads, latent]);
let y_last = y_sparse
.clone()
.slice_dim(2, last..time)
.reshape([batch, heads, latent])
.slice_dim(0, 0..1)
.reshape([heads, latent]);
let xy_last = xy_sparse
.clone()
.slice_dim(2, last..time)
.reshape([batch, heads, latent])
.slice_dim(0, 0..1)
.reshape([heads, latent]);
let device = x_last.device();
let rho_last = match layer_state.rho.as_ref() {
Some(rho) => {
let dims = rho.shape().dims::<4>();
if dims == [batch, heads, latent, self.n_embd] {
let rho_energy = rho
.clone()
.abs()
.sum_dim(3)
.div_scalar(self.n_embd as f32)
.reshape([batch, heads, latent])
.sum_dim(0)
.div_scalar(batch as f32);
rho_energy.reshape([heads, latent])
} else {
Tensor::<B, 2>::zeros([heads, latent], &device)
}
}
None => Tensor::<B, 2>::zeros([heads, latent], &device),
};
layer_state.viz = Some(LayerVizState {
x_last,
y_last,
xy_last,
rho_last,
});
}
let mixed_flat = mixed.reshape([batch * time, heads * latent]);
let mlp_flat = mixed_flat.matmul(decoder.clone());
let mlp_out = mlp_flat.reshape([batch, 1, time, self.n_embd]);
let mlp_out = self.layer_norm(mlp_out);
current = self.layer_norm(current + mlp_out);
}
let [batch, _, time, dim] = current.shape().dims();
state.position = state.position.saturating_add(time);
current
.reshape([batch * time, dim])
.matmul(self.lm_head.val())
.reshape([batch, time, self.vocab_size])
}
}