Skip to main content

rlx_llada2/llada2/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16// RLX — LLaDA2 MoE config (`/Users/Shared/TIDE/model/config.json`).
17
18use serde::Deserialize;
19use std::path::Path;
20
21#[derive(Debug, Clone, Deserialize)]
22pub struct LLaDA2MoeConfig {
23    pub vocab_size: usize,
24    pub hidden_size: usize,
25    #[serde(default)]
26    pub intermediate_size: Option<usize>,
27    pub num_hidden_layers: usize,
28    pub num_attention_heads: usize,
29    #[serde(default)]
30    pub num_key_value_heads: usize,
31    #[serde(default)]
32    pub head_dim: Option<usize>,
33    pub num_experts: usize,
34    pub num_experts_per_tok: usize,
35    #[serde(default)]
36    pub num_shared_experts: Option<usize>,
37    #[serde(default)]
38    pub moe_intermediate_size: Option<usize>,
39    #[serde(default = "default_n_group")]
40    pub n_group: usize,
41    #[serde(default = "default_topk_group")]
42    pub topk_group: usize,
43    #[serde(default = "default_routed_scaling")]
44    pub routed_scaling_factor: f32,
45    #[serde(default)]
46    pub first_k_dense_replace: usize,
47    pub max_position_embeddings: usize,
48    #[serde(default = "default_rope_theta")]
49    pub rope_theta: f64,
50    #[serde(default = "default_rms_norm_eps")]
51    pub rms_norm_eps: f64,
52    #[serde(default = "default_partial_rotary")]
53    pub partial_rotary_factor: f32,
54    #[serde(default)]
55    pub use_qk_norm: bool,
56    #[serde(default)]
57    pub use_qkv_bias: bool,
58    #[serde(default)]
59    pub use_bias: bool,
60    #[serde(default = "default_hidden_act")]
61    pub hidden_act: String,
62    #[serde(default)]
63    pub attention_dropout: f64,
64    #[serde(default)]
65    pub embedding_dropout: f64,
66    #[serde(default)]
67    pub output_dropout: f64,
68    #[serde(default)]
69    pub tie_word_embeddings: bool,
70    #[serde(default)]
71    pub norm_topk_prob: bool,
72    #[serde(default)]
73    pub moe_router_enable_expert_bias: bool,
74    #[serde(default)]
75    pub pad_token_id: u32,
76    #[serde(default = "default_mask_id")]
77    pub mask_token_id: u32,
78    #[serde(default = "default_eos_id")]
79    pub eos_token_id: u32,
80}
81
82fn default_n_group() -> usize {
83    8
84}
85fn default_topk_group() -> usize {
86    4
87}
88fn default_routed_scaling() -> f32 {
89    2.5
90}
91fn default_rms_norm_eps() -> f64 {
92    1e-6
93}
94fn default_rope_theta() -> f64 {
95    600_000.0
96}
97fn default_partial_rotary() -> f32 {
98    0.5
99}
100fn default_mask_id() -> u32 {
101    156_895
102}
103fn default_eos_id() -> u32 {
104    156_892
105}
106fn default_hidden_act() -> String {
107    "silu".into()
108}
109
110impl LLaDA2MoeConfig {
111    pub fn from_json_str(s: &str) -> Result<Self, serde_json::Error> {
112        serde_json::from_str(s)
113    }
114
115    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
116        let data = std::fs::read_to_string(path)?;
117        Ok(serde_json::from_str(&data)?)
118    }
119
120    pub fn from_tide_repo() -> anyhow::Result<Self> {
121        Self::from_file(Path::new("/Users/Shared/TIDE/model/config.json"))
122    }
123
124    pub fn head_dim(&self) -> usize {
125        self.head_dim
126            .unwrap_or(self.hidden_size / self.num_attention_heads)
127    }
128
129    pub fn intermediate_size(&self) -> usize {
130        self.intermediate_size.unwrap_or(self.hidden_size * 4)
131    }
132
133    pub fn expert_ffn_dim(&self) -> usize {
134        self.moe_intermediate_size.unwrap_or(512)
135    }
136
137    pub fn num_kv_heads(&self) -> usize {
138        if self.num_key_value_heads == 0 {
139            self.num_attention_heads
140        } else {
141            self.num_key_value_heads
142        }
143    }
144
145    pub fn kv_group_size(&self) -> usize {
146        self.num_attention_heads / self.num_kv_heads()
147    }
148
149    pub fn rope_dim(&self) -> usize {
150        ((self.head_dim() as f32) * self.partial_rotary_factor) as usize
151    }
152
153    pub fn is_moe_layer(&self, layer: usize) -> bool {
154        self.num_experts > 0 && layer >= self.first_k_dense_replace
155    }
156
157    pub fn num_sparse_moe_layers(&self) -> usize {
158        self.num_hidden_layers
159            .saturating_sub(self.first_k_dense_replace)
160    }
161
162    pub fn expert_param_bytes_f32(&self) -> usize {
163        let h = self.hidden_size;
164        let ff = self.expert_ffn_dim();
165        3 * h * ff * std::mem::size_of::<f32>()
166    }
167}