Skip to main content

rlx_wav2vec2_bert/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use serde::Deserialize;
17use std::path::Path;
18
19/// Wav2Vec2-BERT model configuration (e.g. facebook/w2v-bert-2.0).
20#[derive(Debug, Clone, Deserialize)]
21pub struct Wav2Vec2BertConfig {
22    pub hidden_size: usize,
23    pub num_hidden_layers: usize,
24    pub num_attention_heads: usize,
25    pub intermediate_size: usize,
26    pub feature_projection_input_dim: usize,
27    #[serde(default = "default_layer_norm_eps")]
28    pub layer_norm_eps: f64,
29    #[serde(default = "default_hidden_act")]
30    pub hidden_act: String,
31    #[serde(default = "default_position_embeddings_type")]
32    pub position_embeddings_type: String,
33    #[serde(default = "default_left_max_position_embeddings")]
34    pub left_max_position_embeddings: usize,
35    #[serde(default = "default_right_max_position_embeddings")]
36    pub right_max_position_embeddings: usize,
37    #[serde(default = "default_conv_depthwise_kernel_size")]
38    pub conv_depthwise_kernel_size: usize,
39    #[serde(default)]
40    pub add_adapter: bool,
41    #[serde(default)]
42    pub apply_spec_augment: bool,
43    #[serde(default)]
44    pub use_intermediate_ffn_before_adapter: bool,
45    /// Present in HF configs; ignored at inference when `apply_spec_augment=false`.
46    #[serde(default)]
47    pub model_type: Option<String>,
48}
49
50fn default_layer_norm_eps() -> f64 {
51    1e-5
52}
53fn default_hidden_act() -> String {
54    "swish".into()
55}
56fn default_position_embeddings_type() -> String {
57    "relative_key".into()
58}
59fn default_left_max_position_embeddings() -> usize {
60    64
61}
62fn default_right_max_position_embeddings() -> usize {
63    8
64}
65fn default_conv_depthwise_kernel_size() -> usize {
66    31
67}
68
69impl Wav2Vec2BertConfig {
70    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
71        let data = std::fs::read_to_string(path)?;
72        Ok(serde_json::from_str(&data)?)
73    }
74
75    pub fn head_dim(&self) -> usize {
76        self.hidden_size / self.num_attention_heads
77    }
78
79    pub fn num_relative_positions(&self) -> usize {
80        self.left_max_position_embeddings + self.right_max_position_embeddings + 1
81    }
82
83    /// Factory for the public W2v-BERT 2.0 checkpoint dimensions.
84    pub fn w2v_bert_2_0() -> Self {
85        Self {
86            hidden_size: 1024,
87            num_hidden_layers: 24,
88            num_attention_heads: 16,
89            intermediate_size: 4096,
90            feature_projection_input_dim: 160,
91            layer_norm_eps: 1e-5,
92            hidden_act: "swish".into(),
93            position_embeddings_type: "relative_key".into(),
94            left_max_position_embeddings: 64,
95            right_max_position_embeddings: 8,
96            conv_depthwise_kernel_size: 31,
97            add_adapter: false,
98            apply_spec_augment: false,
99            use_intermediate_ffn_before_adapter: false,
100            model_type: Some("wav2vec2-bert".into()),
101        }
102    }
103}