use crate::nn::{Linear, Module, ReLU};
use crate::tensor::{GraphContext, Tensor};
use std::cell::RefCell;
use std::rc::Rc;
pub struct FeedForward {
linear1: Linear,
relu: ReLU,
linear2: Linear,
}
impl FeedForward {
pub fn new(
context: &Rc<RefCell<GraphContext>>,
name: &str,
embed_dim: usize,
hidden_dim: usize,
) -> Self {
Self {
linear1: Linear::new(context, &format!("{}.linear1", name), embed_dim, hidden_dim),
relu: ReLU::new(),
linear2: Linear::new(context, &format!("{}.linear2", name), hidden_dim, embed_dim),
}
}
}
impl Module for FeedForward {
fn forward(&self, inputs: &Tensor) -> Tensor {
let x = self.linear1.forward(inputs);
let x = self.relu.forward(&x);
self.linear2.forward(&x)
}
fn parameters(&self) -> Vec<Tensor> {
let mut params = Vec::new();
params.extend(self.linear1.parameters());
params.extend(self.linear2.parameters());
params
}
}