use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct FeedForward {
pub w1: Tensor,
pub b1: Tensor,
pub w2: Tensor,
pub b2: Tensor,
pub dim: usize,
pub inner_dim: usize,
}
impl FeedForward {
pub fn new(dim: usize, mult: usize) -> Self {
let inner_dim = dim * mult;
Self {
w1: Tensor::zeros(&[dim, inner_dim]),
b1: Tensor::zeros(&[inner_dim]),
w2: Tensor::zeros(&[inner_dim, dim]),
b2: Tensor::zeros(&[dim]),
dim,
inner_dim,
}
}
pub fn forward(&self, x: &Tensor) -> Tensor {
let (b, n, d) = (x.shape[0], x.shape[1], x.shape[2]);
let h = x.reshape(&[b * n, d])
.matmul(&self.w1)
.add_bias(&self.b1)
.gelu();
h.matmul(&self.w2)
.add_bias(&self.b2)
.reshape(&[b, n, d])
}
}