use crate as burn;
use crate::{
    config::Config,
    module::Module,
    nn::{Dropout, DropoutConfig, Linear, LinearConfig, GELU},
    tensor::{backend::Backend, Tensor},
};
#[derive(Config)]
pub struct PositionWiseFeedForwardConfig {
    pub d_model: usize,
    pub d_ff: usize,
    #[config(default = 0.1)]
    pub dropout: f64,
}
#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
    linear_inner: Linear<B>,
    linear_outer: Linear<B>,
    dropout: Dropout,
    gelu: GELU,
}
impl PositionWiseFeedForwardConfig {
    pub fn init<B: Backend>(&self) -> PositionWiseFeedForward<B> {
        PositionWiseFeedForward {
            linear_inner: LinearConfig::new(self.d_model, self.d_ff).init(),
            linear_outer: LinearConfig::new(self.d_ff, self.d_model).init(),
            dropout: DropoutConfig::new(self.dropout).init(),
            gelu: GELU::new(),
        }
    }
    pub fn init_with<B: Backend>(
        &self,
        record: PositionWiseFeedForwardRecord<B>,
    ) -> PositionWiseFeedForward<B> {
        PositionWiseFeedForward {
            linear_inner: LinearConfig::new(self.d_model, self.d_ff).init_with(record.linear_inner),
            linear_outer: LinearConfig::new(self.d_ff, self.d_model).init_with(record.linear_outer),
            dropout: DropoutConfig::new(self.dropout).init(),
            gelu: GELU::new(),
        }
    }
}
impl<B: Backend> PositionWiseFeedForward<B> {
    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
        let x = self.linear_inner.forward(input);
        let x = self.gelu.forward(x);
        let x = self.dropout.forward(x);
        self.linear_outer.forward(x)
    }
}