Skip to main content

rlx_qwen3/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Qwen3 configuration. Matches HuggingFace `Qwen3ForCausalLM` config.json.
17//!
18//! Qwen3 introduces three things over Qwen2 that this struct must capture:
19//!   - **GQA** with an explicit `head_dim` (not derived from
20//!     `hidden_size / num_attention_heads`), so KV projection width is
21//!     `num_key_value_heads * head_dim` rather than `hidden_size`.
22//!   - **QK-norm**: per-head RMSNorm on Q and K before RoPE. Weight
23//!     shape `[head_dim]`, no bias.
24//!   - **Sliding-window attention** (optional, per-layer): `sliding_window`
25//!     window size, `max_window_layers` controls how many leading layers
26//!     use full attention.
27
28use serde::Deserialize;
29use std::path::Path;
30
31#[derive(Debug, Clone, Deserialize)]
32pub struct Qwen3Config {
33    pub vocab_size: usize,
34    pub hidden_size: usize,
35    pub intermediate_size: usize,
36    pub num_hidden_layers: usize,
37    pub num_attention_heads: usize,
38    pub num_key_value_heads: usize,
39    pub head_dim: usize,
40    pub max_position_embeddings: usize,
41
42    #[serde(default = "default_rms_norm_eps")]
43    pub rms_norm_eps: f64,
44    #[serde(default = "default_rope_theta")]
45    pub rope_theta: f64,
46    #[serde(default = "default_hidden_act")]
47    pub hidden_act: String,
48    #[serde(default)]
49    pub tie_word_embeddings: bool,
50
51    #[serde(default)]
52    pub attention_bias: bool,
53
54    /// Whether the model uses per-head RMS-norm on Q/K *before* RoPE
55    /// (a.k.a. "QK-norm"). Qwen 3 has it; Qwen 2 does NOT. Defaults to
56    /// `true` to match the historical Qwen 3 build path.
57    #[serde(default = "default_qk_norm")]
58    pub qk_norm: bool,
59
60    /// Sliding-window size; `None` (or absent) means full causal.
61    #[serde(default)]
62    pub sliding_window: Option<usize>,
63    /// Number of leading layers that use full causal attention; layers
64    /// `[max_window_layers, num_hidden_layers)` use sliding window when
65    /// `use_sliding_window` is true. HF default: all layers full.
66    #[serde(default = "default_max_window_layers")]
67    pub max_window_layers: usize,
68    #[serde(default)]
69    pub use_sliding_window: bool,
70
71    // ─── MoE fields (PLAN.md M1 — for `qwen3-30b-a3b-instruct`) ───
72    /// Total number of routed experts per MoE layer (0 = dense model).
73    /// HF key: `num_experts` / `n_routed_experts`. GGUF key:
74    /// `qwen3.expert_count`.
75    #[serde(default, alias = "n_routed_experts")]
76    pub num_experts: usize,
77    /// Number of experts activated per token (top-k routing).
78    /// HF key: `num_experts_per_tok`. GGUF key: `qwen3.expert_used_count`.
79    #[serde(default, alias = "num_experts_per_tok")]
80    pub num_experts_used: usize,
81    /// FFN inner width for each routed expert. When 0 falls back to
82    /// `intermediate_size / num_experts_used` to match upstream defaults.
83    /// GGUF key: `qwen3.expert_feed_forward_length`.
84    #[serde(default)]
85    pub expert_ffn_size: usize,
86    /// FFN inner width for the always-on shared expert (0 = no shared
87    /// expert). GGUF key: `qwen3.expert_shared_feed_forward_length`.
88    #[serde(default)]
89    pub shared_expert_ffn_size: usize,
90    /// Multiplier applied to routed-expert logits before softmax
91    /// (default 1.0). GGUF key: `qwen3.expert_weights_scale`.
92    #[serde(default = "default_expert_weights_scale")]
93    pub expert_weights_scale: f32,
94}
95
96fn default_expert_weights_scale() -> f32 {
97    1.0
98}
99
100fn default_rms_norm_eps() -> f64 {
101    1e-6
102}
103fn default_rope_theta() -> f64 {
104    1_000_000.0
105}
106fn default_hidden_act() -> String {
107    "silu".into()
108}
109fn default_max_window_layers() -> usize {
110    usize::MAX
111}
112fn default_qk_norm() -> bool {
113    true
114}
115
116impl Qwen3Config {
117    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
118        let data = std::fs::read_to_string(path)?;
119        Ok(serde_json::from_str(&data)?)
120    }
121
122    /// Repetition factor for GQA: how many Q heads share each KV head.
123    pub fn kv_group_size(&self) -> usize {
124        self.num_attention_heads / self.num_key_value_heads
125    }
126
127    /// Q projection output width (`num_attention_heads * head_dim`).
128    pub fn q_proj_dim(&self) -> usize {
129        self.num_attention_heads * self.head_dim
130    }
131
132    /// K/V projection output width (`num_key_value_heads * head_dim`).
133    pub fn kv_proj_dim(&self) -> usize {
134        self.num_key_value_heads * self.head_dim
135    }
136
137    /// True when the config carries MoE routing (`num_experts > 0`).
138    /// `qwen3-30b-a3b-instruct` and `qwen3-coder-next` MoE variants
139    /// will return true; dense Qwen3 returns false.
140    pub fn is_moe(&self) -> bool {
141        self.num_experts > 0
142    }
143
144    /// Routed-expert SwiGLU inner width. Falls back to
145    /// `intermediate_size / num_experts_used` when the explicit
146    /// `expert_ffn_size` is absent (matches upstream defaults).
147    pub fn expert_ffn_dim(&self) -> usize {
148        if self.expert_ffn_size > 0 {
149            self.expert_ffn_size
150        } else if self.num_experts_used > 0 {
151            self.intermediate_size
152                .checked_div(self.num_experts_used)
153                .unwrap_or(self.intermediate_size)
154        } else {
155            self.intermediate_size
156        }
157    }
158
159    /// Shared-expert SwiGLU inner width (0 when there's no shared
160    /// expert).
161    pub fn shared_expert_ffn_dim(&self) -> usize {
162        if self.shared_expert_ffn_size > 0 {
163            self.shared_expert_ffn_size
164        } else {
165            0
166        }
167    }
168
169    /// Does layer `idx` use sliding-window attention?
170    pub fn layer_uses_swa(&self, idx: usize) -> bool {
171        self.use_sliding_window && self.sliding_window.is_some() && idx >= self.max_window_layers
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn parse_qwen3_0_6b_like() {
181        let json = r#"{
182            "vocab_size": 151936,
183            "hidden_size": 1024,
184            "intermediate_size": 3072,
185            "num_hidden_layers": 28,
186            "num_attention_heads": 16,
187            "num_key_value_heads": 8,
188            "head_dim": 128,
189            "max_position_embeddings": 32768,
190            "rope_theta": 1000000.0,
191            "tie_word_embeddings": true
192        }"#;
193        let cfg: Qwen3Config = serde_json::from_str(json).unwrap();
194        assert_eq!(cfg.kv_group_size(), 2);
195        assert_eq!(cfg.q_proj_dim(), 2048);
196        assert_eq!(cfg.kv_proj_dim(), 1024);
197        assert!(cfg.tie_word_embeddings);
198        assert_eq!(cfg.rms_norm_eps, 1e-6);
199    }
200
201    #[test]
202    fn sliding_window_off_by_default() {
203        let json = r#"{
204            "vocab_size": 100,
205            "hidden_size": 64,
206            "intermediate_size": 128,
207            "num_hidden_layers": 2,
208            "num_attention_heads": 4,
209            "num_key_value_heads": 2,
210            "head_dim": 16,
211            "max_position_embeddings": 512
212        }"#;
213        let cfg: Qwen3Config = serde_json::from_str(json).unwrap();
214        assert!(!cfg.layer_uses_swa(0));
215        assert!(!cfg.layer_uses_swa(1));
216    }
217}