use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::{silu, softmax};
use burn::tensor::backend::Backend;
use burn::tensor::{ElementConversion, IndexingUpdateOp, Int, Tensor};
#[derive(Clone)]
pub struct ExpertFFN<B: Backend> {
pub(crate) gate_proj: Linear<B>,
pub(crate) up_proj: Linear<B>,
pub(crate) down_proj: Linear<B>,
}
impl<B: Backend> ExpertFFN<B> {
pub fn new(device: &B::Device, hidden_size: usize, intermediate_size: usize) -> Self {
let gate_proj = LinearConfig::new(hidden_size, intermediate_size).init(device);
let up_proj = LinearConfig::new(hidden_size, intermediate_size).init(device);
let down_proj = LinearConfig::new(intermediate_size, hidden_size).init(device);
Self {
gate_proj,
up_proj,
down_proj,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let gate = silu(self.gate_proj.forward(x.clone()));
let up = self.up_proj.forward(x);
self.down_proj.forward(gate * up)
}
}
#[derive(Clone)]
pub struct MoERouter<B: Backend> {
pub(crate) gate: Linear<B>,
pub(crate) num_experts: usize,
pub(crate) num_experts_per_tok: usize,
}
impl<B: Backend> MoERouter<B> {
pub fn new(
device: &B::Device,
hidden_size: usize,
num_experts: usize,
num_experts_per_tok: usize,
) -> Self {
let gate = LinearConfig::new(hidden_size, num_experts).init(device);
Self {
gate,
num_experts,
num_experts_per_tok,
}
}
pub fn forward(&self, hidden_states: Tensor<B, 3>) -> (Tensor<B, 3, Int>, Tensor<B, 3>) {
let logits = self.gate.forward(hidden_states);
let (values, indices) = logits.topk_with_indices(self.num_experts_per_tok, 2);
let weights = softmax(values, 2);
(indices, weights)
}
}
#[derive(Clone)]
pub struct MoELayer<B: Backend> {
pub(crate) router: MoERouter<B>,
pub(crate) experts: Vec<ExpertFFN<B>>,
pub(crate) shared_expert: Option<ExpertFFN<B>>,
}
impl<B: Backend> MoELayer<B> {
pub fn new(
device: &B::Device,
hidden_size: usize,
intermediate_size: usize,
num_experts: usize,
num_experts_per_tok: usize,
n_shared_experts: usize,
) -> Self {
let router = MoERouter::new(device, hidden_size, num_experts, num_experts_per_tok);
let mut experts = Vec::with_capacity(num_experts);
for _ in 0..num_experts {
experts.push(ExpertFFN::new(device, hidden_size, intermediate_size));
}
let shared_expert = (n_shared_experts > 0)
.then(|| ExpertFFN::new(device, hidden_size, intermediate_size));
Self {
router,
experts,
shared_expert,
}
}
pub fn forward(&self, hidden_states: Tensor<B, 3>) -> Tensor<B, 3> {
let device = hidden_states.device();
let [batch_size, seq_len, hidden_size] = hidden_states.dims();
let tokens = batch_size.saturating_mul(seq_len);
let top_k = self.router.num_experts_per_tok;
let assignments = tokens.saturating_mul(top_k);
if assignments == 0 {
return hidden_states;
}
let (expert_indices, expert_weights) = self.router.forward(hidden_states.clone());
let expert_indices = expert_indices.reshape([assignments]);
let expert_weights = expert_weights.reshape([assignments]);
let token_indices = Tensor::<B, 1, Int>::arange(0..tokens as i64, &device)
.unsqueeze_dim::<2>(1)
.repeat(&[1, top_k])
.reshape([assignments]);
let (sorted_experts, sort_indices) = expert_indices.sort_with_indices(0);
let sorted_tokens = token_indices.select(0, sort_indices.clone());
let sorted_weights = expert_weights.select(0, sort_indices);
let ones = Tensor::<B, 1, Int>::ones([assignments], &device);
let counts = Tensor::<B, 1, Int>::zeros([self.router.num_experts], &device)
.scatter(0, sorted_experts, ones, IndexingUpdateOp::Add);
let offsets = counts.cumsum(0);
let offsets_data = offsets
.into_data()
.into_vec::<B::IntElem>()
.expect("MoE expert offsets should match backend int element type");
let hidden_states_flat = hidden_states.clone().reshape([tokens, hidden_size]);
let mut output = Tensor::<B, 2>::zeros([tokens, hidden_size], &device);
let mut start = 0usize;
for (expert_idx, expert) in self.experts.iter().enumerate() {
let end = offsets_data
.get(expert_idx)
.map(|v| v.elem::<i64>() as usize)
.unwrap_or(start);
if start != end {
let token_slice = sorted_tokens.clone().slice([start..end]);
let weight_slice = sorted_weights.clone().slice([start..end]);
let selected = hidden_states_flat.clone().select(0, token_slice.clone());
let selected = selected.reshape([1, end - start, hidden_size]);
let expert_output = expert.forward(selected).reshape([end - start, hidden_size]);
let weighted_output = expert_output * weight_slice.reshape([end - start, 1]);
output = output.select_assign(0, token_slice, weighted_output, IndexingUpdateOp::Add);
}
start = end;
}
let mut output = output.reshape([batch_size, seq_len, hidden_size]);
if let Some(shared_expert) = &self.shared_expert {
output = output + shared_expert.forward(hidden_states);
}
output
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray<f32>;
#[test]
fn moe_layer_forward_preserves_shape() {
let device = <TestBackend as Backend>::Device::default();
let hidden_size = 64;
let intermediate_size = 128;
let num_experts = 4;
let num_experts_per_tok = 2;
let n_shared_experts = 1;
let moe = MoELayer::<TestBackend>::new(
&device,
hidden_size,
intermediate_size,
num_experts,
num_experts_per_tok,
n_shared_experts,
);
let input = Tensor::<TestBackend, 3>::random(
[2, 8, hidden_size],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = moe.forward(input.clone());
let [batch, seq, hidden] = output.dims();
assert_eq!(batch, 2);
assert_eq!(seq, 8);
assert_eq!(hidden, hidden_size);
}
#[test]
fn moe_router_selects_top_k_experts() {
let device = <TestBackend as Backend>::Device::default();
let hidden_size = 32;
let num_experts = 8;
let num_experts_per_tok = 2;
let router =
MoERouter::<TestBackend>::new(&device, hidden_size, num_experts, num_experts_per_tok);
let input = Tensor::<TestBackend, 3>::random(
[1, 4, hidden_size],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let (indices, weights) = router.forward(input);
let [b, s, k] = indices.dims();
assert_eq!(b, 1);
assert_eq!(s, 4);
assert_eq!(k, num_experts_per_tok);
let [wb, ws, wk] = weights.dims();
assert_eq!(wb, 1);
assert_eq!(ws, 4);
assert_eq!(wk, num_experts_per_tok);
}
#[test]
fn expert_ffn_forward_preserves_shape() {
let device = <TestBackend as Backend>::Device::default();
let hidden_size = 64;
let intermediate_size = 256;
let expert = ExpertFFN::<TestBackend>::new(&device, hidden_size, intermediate_size);
let input = Tensor::<TestBackend, 3>::random(
[1, 16, hidden_size],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = expert.forward(input);
let [batch, seq, hidden] = output.dims();
assert_eq!(batch, 1);
assert_eq!(seq, 16);
assert_eq!(hidden, hidden_size);
}
}