use serde::{Deserialize, Serialize};
use trustformers_core::errors::Result;
use trustformers_core::traits::Config;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RobertaConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub hidden_act: String,
pub hidden_dropout_prob: f32,
pub attention_probs_dropout_prob: f32,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub initializer_range: f32,
pub layer_norm_eps: f32,
pub pad_token_id: u32,
pub bos_token_id: u32,
pub eos_token_id: u32,
pub position_embedding_type: Option<String>,
pub use_cache: Option<bool>,
pub classifier_dropout: Option<f32>,
}
impl Default for RobertaConfig {
fn default() -> Self {
Self {
vocab_size: 50265,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: "gelu".to_string(),
hidden_dropout_prob: 0.1,
attention_probs_dropout_prob: 0.1,
max_position_embeddings: 514,
type_vocab_size: 1,
initializer_range: 0.02,
layer_norm_eps: 1e-5,
pad_token_id: 1,
bos_token_id: 0,
eos_token_id: 2,
position_embedding_type: Some("absolute".to_string()),
use_cache: Some(true),
classifier_dropout: None,
}
}
}
impl RobertaConfig {
pub fn roberta_base() -> Self {
Self::default()
}
pub fn roberta_large() -> Self {
Self {
hidden_size: 1024,
num_hidden_layers: 24,
num_attention_heads: 16,
intermediate_size: 4096,
..Self::default()
}
}
}
impl Config for RobertaConfig {
fn validate(&self) -> Result<()> {
if !self.hidden_size.is_multiple_of(self.num_attention_heads) {
return Err(trustformers_core::errors::invalid_config(
"hidden_size",
format!(
"hidden_size {} must be divisible by num_attention_heads {}",
self.hidden_size, self.num_attention_heads
),
));
}
Ok(())
}
fn architecture(&self) -> &'static str {
"RoBERTa"
}
}