use crate::error::Result;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct FeedForwardConfig {
pub model_dim: usize,
pub inner_dim: usize,
}
impl FeedForwardConfig {
pub fn standard(model_dim: usize) -> Self {
Self {
model_dim,
inner_dim: 4 * model_dim,
}
}
pub fn custom(model_dim: usize, inner_dim: usize) -> Self {
Self {
model_dim,
inner_dim,
}
}
}
#[derive(Debug, Clone)]
pub struct FeedForward {
pub w1: Tensor,
pub b1: Tensor,
pub w2: Tensor,
pub b2: Tensor,
pub config: FeedForwardConfig,
}
impl FeedForward {
pub fn new(config: FeedForwardConfig, rng: &mut impl rand::Rng) -> Result<Self> {
let w1 = Tensor::xavier_uniform(&[config.model_dim, config.inner_dim], rng)?;
let b1 = Tensor::zeros(&[config.inner_dim]);
let w2 = Tensor::xavier_uniform(&[config.inner_dim, config.model_dim], rng)?;
let b2 = Tensor::zeros(&[config.model_dim]);
Ok(Self {
w1,
b1,
w2,
b2,
config,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let seq_len = x.shape()[0];
let hidden = x.matmul(&self.w1)?;
let hidden = self.add_bias_rows(&hidden, &self.b1, seq_len)?;
let activated = hidden.gelu();
let output = activated.matmul(&self.w2)?;
self.add_bias_rows(&output, &self.b2, seq_len)
}
fn add_bias_rows(&self, matrix: &Tensor, bias: &Tensor, rows: usize) -> Result<Tensor> {
let cols = bias.numel();
let mut data = matrix.data().to_vec();
for r in 0..rows {
for c in 0..cols {
data[r * cols + c] += bias.data()[c];
}
}
Tensor::new(data, vec![rows, cols])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ffn_shape() {
let mut rng = rand::rng();
let config = FeedForwardConfig::standard(8);
let ffn = FeedForward::new(config, &mut rng).unwrap();
let x = Tensor::randn(&[3, 8], &mut rng);
let y = ffn.forward(&x).unwrap();
assert_eq!(y.shape(), &[3, 8]);
}
#[test]
fn test_ffn_custom_inner() {
let mut rng = rand::rng();
let config = FeedForwardConfig::custom(8, 16);
let ffn = FeedForward::new(config, &mut rng).unwrap();
let x = Tensor::randn(&[2, 8], &mut rng);
let y = ffn.forward(&x).unwrap();
assert_eq!(y.shape(), &[2, 8]);
}
}