use serde::Deserialize;
use std::path::Path;
#[derive(Debug, Clone, Deserialize)]
pub struct Wav2Vec2BertConfig {
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub feature_projection_input_dim: usize,
#[serde(default = "default_layer_norm_eps")]
pub layer_norm_eps: f64,
#[serde(default = "default_hidden_act")]
pub hidden_act: String,
#[serde(default = "default_position_embeddings_type")]
pub position_embeddings_type: String,
#[serde(default = "default_left_max_position_embeddings")]
pub left_max_position_embeddings: usize,
#[serde(default = "default_right_max_position_embeddings")]
pub right_max_position_embeddings: usize,
#[serde(default = "default_conv_depthwise_kernel_size")]
pub conv_depthwise_kernel_size: usize,
#[serde(default)]
pub add_adapter: bool,
#[serde(default)]
pub apply_spec_augment: bool,
#[serde(default)]
pub use_intermediate_ffn_before_adapter: bool,
#[serde(default)]
pub model_type: Option<String>,
}
fn default_layer_norm_eps() -> f64 {
1e-5
}
fn default_hidden_act() -> String {
"swish".into()
}
fn default_position_embeddings_type() -> String {
"relative_key".into()
}
fn default_left_max_position_embeddings() -> usize {
64
}
fn default_right_max_position_embeddings() -> usize {
8
}
fn default_conv_depthwise_kernel_size() -> usize {
31
}
impl Wav2Vec2BertConfig {
pub fn from_file(path: &Path) -> anyhow::Result<Self> {
let data = std::fs::read_to_string(path)?;
Ok(serde_json::from_str(&data)?)
}
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
pub fn num_relative_positions(&self) -> usize {
self.left_max_position_embeddings + self.right_max_position_embeddings + 1
}
pub fn w2v_bert_2_0() -> Self {
Self {
hidden_size: 1024,
num_hidden_layers: 24,
num_attention_heads: 16,
intermediate_size: 4096,
feature_projection_input_dim: 160,
layer_norm_eps: 1e-5,
hidden_act: "swish".into(),
position_embeddings_type: "relative_key".into(),
left_max_position_embeddings: 64,
right_max_position_embeddings: 8,
conv_depthwise_kernel_size: 31,
add_adapter: false,
apply_spec_augment: false,
use_intermediate_ffn_before_adapter: false,
model_type: Some("wav2vec2-bert".into()),
}
}
}