#![allow(dead_code)]
use candle_core::{Result, Tensor, D};
use candle_nn::{Linear, Module};
pub struct MoeLayer {
pub gate: Linear,
pub experts: Vec<Linear>, pub top_k: usize,
}
impl MoeLayer {
pub fn new(gate: Linear, experts: Vec<Linear>, top_k: usize) -> Self {
Self {
gate,
experts,
top_k,
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (batch_size, seq_len, hidden_dim) = x.dims3()?;
let x_flat = x.flatten(0, 1)?;
let gate_logits = self.gate.forward(&x_flat)?;
let gate_probs = candle_nn::ops::softmax(&gate_logits, D::Minus1)?;
let top_k_indices = gate_probs.argmax(D::Minus1)?;
let mut final_output = Tensor::zeros_like(&x_flat)?;
let top_expert_idx = top_k_indices.get(0)?.to_scalar::<i64>()? as usize;
if top_expert_idx < self.experts.len() {
final_output = self.experts[top_expert_idx].forward(&x_flat)?;
}
final_output.reshape((batch_size, seq_len, hidden_dim))
}
}