use crate::common::activations::{TensorFunction, _tanh};
use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::fnet::embeddings::FNetEmbeddings;
use crate::fnet::encoder::FNetEncoder;
use crate::{Activation, Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::nn::LayerNormConfig;
use tch::{nn, Tensor};
pub struct FNetModelResources;
pub struct FNetConfigResources;
pub struct FNetVocabResources;
impl FNetModelResources {
pub const BASE: (&'static str, &'static str) = (
"fnet-base/model",
"https://huggingface.co/google/fnet-base/resolve/main/rust_model.ot",
);
pub const BASE_SST2: (&'static str, &'static str) = (
"fnet-base-sst2/model",
"https://huggingface.co/gchhablani/fnet-base-finetuned-sst2/resolve/main/rust_model.ot",
);
}
impl FNetConfigResources {
pub const BASE: (&'static str, &'static str) = (
"fnet-base/config",
"https://huggingface.co/google/fnet-base/resolve/main/config.json",
);
pub const BASE_SST2: (&'static str, &'static str) = (
"fnet-base-sst2/config",
"https://huggingface.co/gchhablani/fnet-base-finetuned-sst2/resolve/main/config.json",
);
}
impl FNetVocabResources {
pub const BASE: (&'static str, &'static str) = (
"fnet-base/spiece",
"https://huggingface.co/google/fnet-base/resolve/main/spiece.model",
);
pub const BASE_SST2: (&'static str, &'static str) = (
"fnet-base-sst2/spiece",
"https://huggingface.co/google/fnet-base/resolve/main/spiece.model",
);
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FNetConfig {
pub vocab_size: i64,
pub hidden_size: i64,
pub num_hidden_layers: i64,
pub intermediate_size: i64,
pub hidden_act: Activation,
pub hidden_dropout_prob: f64,
pub max_position_embeddings: i64,
pub type_vocab_size: i64,
pub initializer_range: f64,
pub layer_norm_eps: Option<f64>,
pub pad_token_id: Option<i64>,
pub bos_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub decoder_start_token_id: Option<i64>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
}
impl Config for FNetConfig {}
impl Default for FNetConfig {
fn default() -> Self {
FNetConfig {
vocab_size: 32000,
hidden_size: 768,
num_hidden_layers: 12,
intermediate_size: 3072,
hidden_act: Activation::gelu_new,
hidden_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 4,
initializer_range: 0.02,
layer_norm_eps: Some(1e-12),
pad_token_id: Some(3),
bos_token_id: Some(1),
eos_token_id: Some(2),
decoder_start_token_id: None,
id2label: None,
label2id: None,
output_attentions: None,
output_hidden_states: None,
}
}
}
struct FNetPooler {
dense: nn::Linear,
activation: TensorFunction,
}
impl FNetPooler {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetPooler
where
P: Borrow<nn::Path<'p>>,
{
let dense = nn::linear(
p.borrow() / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let activation = TensorFunction::new(Box::new(_tanh));
FNetPooler { dense, activation }
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
self.activation.get_fn()(&hidden_states.select(1, 0).apply(&self.dense))
}
}
struct FNetPredictionHeadTransform {
dense: nn::Linear,
activation: TensorFunction,
layer_norm: nn::LayerNorm,
}
impl FNetPredictionHeadTransform {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetPredictionHeadTransform
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let activation = config.hidden_act.get_function();
let layer_norm_config = LayerNormConfig {
eps: config.layer_norm_eps.unwrap_or(1e-12),
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
FNetPredictionHeadTransform {
dense,
activation,
layer_norm,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
let hidden_states = hidden_states.apply(&self.dense);
let hidden_states: Tensor = self.activation.get_fn()(&hidden_states);
hidden_states.apply(&self.layer_norm)
}
}
struct FNetLMPredictionHead {
transform: FNetPredictionHeadTransform,
decoder: nn::Linear,
}
impl FNetLMPredictionHead {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetLMPredictionHead
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let transform = FNetPredictionHeadTransform::new(p / "transform", config);
let decoder = nn::linear(
p / "decoder",
config.hidden_size,
config.vocab_size,
Default::default(),
);
FNetLMPredictionHead { transform, decoder }
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
self.transform.forward(hidden_states).apply(&self.decoder)
}
}
pub struct FNetModel {
embeddings: FNetEmbeddings,
encoder: FNetEncoder,
pooler: Option<FNetPooler>,
}
impl FNetModel {
pub fn new<'p, P>(p: P, config: &FNetConfig, add_pooling_layer: bool) -> FNetModel
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let embeddings = FNetEmbeddings::new(p / "embeddings", config);
let encoder = FNetEncoder::new(p / "encoder", config);
let pooler = if add_pooling_layer {
Some(FNetPooler::new(p / "pooler", config))
} else {
None
};
FNetModel {
embeddings,
encoder,
pooler,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeddings: Option<&Tensor>,
train: bool,
) -> Result<FNetModelOutput, RustBertError> {
let hidden_states = self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeddings,
train,
)?;
let encoder_output = self.encoder.forward_t(&hidden_states, train);
let pooled_output = if let Some(pooler) = &self.pooler {
Some(pooler.forward(&encoder_output.hidden_states))
} else {
None
};
Ok(FNetModelOutput {
hidden_states: encoder_output.hidden_states,
pooled_output,
all_hidden_states: encoder_output.all_hidden_states,
})
}
}
pub struct FNetForMaskedLM {
fnet: FNetModel,
lm_head: FNetLMPredictionHead,
}
impl FNetForMaskedLM {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForMaskedLM
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let fnet = FNetModel::new(p / "fnet", config, false);
let lm_head = FNetLMPredictionHead::new(p.sub("cls").sub("predictions"), config);
FNetForMaskedLM { fnet, lm_head }
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeddings: Option<&Tensor>,
train: bool,
) -> Result<FNetMaskedLMOutput, RustBertError> {
let model_output = self.fnet.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeddings,
train,
)?;
let prediction_scores = self.lm_head.forward(&model_output.hidden_states);
Ok(FNetMaskedLMOutput {
prediction_scores,
all_hidden_states: model_output.all_hidden_states,
})
}
}
pub struct FNetForSequenceClassification {
fnet: FNetModel,
dropout: Dropout,
classifier: nn::Linear,
}
impl FNetForSequenceClassification {
pub fn new<'p, P>(
p: P,
config: &FNetConfig,
) -> Result<FNetForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let fnet = FNetModel::new(p / "fnet", config, true);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
Ok(FNetForSequenceClassification {
fnet,
dropout,
classifier,
})
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeddings: Option<&Tensor>,
train: bool,
) -> Result<FNetSequenceClassificationOutput, RustBertError> {
let base_model_output = self.fnet.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeddings,
train,
)?;
let logits = base_model_output
.pooled_output
.unwrap()
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok(FNetSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
})
}
}
pub struct FNetForMultipleChoice {
fnet: FNetModel,
dropout: Dropout,
classifier: nn::Linear,
}
impl FNetForMultipleChoice {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForMultipleChoice
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let fnet = FNetModel::new(p / "fnet", config, true);
let dropout = Dropout::new(config.hidden_dropout_prob);
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
FNetForMultipleChoice {
fnet,
dropout,
classifier,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeddings: Option<&Tensor>,
train: bool,
) -> Result<FNetSequenceClassificationOutput, RustBertError> {
let (input_shape, _) =
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeddings)?;
let num_choices = input_shape[1];
let input_ids = input_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let token_type_ids =
token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let position_ids =
position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let input_embeddings =
input_embeddings.map(|tensor| tensor.view((-1, tensor.size()[2], tensor.size()[3])));
let base_model_output = self.fnet.forward_t(
input_ids.as_ref(),
token_type_ids.as_ref(),
position_ids.as_ref(),
input_embeddings.as_ref(),
train,
)?;
let logits = base_model_output
.pooled_output
.unwrap()
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
Ok(FNetSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
})
}
}
pub struct FNetForTokenClassification {
fnet: FNetModel,
dropout: Dropout,
classifier: nn::Linear,
}
impl FNetForTokenClassification {
pub fn new<'p, P>(
p: P,
config: &FNetConfig,
) -> Result<FNetForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let fnet = FNetModel::new(p / "fnet", config, false);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
Ok(FNetForTokenClassification {
fnet,
dropout,
classifier,
})
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeddings: Option<&Tensor>,
train: bool,
) -> Result<FNetTokenClassificationOutput, RustBertError> {
let base_model_output = self.fnet.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeddings,
train,
)?;
let logits = base_model_output
.hidden_states
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok(FNetTokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
})
}
}
pub struct FNetForQuestionAnswering {
fnet: FNetModel,
qa_outputs: nn::Linear,
}
impl FNetForQuestionAnswering {
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForQuestionAnswering
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let fnet = FNetModel::new(p / "fnet", config, false);
let qa_outputs = nn::linear(p / "classifier", config.hidden_size, 2, Default::default());
FNetForQuestionAnswering { fnet, qa_outputs }
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeddings: Option<&Tensor>,
train: bool,
) -> Result<FNetQuestionAnsweringOutput, RustBertError> {
let base_model_output = self.fnet.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeddings,
train,
)?;
let logits = base_model_output
.hidden_states
.apply(&self.qa_outputs)
.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze_dim(-1);
Ok(FNetQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: base_model_output.all_hidden_states,
})
}
}
pub struct FNetModelOutput {
pub hidden_states: Tensor,
pub pooled_output: Option<Tensor>,
pub all_hidden_states: Option<Vec<Tensor>>,
}
pub struct FNetMaskedLMOutput {
pub prediction_scores: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
}
pub struct FNetSequenceClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
}
pub type FNetTokenClassificationOutput = FNetSequenceClassificationOutput;
pub struct FNetQuestionAnsweringOutput {
pub start_logits: Tensor,
pub end_logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
}
#[cfg(test)]
mod test {
use tch::Device;
use crate::{
resources::{RemoteResource, ResourceProvider},
Config,
};
use super::*;
#[test]
#[ignore] fn fnet_model_send() {
let config_resource = Box::new(RemoteResource::from_pretrained(FNetConfigResources::BASE));
let config_path = config_resource.get_local_path().expect("");
let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device);
let config = FNetConfig::from_file(config_path);
let _: Box<dyn Send> = Box::new(FNetModel::new(vs.root(), &config, true));
}
}