use crate::device::Device;
use crate::errors::Result;
use crate::layers::Linear;
use crate::ops::activations::gelu;
use crate::tensor::Tensor;
use crate::traits::Layer;
#[derive(Debug, Clone)]
pub struct FeedForward {
dense: Linear,
output: Linear,
#[allow(dead_code)]
dropout_prob: f32,
}
impl FeedForward {
pub fn new_with_device(
hidden_size: usize,
intermediate_size: usize,
dropout_prob: f32,
device: Device,
) -> Self {
Self {
dense: Linear::new_with_device(hidden_size, intermediate_size, true, device),
output: Linear::new_with_device(intermediate_size, hidden_size, true, device),
dropout_prob,
}
}
pub fn new(hidden_size: usize, intermediate_size: usize, dropout_prob: f32) -> Result<Self> {
Ok(Self::new_with_device(
hidden_size,
intermediate_size,
dropout_prob,
Device::CPU,
))
}
pub fn parameter_count(&self) -> usize {
self.dense.parameter_count() + self.output.parameter_count()
}
pub fn set_dense_weight(&mut self, weight: Tensor) -> Result<()> {
self.dense.set_weight(weight)
}
pub fn set_dense_bias(&mut self, bias: Tensor) -> Result<()> {
self.dense.set_bias(bias)
}
pub fn set_output_weight(&mut self, weight: Tensor) -> Result<()> {
self.output.set_weight(weight)
}
pub fn set_output_bias(&mut self, bias: Tensor) -> Result<()> {
self.output.set_bias(bias)
}
}
impl Layer for FeedForward {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let hidden_states = self.dense.forward(input)?;
let hidden_states = gelu(&hidden_states)?;
self.output.forward(hidden_states)
}
}