rlx_qwen3/config.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Qwen3 configuration. Matches HuggingFace `Qwen3ForCausalLM` config.json.
17//!
18//! Qwen3 introduces three things over Qwen2 that this struct must capture:
19//! - **GQA** with an explicit `head_dim` (not derived from
20//! `hidden_size / num_attention_heads`), so KV projection width is
21//! `num_key_value_heads * head_dim` rather than `hidden_size`.
22//! - **QK-norm**: per-head RMSNorm on Q and K before RoPE. Weight
23//! shape `[head_dim]`, no bias.
24//! - **Sliding-window attention** (optional, per-layer): `sliding_window`
25//! window size, `max_window_layers` controls how many leading layers
26//! use full attention.
27
28use serde::Deserialize;
29use std::path::Path;
30
31#[derive(Debug, Clone, Deserialize)]
32pub struct Qwen3Config {
33 pub vocab_size: usize,
34 pub hidden_size: usize,
35 pub intermediate_size: usize,
36 pub num_hidden_layers: usize,
37 pub num_attention_heads: usize,
38 pub num_key_value_heads: usize,
39 pub head_dim: usize,
40 pub max_position_embeddings: usize,
41
42 #[serde(default = "default_rms_norm_eps")]
43 pub rms_norm_eps: f64,
44 #[serde(default = "default_rope_theta")]
45 pub rope_theta: f64,
46 #[serde(default = "default_hidden_act")]
47 pub hidden_act: String,
48 #[serde(default)]
49 pub tie_word_embeddings: bool,
50
51 #[serde(default)]
52 pub attention_bias: bool,
53
54 /// Whether the model uses per-head RMS-norm on Q/K *before* RoPE
55 /// (a.k.a. "QK-norm"). Qwen 3 has it; Qwen 2 does NOT. Defaults to
56 /// `true` to match the historical Qwen 3 build path.
57 #[serde(default = "default_qk_norm")]
58 pub qk_norm: bool,
59
60 /// Sliding-window size; `None` (or absent) means full causal.
61 #[serde(default)]
62 pub sliding_window: Option<usize>,
63 /// Number of leading layers that use full causal attention; layers
64 /// `[max_window_layers, num_hidden_layers)` use sliding window when
65 /// `use_sliding_window` is true. HF default: all layers full.
66 #[serde(default = "default_max_window_layers")]
67 pub max_window_layers: usize,
68 #[serde(default)]
69 pub use_sliding_window: bool,
70
71 // ─── MoE fields (PLAN.md M1 — for `qwen3-30b-a3b-instruct`) ───
72 /// Total number of routed experts per MoE layer (0 = dense model).
73 /// HF key: `num_experts` / `n_routed_experts`. GGUF key:
74 /// `qwen3.expert_count`.
75 #[serde(default, alias = "n_routed_experts")]
76 pub num_experts: usize,
77 /// Number of experts activated per token (top-k routing).
78 /// HF key: `num_experts_per_tok`. GGUF key: `qwen3.expert_used_count`.
79 #[serde(default, alias = "num_experts_per_tok")]
80 pub num_experts_used: usize,
81 /// FFN inner width for each routed expert. When 0 falls back to
82 /// `intermediate_size / num_experts_used` to match upstream defaults.
83 /// GGUF key: `qwen3.expert_feed_forward_length`.
84 #[serde(default)]
85 pub expert_ffn_size: usize,
86 /// FFN inner width for the always-on shared expert (0 = no shared
87 /// expert). GGUF key: `qwen3.expert_shared_feed_forward_length`.
88 #[serde(default)]
89 pub shared_expert_ffn_size: usize,
90 /// Multiplier applied to routed-expert logits before softmax
91 /// (default 1.0). GGUF key: `qwen3.expert_weights_scale`.
92 #[serde(default = "default_expert_weights_scale")]
93 pub expert_weights_scale: f32,
94}
95
96fn default_expert_weights_scale() -> f32 {
97 1.0
98}
99
100fn default_rms_norm_eps() -> f64 {
101 1e-6
102}
103fn default_rope_theta() -> f64 {
104 1_000_000.0
105}
106fn default_hidden_act() -> String {
107 "silu".into()
108}
109fn default_max_window_layers() -> usize {
110 usize::MAX
111}
112fn default_qk_norm() -> bool {
113 true
114}
115
116impl Qwen3Config {
117 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
118 let data = std::fs::read_to_string(path)?;
119 Ok(serde_json::from_str(&data)?)
120 }
121
122 /// Repetition factor for GQA: how many Q heads share each KV head.
123 pub fn kv_group_size(&self) -> usize {
124 self.num_attention_heads / self.num_key_value_heads
125 }
126
127 /// Q projection output width (`num_attention_heads * head_dim`).
128 pub fn q_proj_dim(&self) -> usize {
129 self.num_attention_heads * self.head_dim
130 }
131
132 /// K/V projection output width (`num_key_value_heads * head_dim`).
133 pub fn kv_proj_dim(&self) -> usize {
134 self.num_key_value_heads * self.head_dim
135 }
136
137 /// True when the config carries MoE routing (`num_experts > 0`).
138 /// `qwen3-30b-a3b-instruct` and `qwen3-coder-next` MoE variants
139 /// will return true; dense Qwen3 returns false.
140 pub fn is_moe(&self) -> bool {
141 self.num_experts > 0
142 }
143
144 /// Routed-expert SwiGLU inner width. Falls back to
145 /// `intermediate_size / num_experts_used` when the explicit
146 /// `expert_ffn_size` is absent (matches upstream defaults).
147 pub fn expert_ffn_dim(&self) -> usize {
148 if self.expert_ffn_size > 0 {
149 self.expert_ffn_size
150 } else if self.num_experts_used > 0 {
151 self.intermediate_size
152 .checked_div(self.num_experts_used)
153 .unwrap_or(self.intermediate_size)
154 } else {
155 self.intermediate_size
156 }
157 }
158
159 /// Shared-expert SwiGLU inner width (0 when there's no shared
160 /// expert).
161 pub fn shared_expert_ffn_dim(&self) -> usize {
162 if self.shared_expert_ffn_size > 0 {
163 self.shared_expert_ffn_size
164 } else {
165 0
166 }
167 }
168
169 /// Does layer `idx` use sliding-window attention?
170 pub fn layer_uses_swa(&self, idx: usize) -> bool {
171 self.use_sliding_window && self.sliding_window.is_some() && idx >= self.max_window_layers
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn parse_qwen3_0_6b_like() {
181 let json = r#"{
182 "vocab_size": 151936,
183 "hidden_size": 1024,
184 "intermediate_size": 3072,
185 "num_hidden_layers": 28,
186 "num_attention_heads": 16,
187 "num_key_value_heads": 8,
188 "head_dim": 128,
189 "max_position_embeddings": 32768,
190 "rope_theta": 1000000.0,
191 "tie_word_embeddings": true
192 }"#;
193 let cfg: Qwen3Config = serde_json::from_str(json).unwrap();
194 assert_eq!(cfg.kv_group_size(), 2);
195 assert_eq!(cfg.q_proj_dim(), 2048);
196 assert_eq!(cfg.kv_proj_dim(), 1024);
197 assert!(cfg.tie_word_embeddings);
198 assert_eq!(cfg.rms_norm_eps, 1e-6);
199 }
200
201 #[test]
202 fn sliding_window_off_by_default() {
203 let json = r#"{
204 "vocab_size": 100,
205 "hidden_size": 64,
206 "intermediate_size": 128,
207 "num_hidden_layers": 2,
208 "num_attention_heads": 4,
209 "num_key_value_heads": 2,
210 "head_dim": 16,
211 "max_position_embeddings": 512
212 }"#;
213 let cfg: Qwen3Config = serde_json::from_str(json).unwrap();
214 assert!(!cfg.layer_uses_swa(0));
215 assert!(!cfg.layer_uses_swa(1));
216 }
217}