#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayerKind {
LinearAttn,
FullAttn,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AttnKind {
Gqa,
Mla,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MlpKind {
Dense,
MoE,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RopeKind {
Vanilla,
Yarn,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouterKind {
Softmax,
NoauxTc,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SharedExpertGate {
SigmoidGate,
Unscaled,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Variant {
pub name: &'static str,
pub hidden_dim: usize,
pub num_layers: usize,
pub num_attn_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub vocab_size: usize,
pub num_experts: usize,
pub num_experts_per_tok: usize,
pub moe_intermediate: usize,
pub shared_intermediate: usize,
pub full_attn_interval: usize,
pub linear_num_v_heads: usize,
pub linear_num_k_heads: usize,
pub eos_token_1: i32,
pub eos_token_2: i32,
pub think_start_token: i32,
pub think_end_token: i32,
pub attn_kind: AttnKind,
pub rope_kind: RopeKind,
pub router_kind: RouterKind,
pub shared_expert_gate: SharedExpertGate,
pub q_lora_rank: usize,
pub kv_lora_rank: usize,
pub qk_nope_head_dim: usize,
pub qk_rope_head_dim: usize,
pub v_head_dim: usize,
pub n_group: usize,
pub topk_group: usize,
pub routed_scaling_factor: f32,
pub first_k_dense_replace: usize,
pub dense_intermediate: usize,
pub yarn_factor: f32,
pub yarn_original_max_pos: usize,
pub yarn_beta_fast: f32,
pub yarn_beta_slow: f32,
pub yarn_mscale: f32,
pub yarn_mscale_all_dim: f32,
}
impl Variant {
pub const LINEAR_KEY_DIM: usize = 128;
pub const LINEAR_VALUE_DIM: usize = 128;
pub const CONV_KERNEL_SIZE: usize = 4;
pub const fn linear_total_key(&self) -> usize {
self.linear_num_k_heads * Self::LINEAR_KEY_DIM
}
pub const fn linear_total_value(&self) -> usize {
self.linear_num_v_heads * Self::LINEAR_VALUE_DIM
}
pub const fn linear_conv_dim(&self) -> usize {
self.linear_total_key() * 2 + self.linear_total_value()
}
pub const fn rotary_dim(&self) -> usize {
match self.attn_kind {
AttnKind::Gqa => self.head_dim / 4,
AttnKind::Mla => self.qk_rope_head_dim,
}
}
pub const fn layer_kind(&self, layer_idx: usize) -> LayerKind {
if (layer_idx + 1) % self.full_attn_interval == 0 {
LayerKind::FullAttn
} else {
LayerKind::LinearAttn
}
}
pub const fn mlp_kind_at(&self, layer_idx: usize) -> MlpKind {
if layer_idx < self.first_k_dense_replace {
MlpKind::Dense
} else {
MlpKind::MoE
}
}
pub const fn expert_weight_bytes_4bit(&self) -> usize {
self.moe_intermediate * self.hidden_dim * BITS / 8
}
pub const fn expert_scale_bytes(&self) -> usize {
self.moe_intermediate * (self.hidden_dim / GROUP_SIZE) * 2
}
pub const fn expert_block_bytes_4bit(&self) -> usize {
self.expert_weight_bytes_4bit() + 2 * self.expert_scale_bytes()
}
pub const fn expert_size_4bit(&self) -> usize {
3 * self.expert_block_bytes_4bit()
}
pub const fn gate_w_off_4bit(&self) -> usize {
0
}
pub const fn gate_s_off_4bit(&self) -> usize {
self.expert_weight_bytes_4bit()
}
pub const fn gate_b_off_4bit(&self) -> usize {
self.expert_weight_bytes_4bit() + self.expert_scale_bytes()
}
pub const fn up_w_off_4bit(&self) -> usize {
self.expert_block_bytes_4bit()
}
pub const fn up_s_off_4bit(&self) -> usize {
self.expert_block_bytes_4bit() + self.expert_weight_bytes_4bit()
}
pub const fn up_b_off_4bit(&self) -> usize {
self.expert_block_bytes_4bit()
+ self.expert_weight_bytes_4bit()
+ self.expert_scale_bytes()
}
pub const fn down_w_off_4bit(&self) -> usize {
2 * self.expert_block_bytes_4bit()
}
pub const fn down_s_off_4bit(&self) -> usize {
2 * self.expert_block_bytes_4bit() + self.expert_weight_bytes_4bit()
}
pub const fn down_b_off_4bit(&self) -> usize {
2 * self.expert_block_bytes_4bit()
+ self.expert_weight_bytes_4bit()
+ self.expert_scale_bytes()
}
pub const fn expert_weight_bytes_2bit(&self) -> usize {
self.moe_intermediate * self.hidden_dim * 2 / 8
}
pub const fn expert_block_bytes_2bit(&self) -> usize {
self.expert_weight_bytes_2bit() + 2 * self.expert_scale_bytes()
}
pub const fn expert_size_2bit(&self) -> usize {
3 * self.expert_block_bytes_2bit()
}
}
pub const RMS_NORM_EPS: f32 = 1e-6;
pub const GROUP_SIZE: usize = 64;
pub const BITS: usize = 4;
#[cfg(any(feature = "model-qwen3-5-a17b", feature = "model-qwen3-6-35b-a3b"))]
pub const ROPE_THETA: f32 = 10_000_000.0;
#[cfg(feature = "model-cogito-v2-671b")]
pub const ROPE_THETA: f32 = 10_000.0;
pub const MAX_SEQ_LEN: usize = 1_048_576;
pub const GPU_KV_SEQ: usize = 8192;
#[cfg(feature = "model-qwen3-5-a17b")]
pub const VARIANT: Variant = Variant {
name: "Qwen3.5-397B-A17B-4bit",
hidden_dim: 4096,
num_layers: 60,
num_attn_heads: 32,
num_kv_heads: 2,
head_dim: 256,
vocab_size: 248320,
num_experts: 512,
num_experts_per_tok: 10,
moe_intermediate: 1024,
shared_intermediate: 1024,
full_attn_interval: 4,
linear_num_v_heads: 64,
linear_num_k_heads: 16,
eos_token_1: 248046,
eos_token_2: 248044,
think_start_token: 248068,
think_end_token: 248069,
attn_kind: AttnKind::Gqa,
rope_kind: RopeKind::Vanilla,
router_kind: RouterKind::Softmax,
shared_expert_gate: SharedExpertGate::SigmoidGate,
q_lora_rank: 0,
kv_lora_rank: 0,
qk_nope_head_dim: 0,
qk_rope_head_dim: 0,
v_head_dim: 0,
n_group: 0,
topk_group: 0,
routed_scaling_factor: 1.0,
first_k_dense_replace: 0,
dense_intermediate: 0,
yarn_factor: 1.0,
yarn_original_max_pos: 0,
yarn_beta_fast: 0.0,
yarn_beta_slow: 0.0,
yarn_mscale: 1.0,
yarn_mscale_all_dim: 1.0,
};
#[cfg(feature = "model-qwen3-6-35b-a3b")]
pub const VARIANT: Variant = Variant {
name: "Qwen3.6-35B-A3B-4bit",
hidden_dim: 2048,
num_layers: 40,
num_attn_heads: 16,
num_kv_heads: 2,
head_dim: 256,
vocab_size: 248320,
num_experts: 256,
num_experts_per_tok: 8,
moe_intermediate: 512,
shared_intermediate: 512,
full_attn_interval: 4,
linear_num_v_heads: 32,
linear_num_k_heads: 16,
eos_token_1: 248046,
eos_token_2: 248044,
think_start_token: 248068,
think_end_token: 248069,
attn_kind: AttnKind::Gqa,
rope_kind: RopeKind::Vanilla,
router_kind: RouterKind::Softmax,
shared_expert_gate: SharedExpertGate::SigmoidGate,
q_lora_rank: 0,
kv_lora_rank: 0,
qk_nope_head_dim: 0,
qk_rope_head_dim: 0,
v_head_dim: 0,
n_group: 0,
topk_group: 0,
routed_scaling_factor: 1.0,
first_k_dense_replace: 0,
dense_intermediate: 0,
yarn_factor: 1.0,
yarn_original_max_pos: 0,
yarn_beta_fast: 0.0,
yarn_beta_slow: 0.0,
yarn_mscale: 1.0,
yarn_mscale_all_dim: 1.0,
};
#[cfg(feature = "model-cogito-v2-671b")]
pub const VARIANT: Variant = Variant {
name: "Cogito-V2-Preview-671B-4bit",
hidden_dim: 7168,
num_layers: 61,
num_attn_heads: 128,
num_kv_heads: 128,
head_dim: 192,
vocab_size: 128815,
num_experts: 256,
num_experts_per_tok: 8,
moe_intermediate: 2048,
shared_intermediate: 2048,
full_attn_interval: 1,
linear_num_v_heads: 0,
linear_num_k_heads: 0,
eos_token_1: 1,
eos_token_2: 1,
think_start_token: -1,
think_end_token: -1,
attn_kind: AttnKind::Mla,
rope_kind: RopeKind::Yarn,
router_kind: RouterKind::NoauxTc,
shared_expert_gate: SharedExpertGate::Unscaled,
q_lora_rank: 1536,
kv_lora_rank: 512,
qk_nope_head_dim: 128,
qk_rope_head_dim: 64,
v_head_dim: 128,
n_group: 8,
topk_group: 4,
routed_scaling_factor: 2.5,
first_k_dense_replace: 3,
dense_intermediate: 18432,
yarn_factor: 40.0,
yarn_original_max_pos: 4096,
yarn_beta_fast: 32.0,
yarn_beta_slow: 1.0,
yarn_mscale: 1.0,
yarn_mscale_all_dim: 1.0,
};
#[cfg(not(any(
feature = "model-qwen3-5-a17b",
feature = "model-qwen3-6-35b-a3b",
feature = "model-cogito-v2-671b",
)))]
compile_error!(
"moeflux: enable exactly one model variant feature \
(`model-qwen3-5-a17b`, `model-qwen3-6-35b-a3b`, or \
`model-cogito-v2-671b`)."
);
const _: () = {
assert!(
VARIANT.hidden_dim % GROUP_SIZE == 0,
"HIDDEN_DIM must be a multiple of GROUP_SIZE"
);
assert!(
VARIANT.num_experts_per_tok <= VARIANT.num_experts,
"num_experts_per_tok must be ≤ num_experts"
);
if matches!(VARIANT.attn_kind, AttnKind::Gqa) {
assert!(
VARIANT.num_attn_heads % VARIANT.num_kv_heads == 0,
"num_attn_heads must be a multiple of num_kv_heads (GQA)"
);
assert!(
(VARIANT.num_attn_heads * VARIANT.head_dim) % VARIANT.hidden_dim
== 0,
"num_attn_heads * head_dim must be a multiple of hidden_dim"
);
}
if matches!(VARIANT.attn_kind, AttnKind::Mla) {
assert!(
VARIANT.kv_lora_rank % GROUP_SIZE == 0,
"kv_lora_rank must be a multiple of GROUP_SIZE"
);
assert!(
VARIANT.q_lora_rank % GROUP_SIZE == 0,
"q_lora_rank must be a multiple of GROUP_SIZE"
);
assert!(
VARIANT.qk_nope_head_dim + VARIANT.qk_rope_head_dim > 0,
"MLA must define qk_nope_head_dim + qk_rope_head_dim"
);
assert!(VARIANT.v_head_dim > 0, "MLA must define v_head_dim");
}
if matches!(VARIANT.router_kind, RouterKind::NoauxTc) {
assert!(
VARIANT.n_group > 0 && VARIANT.topk_group > 0,
"noaux_tc requires n_group and topk_group > 0"
);
assert!(
VARIANT.num_experts % VARIANT.n_group == 0,
"num_experts must be divisible by n_group for group-limit routing"
);
assert!(
VARIANT.topk_group <= VARIANT.n_group,
"topk_group must be ≤ n_group"
);
}
assert!(
VARIANT.first_k_dense_replace < VARIANT.num_layers,
"first_k_dense_replace must be strictly less than num_layers"
);
};
pub const fn assert_static_invariants() {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn variant_is_well_formed() {
assert_static_invariants();
let v = VARIANT;
assert!(v.expert_size_4bit() > 0);
assert!(v.expert_size_2bit() > 0);
assert!(v.rotary_dim() > 0);
assert!(v.expert_size_2bit() < v.expert_size_4bit());
if matches!(v.attn_kind, AttnKind::Gqa) {
assert_eq!(v.num_attn_heads % v.num_kv_heads, 0);
assert!(v.linear_conv_dim() > 0);
}
}
#[test]
fn layer_kind_matches_legacy_modulo() {
let v = VARIANT;
for i in 0..v.num_layers {
let legacy_full = (i + 1) % v.full_attn_interval == 0;
let kind = v.layer_kind(i);
assert_eq!(
kind == LayerKind::FullAttn,
legacy_full,
"layer_kind({i}) disagrees with legacy modulo predicate \
(full_attn_interval = {})",
v.full_attn_interval,
);
}
let n_full = (0..v.num_layers)
.filter(|&i| v.layer_kind(i) == LayerKind::FullAttn)
.count();
assert!(n_full > 0, "every shipping variant has full-attn layers");
assert_eq!(n_full, v.num_layers / v.full_attn_interval);
}
#[test]
fn mlp_kind_dense_then_moe() {
let v = VARIANT;
for i in 0..v.num_layers {
let expected = if i < v.first_k_dense_replace {
MlpKind::Dense
} else {
MlpKind::MoE
};
assert_eq!(v.mlp_kind_at(i), expected, "mlp_kind_at({i})");
}
}
}