Skip to main content

rlx_models_core/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Model configuration structs — parsed from HuggingFace config.json.
17
18use serde::{Deserialize, Deserializer};
19use std::path::Path;
20
21fn deserialize_usize_or_float<'de, D: Deserializer<'de>>(d: D) -> Result<usize, D::Error> {
22    let v: serde_json::Value = Deserialize::deserialize(d)?;
23    match v {
24        serde_json::Value::Number(n) => {
25            if let Some(u) = n.as_u64() {
26                Ok(u as usize)
27            } else if let Some(f) = n.as_f64() {
28                Ok(f as usize)
29            } else {
30                Err(serde::de::Error::custom("expected number"))
31            }
32        }
33        _ => Err(serde::de::Error::custom("expected number")),
34    }
35}
36
37/// BERT model configuration.
38#[derive(Debug, Clone, Deserialize)]
39pub struct BertConfig {
40    pub vocab_size: usize,
41    pub hidden_size: usize,
42    pub num_hidden_layers: usize,
43    pub num_attention_heads: usize,
44    pub intermediate_size: usize,
45    pub max_position_embeddings: usize,
46    #[serde(default = "default_type_vocab_size")]
47    pub type_vocab_size: usize,
48    #[serde(default = "default_layer_norm_eps")]
49    pub layer_norm_eps: f64,
50    #[serde(default = "default_hidden_act")]
51    pub hidden_act: String,
52}
53
54fn default_type_vocab_size() -> usize {
55    2
56}
57fn default_layer_norm_eps() -> f64 {
58    1e-12
59}
60fn default_hidden_act() -> String {
61    "gelu".into()
62}
63
64impl BertConfig {
65    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
66        let data = std::fs::read_to_string(path)?;
67        Ok(serde_json::from_str(&data)?)
68    }
69
70    pub fn head_dim(&self) -> usize {
71        self.hidden_size / self.num_attention_heads
72    }
73}
74
75/// NomicBERT model configuration.
76#[derive(Debug, Clone, Deserialize)]
77pub struct NomicBertConfig {
78    pub vocab_size: usize,
79    pub hidden_size: usize,
80    pub num_hidden_layers: usize,
81    pub num_attention_heads: usize,
82    pub intermediate_size: usize,
83    pub max_position_embeddings: usize,
84    #[serde(default = "default_type_vocab_size")]
85    pub type_vocab_size: usize,
86    #[serde(default = "default_layer_norm_eps")]
87    pub layer_norm_eps: f64,
88    #[serde(default = "default_head_dim")]
89    pub head_dim: usize,
90    #[serde(default = "default_rotary_emb_base")]
91    pub rotary_emb_base: f64,
92}
93
94fn default_head_dim() -> usize {
95    64
96}
97fn default_rotary_emb_base() -> f64 {
98    1000.0
99}
100
101impl NomicBertConfig {
102    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
103        let data = std::fs::read_to_string(path)?;
104        Ok(serde_json::from_str(&data)?)
105    }
106}
107
108/// NomicVision model configuration.
109#[derive(Debug, Clone, Deserialize)]
110pub struct NomicVisionConfig {
111    #[serde(alias = "n_embd")]
112    pub hidden_size: usize,
113    #[serde(alias = "n_layer")]
114    pub num_hidden_layers: usize,
115    #[serde(alias = "n_head")]
116    pub num_attention_heads: usize,
117    #[serde(
118        default = "default_vision_intermediate",
119        deserialize_with = "deserialize_usize_or_float"
120    )]
121    pub n_inner: usize,
122    pub img_size: usize,
123    pub patch_size: usize,
124    #[serde(default = "default_vision_ln_eps")]
125    pub layer_norm_epsilon: f64,
126}
127
128fn default_vision_intermediate() -> usize {
129    2048
130}
131fn default_vision_ln_eps() -> f64 {
132    1e-6
133}
134
135impl NomicVisionConfig {
136    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
137        let data = std::fs::read_to_string(path)?;
138        Ok(serde_json::from_str(&data)?)
139    }
140    pub fn intermediate_size(&self) -> usize {
141        self.n_inner
142    }
143    pub fn layer_norm_eps(&self) -> f64 {
144        self.layer_norm_epsilon
145    }
146}
147
148// DinoV2Config moved to crate::dinov2::config (subfolder module).
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn parse_bert_config() {
156        let json = r#"{
157            "vocab_size": 30522,
158            "hidden_size": 384,
159            "num_hidden_layers": 6,
160            "num_attention_heads": 12,
161            "intermediate_size": 1536,
162            "max_position_embeddings": 512
163        }"#;
164        let cfg: BertConfig = serde_json::from_str(json).unwrap();
165        assert_eq!(cfg.hidden_size, 384);
166        assert_eq!(cfg.head_dim(), 32);
167        assert_eq!(cfg.layer_norm_eps, 1e-12);
168    }
169}