rlx_models_core/
config.rs1use serde::{Deserialize, Deserializer};
19use std::path::Path;
20
21fn deserialize_usize_or_float<'de, D: Deserializer<'de>>(d: D) -> Result<usize, D::Error> {
22 let v: serde_json::Value = Deserialize::deserialize(d)?;
23 match v {
24 serde_json::Value::Number(n) => {
25 if let Some(u) = n.as_u64() {
26 Ok(u as usize)
27 } else if let Some(f) = n.as_f64() {
28 Ok(f as usize)
29 } else {
30 Err(serde::de::Error::custom("expected number"))
31 }
32 }
33 _ => Err(serde::de::Error::custom("expected number")),
34 }
35}
36
37#[derive(Debug, Clone, Deserialize)]
39pub struct BertConfig {
40 pub vocab_size: usize,
41 pub hidden_size: usize,
42 pub num_hidden_layers: usize,
43 pub num_attention_heads: usize,
44 pub intermediate_size: usize,
45 pub max_position_embeddings: usize,
46 #[serde(default = "default_type_vocab_size")]
47 pub type_vocab_size: usize,
48 #[serde(default = "default_layer_norm_eps")]
49 pub layer_norm_eps: f64,
50 #[serde(default = "default_hidden_act")]
51 pub hidden_act: String,
52}
53
54fn default_type_vocab_size() -> usize {
55 2
56}
57fn default_layer_norm_eps() -> f64 {
58 1e-12
59}
60fn default_hidden_act() -> String {
61 "gelu".into()
62}
63
64impl BertConfig {
65 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
66 let data = std::fs::read_to_string(path)?;
67 Ok(serde_json::from_str(&data)?)
68 }
69
70 pub fn head_dim(&self) -> usize {
71 self.hidden_size / self.num_attention_heads
72 }
73}
74
75#[derive(Debug, Clone, Deserialize)]
77pub struct NomicBertConfig {
78 pub vocab_size: usize,
79 pub hidden_size: usize,
80 pub num_hidden_layers: usize,
81 pub num_attention_heads: usize,
82 pub intermediate_size: usize,
83 pub max_position_embeddings: usize,
84 #[serde(default = "default_type_vocab_size")]
85 pub type_vocab_size: usize,
86 #[serde(default = "default_layer_norm_eps")]
87 pub layer_norm_eps: f64,
88 #[serde(default = "default_head_dim")]
89 pub head_dim: usize,
90 #[serde(default = "default_rotary_emb_base")]
91 pub rotary_emb_base: f64,
92}
93
94fn default_head_dim() -> usize {
95 64
96}
97fn default_rotary_emb_base() -> f64 {
98 1000.0
99}
100
101impl NomicBertConfig {
102 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
103 let data = std::fs::read_to_string(path)?;
104 Ok(serde_json::from_str(&data)?)
105 }
106}
107
108#[derive(Debug, Clone, Deserialize)]
110pub struct NomicVisionConfig {
111 #[serde(alias = "n_embd")]
112 pub hidden_size: usize,
113 #[serde(alias = "n_layer")]
114 pub num_hidden_layers: usize,
115 #[serde(alias = "n_head")]
116 pub num_attention_heads: usize,
117 #[serde(
118 default = "default_vision_intermediate",
119 deserialize_with = "deserialize_usize_or_float"
120 )]
121 pub n_inner: usize,
122 pub img_size: usize,
123 pub patch_size: usize,
124 #[serde(default = "default_vision_ln_eps")]
125 pub layer_norm_epsilon: f64,
126}
127
128fn default_vision_intermediate() -> usize {
129 2048
130}
131fn default_vision_ln_eps() -> f64 {
132 1e-6
133}
134
135impl NomicVisionConfig {
136 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
137 let data = std::fs::read_to_string(path)?;
138 Ok(serde_json::from_str(&data)?)
139 }
140 pub fn intermediate_size(&self) -> usize {
141 self.n_inner
142 }
143 pub fn layer_norm_eps(&self) -> f64 {
144 self.layer_norm_epsilon
145 }
146}
147
148#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn parse_bert_config() {
156 let json = r#"{
157 "vocab_size": 30522,
158 "hidden_size": 384,
159 "num_hidden_layers": 6,
160 "num_attention_heads": 12,
161 "intermediate_size": 1536,
162 "max_position_embeddings": 512
163 }"#;
164 let cfg: BertConfig = serde_json::from_str(json).unwrap();
165 assert_eq!(cfg.hidden_size, 384);
166 assert_eq!(cfg.head_dim(), 32);
167 assert_eq!(cfg.layer_norm_eps, 1e-12);
168 }
169}