rlx_wav2vec2_bert/
config.rs1use serde::Deserialize;
17use std::path::Path;
18
19#[derive(Debug, Clone, Deserialize)]
21pub struct Wav2Vec2BertConfig {
22 pub hidden_size: usize,
23 pub num_hidden_layers: usize,
24 pub num_attention_heads: usize,
25 pub intermediate_size: usize,
26 pub feature_projection_input_dim: usize,
27 #[serde(default = "default_layer_norm_eps")]
28 pub layer_norm_eps: f64,
29 #[serde(default = "default_hidden_act")]
30 pub hidden_act: String,
31 #[serde(default = "default_position_embeddings_type")]
32 pub position_embeddings_type: String,
33 #[serde(default = "default_left_max_position_embeddings")]
34 pub left_max_position_embeddings: usize,
35 #[serde(default = "default_right_max_position_embeddings")]
36 pub right_max_position_embeddings: usize,
37 #[serde(default = "default_conv_depthwise_kernel_size")]
38 pub conv_depthwise_kernel_size: usize,
39 #[serde(default)]
40 pub add_adapter: bool,
41 #[serde(default)]
42 pub apply_spec_augment: bool,
43 #[serde(default)]
44 pub use_intermediate_ffn_before_adapter: bool,
45 #[serde(default)]
47 pub model_type: Option<String>,
48}
49
50fn default_layer_norm_eps() -> f64 {
51 1e-5
52}
53fn default_hidden_act() -> String {
54 "swish".into()
55}
56fn default_position_embeddings_type() -> String {
57 "relative_key".into()
58}
59fn default_left_max_position_embeddings() -> usize {
60 64
61}
62fn default_right_max_position_embeddings() -> usize {
63 8
64}
65fn default_conv_depthwise_kernel_size() -> usize {
66 31
67}
68
69impl Wav2Vec2BertConfig {
70 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
71 let data = std::fs::read_to_string(path)?;
72 Ok(serde_json::from_str(&data)?)
73 }
74
75 pub fn head_dim(&self) -> usize {
76 self.hidden_size / self.num_attention_heads
77 }
78
79 pub fn num_relative_positions(&self) -> usize {
80 self.left_max_position_embeddings + self.right_max_position_embeddings + 1
81 }
82
83 pub fn w2v_bert_2_0() -> Self {
85 Self {
86 hidden_size: 1024,
87 num_hidden_layers: 24,
88 num_attention_heads: 16,
89 intermediate_size: 4096,
90 feature_projection_input_dim: 160,
91 layer_norm_eps: 1e-5,
92 hidden_act: "swish".into(),
93 position_embeddings_type: "relative_key".into(),
94 left_max_position_embeddings: 64,
95 right_max_position_embeddings: 8,
96 conv_depthwise_kernel_size: 31,
97 add_adapter: false,
98 apply_spec_augment: false,
99 use_intermediate_ffn_before_adapter: false,
100 model_type: Some("wav2vec2-bert".into()),
101 }
102 }
103}