rlx_llada2/llada2/
config.rs1use serde::Deserialize;
19use std::path::Path;
20
21#[derive(Debug, Clone, Deserialize)]
22pub struct LLaDA2MoeConfig {
23 pub vocab_size: usize,
24 pub hidden_size: usize,
25 #[serde(default)]
26 pub intermediate_size: Option<usize>,
27 pub num_hidden_layers: usize,
28 pub num_attention_heads: usize,
29 #[serde(default)]
30 pub num_key_value_heads: usize,
31 #[serde(default)]
32 pub head_dim: Option<usize>,
33 pub num_experts: usize,
34 pub num_experts_per_tok: usize,
35 #[serde(default)]
36 pub num_shared_experts: Option<usize>,
37 #[serde(default)]
38 pub moe_intermediate_size: Option<usize>,
39 #[serde(default = "default_n_group")]
40 pub n_group: usize,
41 #[serde(default = "default_topk_group")]
42 pub topk_group: usize,
43 #[serde(default = "default_routed_scaling")]
44 pub routed_scaling_factor: f32,
45 #[serde(default)]
46 pub first_k_dense_replace: usize,
47 pub max_position_embeddings: usize,
48 #[serde(default = "default_rope_theta")]
49 pub rope_theta: f64,
50 #[serde(default = "default_rms_norm_eps")]
51 pub rms_norm_eps: f64,
52 #[serde(default = "default_partial_rotary")]
53 pub partial_rotary_factor: f32,
54 #[serde(default)]
55 pub use_qk_norm: bool,
56 #[serde(default)]
57 pub use_qkv_bias: bool,
58 #[serde(default)]
59 pub use_bias: bool,
60 #[serde(default = "default_hidden_act")]
61 pub hidden_act: String,
62 #[serde(default)]
63 pub attention_dropout: f64,
64 #[serde(default)]
65 pub embedding_dropout: f64,
66 #[serde(default)]
67 pub output_dropout: f64,
68 #[serde(default)]
69 pub tie_word_embeddings: bool,
70 #[serde(default)]
71 pub norm_topk_prob: bool,
72 #[serde(default)]
73 pub moe_router_enable_expert_bias: bool,
74 #[serde(default)]
75 pub pad_token_id: u32,
76 #[serde(default = "default_mask_id")]
77 pub mask_token_id: u32,
78 #[serde(default = "default_eos_id")]
79 pub eos_token_id: u32,
80}
81
82fn default_n_group() -> usize {
83 8
84}
85fn default_topk_group() -> usize {
86 4
87}
88fn default_routed_scaling() -> f32 {
89 2.5
90}
91fn default_rms_norm_eps() -> f64 {
92 1e-6
93}
94fn default_rope_theta() -> f64 {
95 600_000.0
96}
97fn default_partial_rotary() -> f32 {
98 0.5
99}
100fn default_mask_id() -> u32 {
101 156_895
102}
103fn default_eos_id() -> u32 {
104 156_892
105}
106fn default_hidden_act() -> String {
107 "silu".into()
108}
109
110impl LLaDA2MoeConfig {
111 pub fn from_json_str(s: &str) -> Result<Self, serde_json::Error> {
112 serde_json::from_str(s)
113 }
114
115 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
116 let data = std::fs::read_to_string(path)?;
117 Ok(serde_json::from_str(&data)?)
118 }
119
120 pub fn from_tide_repo() -> anyhow::Result<Self> {
121 Self::from_file(Path::new("/Users/Shared/TIDE/model/config.json"))
122 }
123
124 pub fn head_dim(&self) -> usize {
125 self.head_dim
126 .unwrap_or(self.hidden_size / self.num_attention_heads)
127 }
128
129 pub fn intermediate_size(&self) -> usize {
130 self.intermediate_size.unwrap_or(self.hidden_size * 4)
131 }
132
133 pub fn expert_ffn_dim(&self) -> usize {
134 self.moe_intermediate_size.unwrap_or(512)
135 }
136
137 pub fn num_kv_heads(&self) -> usize {
138 if self.num_key_value_heads == 0 {
139 self.num_attention_heads
140 } else {
141 self.num_key_value_heads
142 }
143 }
144
145 pub fn kv_group_size(&self) -> usize {
146 self.num_attention_heads / self.num_kv_heads()
147 }
148
149 pub fn rope_dim(&self) -> usize {
150 ((self.head_dim() as f32) * self.partial_rotary_factor) as usize
151 }
152
153 pub fn is_moe_layer(&self, layer: usize) -> bool {
154 self.num_experts > 0 && layer >= self.first_k_dense_replace
155 }
156
157 pub fn num_sparse_moe_layers(&self) -> usize {
158 self.num_hidden_layers
159 .saturating_sub(self.first_k_dense_replace)
160 }
161
162 pub fn expert_param_bytes_f32(&self) -> usize {
163 let h = self.hidden_size;
164 let ff = self.expert_ffn_dim();
165 3 * h * ff * std::mem::size_of::<f32>()
166 }
167}