use crate::tensor::DenseTensor;
use crate::tensor::traits::{TensorOps, TensorBase};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FeedForwardType {
Standard,
SwiGLU,
GeGLU,
}
#[derive(Debug, Clone)]
pub struct FeedForward {
pub ff_type: FeedForwardType,
pub fc1: Option<DenseTensor>,
pub fc2: Option<DenseTensor>,
pub activation: Option<Activation>,
pub gate_proj: Option<DenseTensor>,
pub up_proj: Option<DenseTensor>,
pub down_proj: Option<DenseTensor>,
pub intermediate_dim: usize,
pub hidden_dim: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Activation {
ReLU,
GELU,
SiLU,
}
impl FeedForward {
pub fn standard(fc1: DenseTensor, fc2: DenseTensor, activation: Activation) -> Self {
let hidden_dim = fc1.shape()[0];
let intermediate_dim = fc1.shape()[1];
Self {
ff_type: FeedForwardType::Standard,
fc1: Some(fc1),
fc2: Some(fc2),
activation: Some(activation),
gate_proj: None,
up_proj: None,
down_proj: None,
intermediate_dim,
hidden_dim,
}
}
pub fn swiglu(gate_proj: DenseTensor, up_proj: DenseTensor, down_proj: DenseTensor) -> Self {
let hidden_dim = gate_proj.shape()[0];
let intermediate_dim = gate_proj.shape()[1];
Self {
ff_type: FeedForwardType::SwiGLU,
fc1: None,
fc2: None,
activation: None,
gate_proj: Some(gate_proj),
up_proj: Some(up_proj),
down_proj: Some(down_proj),
intermediate_dim,
hidden_dim,
}
}
pub fn geglu(gate_proj: DenseTensor, up_proj: DenseTensor, down_proj: DenseTensor) -> Self {
let hidden_dim = gate_proj.shape()[0];
let intermediate_dim = gate_proj.shape()[1];
Self {
ff_type: FeedForwardType::GeGLU,
fc1: None,
fc2: None,
activation: None,
gate_proj: Some(gate_proj),
up_proj: Some(up_proj),
down_proj: Some(down_proj),
intermediate_dim,
hidden_dim,
}
}
pub fn forward(&self, x: &DenseTensor) -> DenseTensor {
match self.ff_type {
FeedForwardType::Standard => self.forward_standard(x),
FeedForwardType::SwiGLU => self.forward_swiglu(x),
FeedForwardType::GeGLU => self.forward_geglu(x),
}
}
fn forward_standard(&self, x: &DenseTensor) -> DenseTensor {
let fc1 = self.fc1.as_ref().expect("FC1 not initialized");
let fc2 = self.fc2.as_ref().expect("FC2 not initialized");
let activation = self.activation.expect("Activation not set");
let hidden = x.bmm_broadcast_weight(fc1);
let activated = match activation {
Activation::ReLU => hidden.relu(),
Activation::GELU => hidden.gelu(),
Activation::SiLU => hidden.silu(),
};
activated.bmm_broadcast_weight(fc2)
}
fn forward_swiglu(&self, x: &DenseTensor) -> DenseTensor {
let gate_proj = self.gate_proj.as_ref().expect("gate_proj not initialized");
let up_proj = self.up_proj.as_ref().expect("up_proj not initialized");
let down_proj = self.down_proj.as_ref().expect("down_proj not initialized");
let gate = x.bmm_broadcast_weight(gate_proj);
let gate = gate.silu();
let up = x.bmm_broadcast_weight(up_proj);
let intermediate = gate.mul(&up);
intermediate.bmm_broadcast_weight(down_proj)
}
fn forward_geglu(&self, x: &DenseTensor) -> DenseTensor {
let gate_proj = self.gate_proj.as_ref().expect("gate_proj not initialized");
let up_proj = self.up_proj.as_ref().expect("up_proj not initialized");
let down_proj = self.down_proj.as_ref().expect("down_proj not initialized");
let gate = x.bmm_broadcast_weight(gate_proj);
let gate = gate.gelu();
let up = x.bmm_broadcast_weight(up_proj);
let intermediate = gate.mul(&up);
intermediate.bmm_broadcast_weight(down_proj)
}
pub fn intermediate_dim(&self) -> usize {
self.intermediate_dim
}
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
pub fn num_parameters(&self) -> usize {
let mut total = 0;
match self.ff_type {
FeedForwardType::Standard => {
if let Some(ref fc1) = self.fc1 {
total += fc1.shape().iter().product::<usize>();
}
if let Some(ref fc2) = self.fc2 {
total += fc2.shape().iter().product::<usize>();
}
}
FeedForwardType::SwiGLU | FeedForwardType::GeGLU => {
if let Some(ref gate_proj) = self.gate_proj {
total += gate_proj.shape().iter().product::<usize>();
}
if let Some(ref up_proj) = self.up_proj {
total += up_proj.shape().iter().product::<usize>();
}
if let Some(ref down_proj) = self.down_proj {
total += down_proj.shape().iter().product::<usize>();
}
}
}
total
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_standard_ffn() {
let hidden_dim = 8;
let intermediate_dim = 32;
let fc1 = DenseTensor::ones(vec![hidden_dim, intermediate_dim]);
let fc2 = DenseTensor::ones(vec![intermediate_dim, hidden_dim]);
let ffn = FeedForward::standard(fc1, fc2, Activation::GELU);
let batch_size = 2;
let seq_len = 4;
let x = DenseTensor::ones(vec![batch_size, seq_len, hidden_dim]);
let output = ffn.forward(&x);
assert_eq!(output.shape(), &[batch_size, seq_len, hidden_dim]);
}
#[test]
fn test_swiglu_ffn() {
let hidden_dim = 8;
let intermediate_dim = 32;
let gate_proj = DenseTensor::ones(vec![hidden_dim, intermediate_dim]);
let up_proj = DenseTensor::ones(vec![hidden_dim, intermediate_dim]);
let down_proj = DenseTensor::ones(vec![intermediate_dim, hidden_dim]);
let ffn = FeedForward::swiglu(gate_proj, up_proj, down_proj);
let batch_size = 2;
let seq_len = 4;
let x = DenseTensor::ones(vec![batch_size, seq_len, hidden_dim]);
let output = ffn.forward(&x);
assert_eq!(output.shape(), &[batch_size, seq_len, hidden_dim]);
}
}