#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayerKind {
LinearAttn,
FullAttn,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Variant {
pub name: &'static str,
pub hidden_dim: usize,
pub num_layers: usize,
pub num_attn_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub vocab_size: usize,
pub num_experts: usize,
pub num_experts_per_tok: usize,
pub moe_intermediate: usize,
pub shared_intermediate: usize,
pub full_attn_interval: usize,
pub linear_num_v_heads: usize,
pub linear_num_k_heads: usize,
pub eos_token_1: i32,
pub eos_token_2: i32,
pub think_start_token: i32,
pub think_end_token: i32,
}
impl Variant {
pub const LINEAR_KEY_DIM: usize = 128;
pub const LINEAR_VALUE_DIM: usize = 128;
pub const CONV_KERNEL_SIZE: usize = 4;
pub const fn linear_total_key(&self) -> usize {
self.linear_num_k_heads * Self::LINEAR_KEY_DIM
}
pub const fn linear_total_value(&self) -> usize {
self.linear_num_v_heads * Self::LINEAR_VALUE_DIM
}
pub const fn linear_conv_dim(&self) -> usize {
self.linear_total_key() * 2 + self.linear_total_value()
}
pub const fn rotary_dim(&self) -> usize {
self.head_dim / 4
}
pub const fn layer_kind(&self, layer_idx: usize) -> LayerKind {
if (layer_idx + 1) % self.full_attn_interval == 0 {
LayerKind::FullAttn
} else {
LayerKind::LinearAttn
}
}
pub const fn expert_weight_bytes_4bit(&self) -> usize {
self.moe_intermediate * self.hidden_dim * BITS / 8
}
pub const fn expert_scale_bytes(&self) -> usize {
self.moe_intermediate * (self.hidden_dim / GROUP_SIZE) * 2
}
pub const fn expert_block_bytes_4bit(&self) -> usize {
self.expert_weight_bytes_4bit() + 2 * self.expert_scale_bytes()
}
pub const fn expert_size_4bit(&self) -> usize {
3 * self.expert_block_bytes_4bit()
}
pub const fn gate_w_off_4bit(&self) -> usize {
0
}
pub const fn gate_s_off_4bit(&self) -> usize {
self.expert_weight_bytes_4bit()
}
pub const fn gate_b_off_4bit(&self) -> usize {
self.expert_weight_bytes_4bit() + self.expert_scale_bytes()
}
pub const fn up_w_off_4bit(&self) -> usize {
self.expert_block_bytes_4bit()
}
pub const fn up_s_off_4bit(&self) -> usize {
self.expert_block_bytes_4bit() + self.expert_weight_bytes_4bit()
}
pub const fn up_b_off_4bit(&self) -> usize {
self.expert_block_bytes_4bit()
+ self.expert_weight_bytes_4bit()
+ self.expert_scale_bytes()
}
pub const fn down_w_off_4bit(&self) -> usize {
2 * self.expert_block_bytes_4bit()
}
pub const fn down_s_off_4bit(&self) -> usize {
2 * self.expert_block_bytes_4bit() + self.expert_weight_bytes_4bit()
}
pub const fn down_b_off_4bit(&self) -> usize {
2 * self.expert_block_bytes_4bit()
+ self.expert_weight_bytes_4bit()
+ self.expert_scale_bytes()
}
pub const fn expert_weight_bytes_2bit(&self) -> usize {
self.moe_intermediate * self.hidden_dim * 2 / 8
}
pub const fn expert_block_bytes_2bit(&self) -> usize {
self.expert_weight_bytes_2bit() + 2 * self.expert_scale_bytes()
}
pub const fn expert_size_2bit(&self) -> usize {
3 * self.expert_block_bytes_2bit()
}
}
pub const RMS_NORM_EPS: f32 = 1e-6;
pub const GROUP_SIZE: usize = 64;
pub const BITS: usize = 4;
pub const ROPE_THETA: f32 = 10_000_000.0;
pub const MAX_SEQ_LEN: usize = 1_048_576;
pub const GPU_KV_SEQ: usize = 8192;
#[cfg(feature = "model-qwen3-5-a17b")]
pub const VARIANT: Variant = Variant {
name: "Qwen3.5-397B-A17B-4bit",
hidden_dim: 4096,
num_layers: 60,
num_attn_heads: 32,
num_kv_heads: 2,
head_dim: 256,
vocab_size: 248320,
num_experts: 512,
num_experts_per_tok: 10,
moe_intermediate: 1024,
shared_intermediate: 1024,
full_attn_interval: 4,
linear_num_v_heads: 64,
linear_num_k_heads: 16,
eos_token_1: 248046,
eos_token_2: 248044,
think_start_token: 248068,
think_end_token: 248069,
};
#[cfg(feature = "model-qwen3-6-35b-a3b")]
pub const VARIANT: Variant = Variant {
name: "Qwen3.6-35B-A3B-4bit",
hidden_dim: 2048,
num_layers: 40,
num_attn_heads: 16,
num_kv_heads: 2,
head_dim: 256,
vocab_size: 248320,
num_experts: 256,
num_experts_per_tok: 8,
moe_intermediate: 512,
shared_intermediate: 512,
full_attn_interval: 4,
linear_num_v_heads: 32,
linear_num_k_heads: 16,
eos_token_1: 248046,
eos_token_2: 248044,
think_start_token: 248068,
think_end_token: 248069,
};
#[cfg(not(any(
feature = "model-qwen3-5-a17b",
feature = "model-qwen3-6-35b-a3b",
)))]
compile_error!(
"moeflux: enable exactly one model variant feature \
(`model-qwen3-5-a17b` or `model-qwen3-6-35b-a3b`)."
);
const _: () = {
assert!(
VARIANT.hidden_dim % GROUP_SIZE == 0,
"HIDDEN_DIM must be a multiple of GROUP_SIZE"
);
assert!(
VARIANT.num_attn_heads % VARIANT.num_kv_heads == 0,
"num_attn_heads must be a multiple of num_kv_heads (GQA)"
);
assert!(
VARIANT.num_experts_per_tok <= VARIANT.num_experts,
"num_experts_per_tok must be ≤ num_experts"
);
assert!(
(VARIANT.num_attn_heads * VARIANT.head_dim) % VARIANT.hidden_dim == 0,
"num_attn_heads * head_dim must be a multiple of hidden_dim"
);
};
pub const fn assert_static_invariants() {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn variant_is_well_formed() {
assert_static_invariants();
let v = VARIANT;
assert!(v.expert_size_4bit() > 0);
assert!(v.expert_size_2bit() > 0);
assert!(v.linear_conv_dim() > 0);
assert!(v.rotary_dim() > 0);
assert!(v.expert_size_2bit() < v.expert_size_4bit());
assert_eq!(v.num_attn_heads % v.num_kv_heads, 0);
}
#[test]
fn layer_kind_matches_legacy_modulo() {
let v = VARIANT;
for i in 0..v.num_layers {
let legacy_full = (i + 1) % v.full_attn_interval == 0;
let kind = v.layer_kind(i);
assert_eq!(
kind == LayerKind::FullAttn,
legacy_full,
"layer_kind({i}) disagrees with legacy modulo predicate \
(full_attn_interval = {})",
v.full_attn_interval,
);
}
let n_full = (0..v.num_layers)
.filter(|&i| v.layer_kind(i) == LayerKind::FullAttn)
.count();
assert!(n_full > 0, "every shipping variant has full-attn layers");
assert_eq!(n_full, v.num_layers / v.full_attn_interval);
}
}