use crate::common::dropout::{Dropout, XDropout};
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::deberta::{
deserialize_attention_type, ContextPooler, DebertaConfig, DebertaLMPredictionHead,
DebertaMaskedLMOutput, DebertaModelOutput, DebertaQuestionAnsweringOutput,
DebertaSequenceClassificationOutput, DebertaTokenClassificationOutput, PositionAttentionTypes,
};
use crate::deberta_v2::embeddings::DebertaV2Embeddings;
use crate::deberta_v2::encoder::DebertaV2Encoder;
use crate::{Activation, Config, RustBertError};
use serde::de::{SeqAccess, Visitor};
use serde::{de, Deserialize, Deserializer, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
use tch::{nn, Kind, Tensor};
pub struct DebertaV2ModelResources;
pub struct DebertaV2ConfigResources;
pub struct DebertaV2VocabResources;
impl DebertaV2ModelResources {
pub const DEBERTA_V3_BASE: (&'static str, &'static str) = (
"deberta-v3-base/model",
"https://huggingface.co/microsoft/deberta-v3-base/resolve/main/rust_model.ot",
);
}
impl DebertaV2ConfigResources {
pub const DEBERTA_V3_BASE: (&'static str, &'static str) = (
"deberta-v3-base/config",
"https://huggingface.co/microsoft/deberta-v3-base/resolve/main/config.json",
);
}
impl DebertaV2VocabResources {
pub const DEBERTA_V3_BASE: (&'static str, &'static str) = (
"deberta-v3-base/vocab",
"https://huggingface.co/microsoft/deberta-v3-base/resolve/main/spm.model",
);
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DebertaV2Config {
pub vocab_size: i64,
pub hidden_size: i64,
pub num_hidden_layers: i64,
pub hidden_act: Activation,
pub attention_probs_dropout_prob: f64,
pub hidden_dropout_prob: f64,
pub initializer_range: f64,
pub intermediate_size: i64,
pub max_position_embeddings: i64,
pub position_buckets: Option<i64>,
pub num_attention_heads: i64,
pub type_vocab_size: i64,
pub position_biased_input: Option<bool>,
#[serde(default, deserialize_with = "deserialize_attention_type")]
pub pos_att_type: Option<PositionAttentionTypes>,
#[serde(default, deserialize_with = "deserialize_norm_type")]
pub norm_rel_ebd: Option<NormRelEmbedTypes>,
pub share_att_key: Option<bool>,
pub conv_kernel_size: Option<i64>,
pub conv_groups: Option<i64>,
pub conv_act: Option<Activation>,
pub pooler_dropout: Option<f64>,
pub pooler_hidden_act: Option<Activation>,
pub pooler_hidden_size: Option<i64>,
pub layer_norm_eps: Option<f64>,
pub pad_token_id: Option<i64>,
pub relative_attention: Option<bool>,
pub max_relative_positions: Option<i64>,
pub embedding_size: Option<i64>,
pub talking_head: Option<bool>,
pub output_hidden_states: Option<bool>,
pub output_attentions: Option<bool>,
pub classifier_activation: Option<bool>,
pub classifier_dropout: Option<f64>,
pub is_decoder: Option<bool>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
}
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, Copy, PartialEq, Eq)]
pub enum NormRelEmbedType {
layer_norm,
}
impl FromStr for NormRelEmbedType {
type Err = RustBertError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"layer_norm" => Ok(NormRelEmbedType::layer_norm),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"Layer normalization type `{s}` not in accepted variants (`layer_norm`)",
))),
}
}
}
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct NormRelEmbedTypes {
types: Vec<NormRelEmbedType>,
}
impl FromStr for NormRelEmbedTypes {
type Err = RustBertError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let types = s
.to_lowercase()
.split('|')
.map(NormRelEmbedType::from_str)
.collect::<Result<Vec<_>, _>>()?;
Ok(NormRelEmbedTypes { types })
}
}
impl NormRelEmbedTypes {
pub fn has_type(&self, norm_type: NormRelEmbedType) -> bool {
self.types.iter().any(|self_type| *self_type == norm_type)
}
pub fn len(&self) -> usize {
self.types.len()
}
}
pub fn deserialize_norm_type<'de, D>(deserializer: D) -> Result<Option<NormRelEmbedTypes>, D::Error>
where
D: Deserializer<'de>,
{
struct NormTypeVisitor;
impl<'de> Visitor<'de> for NormTypeVisitor {
type Value = NormRelEmbedTypes;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("null, string or sequence")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(FromStr::from_str(value).unwrap())
}
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
where
S: SeqAccess<'de>,
{
let mut types = vec![];
while let Some(norm_type) = seq.next_element::<String>()? {
types.push(FromStr::from_str(norm_type.as_str()).unwrap())
}
Ok(NormRelEmbedTypes { types })
}
}
deserializer.deserialize_any(NormTypeVisitor).map(Some)
}
impl Config for DebertaV2Config {}
impl Default for DebertaV2Config {
fn default() -> Self {
DebertaV2Config {
vocab_size: 128100,
hidden_size: 1536,
num_hidden_layers: 24,
hidden_act: Activation::gelu,
attention_probs_dropout_prob: 0.1,
hidden_dropout_prob: 0.1,
initializer_range: 0.02,
intermediate_size: 6144,
max_position_embeddings: 512,
position_buckets: None,
num_attention_heads: 24,
type_vocab_size: 0,
position_biased_input: Some(true),
pos_att_type: None,
norm_rel_ebd: None,
share_att_key: None,
conv_kernel_size: None,
conv_groups: None,
conv_act: None,
pooler_dropout: Some(0.0),
pooler_hidden_act: Some(Activation::gelu),
pooler_hidden_size: None,
layer_norm_eps: Some(1e-7),
pad_token_id: Some(0),
relative_attention: None,
max_relative_positions: None,
embedding_size: None,
talking_head: None,
output_hidden_states: None,
output_attentions: None,
classifier_activation: None,
classifier_dropout: None,
is_decoder: None,
id2label: None,
label2id: None,
}
}
}
impl From<DebertaV2Config> for DebertaConfig {
fn from(v2_config: DebertaV2Config) -> Self {
DebertaConfig {
hidden_act: v2_config.hidden_act,
attention_probs_dropout_prob: v2_config.attention_probs_dropout_prob,
hidden_dropout_prob: v2_config.hidden_dropout_prob,
hidden_size: v2_config.hidden_size,
initializer_range: v2_config.initializer_range,
intermediate_size: v2_config.intermediate_size,
max_position_embeddings: v2_config.max_position_embeddings,
num_attention_heads: v2_config.num_attention_heads,
num_hidden_layers: v2_config.num_hidden_layers,
type_vocab_size: v2_config.type_vocab_size,
vocab_size: v2_config.vocab_size,
position_biased_input: v2_config.position_biased_input,
pos_att_type: v2_config.pos_att_type,
pooler_dropout: v2_config.pooler_dropout,
pooler_hidden_act: v2_config.pooler_hidden_act,
pooler_hidden_size: v2_config.pooler_hidden_size,
layer_norm_eps: v2_config.layer_norm_eps,
pad_token_id: v2_config.pad_token_id,
relative_attention: v2_config.relative_attention,
max_relative_positions: v2_config.max_relative_positions,
embedding_size: v2_config.embedding_size,
talking_head: v2_config.talking_head,
output_hidden_states: v2_config.output_hidden_states,
output_attentions: v2_config.output_attentions,
classifier_dropout: v2_config.classifier_dropout,
is_decoder: v2_config.is_decoder,
id2label: v2_config.id2label,
label2id: v2_config.label2id,
share_att_key: v2_config.share_att_key,
position_buckets: v2_config.position_buckets,
}
}
}
impl From<&DebertaV2Config> for DebertaConfig {
fn from(v2_config: &DebertaV2Config) -> Self {
DebertaConfig {
hidden_act: v2_config.hidden_act,
attention_probs_dropout_prob: v2_config.attention_probs_dropout_prob,
hidden_dropout_prob: v2_config.hidden_dropout_prob,
hidden_size: v2_config.hidden_size,
initializer_range: v2_config.initializer_range,
intermediate_size: v2_config.intermediate_size,
max_position_embeddings: v2_config.max_position_embeddings,
num_attention_heads: v2_config.num_attention_heads,
num_hidden_layers: v2_config.num_hidden_layers,
type_vocab_size: v2_config.type_vocab_size,
vocab_size: v2_config.vocab_size,
position_biased_input: v2_config.position_biased_input,
pos_att_type: v2_config.pos_att_type.clone(),
pooler_dropout: v2_config.pooler_dropout,
pooler_hidden_act: v2_config.pooler_hidden_act,
pooler_hidden_size: v2_config.pooler_hidden_size,
layer_norm_eps: v2_config.layer_norm_eps,
pad_token_id: v2_config.pad_token_id,
relative_attention: v2_config.relative_attention,
max_relative_positions: v2_config.max_relative_positions,
embedding_size: v2_config.embedding_size,
talking_head: v2_config.talking_head,
output_hidden_states: v2_config.output_hidden_states,
output_attentions: v2_config.output_attentions,
classifier_dropout: v2_config.classifier_dropout,
is_decoder: v2_config.is_decoder,
id2label: v2_config.id2label.clone(),
label2id: v2_config.label2id.clone(),
share_att_key: v2_config.share_att_key,
position_buckets: v2_config.position_buckets,
}
}
}
pub struct DebertaV2Model {
embeddings: DebertaV2Embeddings,
encoder: DebertaV2Encoder,
}
impl DebertaV2Model {
pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2Model
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let embeddings = DebertaV2Embeddings::new(p / "embeddings", &config.into());
let encoder = DebertaV2Encoder::new(p / "encoder", config);
DebertaV2Model {
embeddings,
encoder,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DebertaV2ModelOutput, RustBertError> {
let (input_shape, device) =
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
let calc_attention_mask = if attention_mask.is_none() {
Some(Tensor::ones(input_shape.as_slice(), (Kind::Bool, device)))
} else {
None
};
let attention_mask =
attention_mask.unwrap_or_else(|| calc_attention_mask.as_ref().unwrap());
let embedding_output = self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
attention_mask,
input_embeds,
train,
)?;
let encoder_output =
self.encoder
.forward_t(&embedding_output, attention_mask, None, None, train)?;
Ok(encoder_output)
}
}
pub struct DebertaV2ForMaskedLM {
deberta: DebertaV2Model,
cls: DebertaLMPredictionHead,
}
impl DebertaV2ForMaskedLM {
pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2ForMaskedLM
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let deberta = DebertaV2Model::new(p / "deberta", config);
let cls =
DebertaLMPredictionHead::new(p.sub("cls").sub("predictions"), &config.into(), false);
DebertaV2ForMaskedLM { deberta, cls }
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DebertaV2MaskedLMOutput, RustBertError> {
let model_outputs = self.deberta.forward_t(
input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let logits = model_outputs.hidden_state.apply(&self.cls);
Ok(DebertaV2MaskedLMOutput {
logits,
all_hidden_states: model_outputs.all_hidden_states,
all_attentions: model_outputs.all_attentions,
})
}
}
pub struct DebertaV2ForSequenceClassification {
deberta: DebertaV2Model,
pooler: ContextPooler,
classifier: nn::Linear,
dropout: XDropout,
}
impl DebertaV2ForSequenceClassification {
pub fn new<'p, P>(
p: P,
config: &DebertaV2Config,
) -> Result<DebertaV2ForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let deberta = DebertaV2Model::new(p / "deberta", config);
let pooler = ContextPooler::new(p / "pooler", &config.into());
let dropout = XDropout::new(
config
.classifier_dropout
.unwrap_or(config.hidden_dropout_prob),
);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
pooler.output_dim,
num_labels,
Default::default(),
);
Ok(DebertaV2ForSequenceClassification {
deberta,
pooler,
classifier,
dropout,
})
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DebertaV2SequenceClassificationOutput, RustBertError> {
let base_model_output = self.deberta.forward_t(
input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let logits = base_model_output
.hidden_state
.apply_t(&self.pooler, train)
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok(DebertaV2SequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
pub struct DebertaV2ForTokenClassification {
deberta: DebertaV2Model,
dropout: Dropout,
classifier: nn::Linear,
}
impl DebertaV2ForTokenClassification {
pub fn new<'p, P>(
p: P,
config: &DebertaV2Config,
) -> Result<DebertaV2ForTokenClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let deberta = DebertaV2Model::new(p / "deberta", config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config
.id2label
.as_ref()
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
Ok(DebertaV2ForTokenClassification {
deberta,
dropout,
classifier,
})
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DebertaV2TokenClassificationOutput, RustBertError> {
let base_model_output = self.deberta.forward_t(
input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let logits = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok(DebertaV2TokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
pub struct DebertaV2ForQuestionAnswering {
deberta: DebertaV2Model,
qa_outputs: nn::Linear,
}
impl DebertaV2ForQuestionAnswering {
pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2ForQuestionAnswering
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let deberta = DebertaV2Model::new(p / "deberta", config);
let num_labels = 2;
let qa_outputs = nn::linear(
p / "qa_outputs",
config.hidden_size,
num_labels,
Default::default(),
);
DebertaV2ForQuestionAnswering {
deberta,
qa_outputs,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<DebertaV2QuestionAnsweringOutput, RustBertError> {
let base_model_output = self.deberta.forward_t(
input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze_dim(-1);
let end_logits = end_logits.squeeze_dim(-1);
Ok(DebertaV2QuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
pub type DebertaV2ModelOutput = DebertaModelOutput;
pub type DebertaV2MaskedLMOutput = DebertaMaskedLMOutput;
pub type DebertaV2SequenceClassificationOutput = DebertaSequenceClassificationOutput;
pub type DebertaV2TokenClassificationOutput = DebertaTokenClassificationOutput;
pub type DebertaV2QuestionAnsweringOutput = DebertaQuestionAnsweringOutput;