use crate::fnet::attention::FNetLayer;
use crate::fnet::FNetConfig;
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Tensor};
pub struct FNetEncoder {
layers: Vec<FNetLayer>,
output_hidden_states: bool,
}
impl FNetEncoder {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetEncoder
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let p_layers = p / "layer";
let mut layers: Vec<FNetLayer> = Vec::with_capacity(config.num_hidden_layers as usize);
for layer_index in 0..config.num_hidden_layers {
layers.push(FNetLayer::new(&p_layers / layer_index, config));
}
let output_hidden_states = config.output_hidden_states.unwrap_or(false);
FNetEncoder {
layers,
output_hidden_states,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> FNetEncoderOutput {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut x: Option<Tensor> = None;
for layer in &self.layers {
let temp = if let Some(x_value) = &x {
layer.forward_t(x_value, train)
} else {
layer.forward_t(hidden_states, train)
};
x = Some(temp);
if let Some(all_hidden_states) = all_hidden_states.borrow_mut() {
all_hidden_states.push(x.as_ref().unwrap().copy());
};
}
FNetEncoderOutput {
hidden_states: x.unwrap(),
all_hidden_states,
}
}
}
pub struct FNetEncoderOutput {
pub hidden_states: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
}