use crate::{
array::Array,
dtype::Dtype,
error::{
ArithmeticOverflowPayload, InvariantViolationPayload, LengthMismatchPayload, OutOfRangePayload,
RankMismatchPayload, Result, ShapePairMismatchPayload, UnknownEnumValuePayload,
},
ops::{arithmetic, indexing, linalg_basic, misc, quantized, shape},
};
use smol_str::format_smolstr;
use super::activations::silu;
#[derive(Debug)]
pub struct SwitchLinear {
weight: Array,
bias: Option<Array>,
}
impl SwitchLinear {
pub fn from_parts(weight: Array, bias: Option<Array>) -> Result<Self> {
let w_shape = weight.shape();
if w_shape.len() != 3 {
return Err(crate::Error::RankMismatch(RankMismatchPayload::new(
"SwitchLinear::from_parts: weight must be 3-D [num_experts, output_dims, input_dims]",
w_shape.len() as u32,
w_shape.to_vec(),
)));
}
if let Some(b) = &bias {
let b_shape = b.shape();
if b_shape.len() != 2 {
return Err(crate::Error::RankMismatch(RankMismatchPayload::new(
"SwitchLinear::from_parts: bias must be rank-2 [num_experts, output_dims]",
b_shape.len() as u32,
b_shape.to_vec(),
)));
}
if b_shape[0] != w_shape[0] || b_shape[1] != w_shape[1] {
return Err(crate::Error::ShapePairMismatch(
ShapePairMismatchPayload::new(
"SwitchLinear::from_parts: bias must be [num_experts, output_dims]",
vec![w_shape[0], w_shape[1]],
b_shape.to_vec(),
),
));
}
}
Ok(Self { weight, bias })
}
pub fn weight_ref(&self) -> &Array {
&self.weight
}
pub fn bias(&self) -> Option<&Array> {
self.bias.as_ref()
}
pub fn num_experts(&self) -> usize {
self.weight.shape()[0]
}
pub fn output_dims(&self) -> usize {
self.weight.shape()[1]
}
pub fn input_dims(&self) -> usize {
self.weight.shape()[2]
}
pub fn apply(&self, x: &Array, indices: &Array, sorted_indices: bool) -> Result<Array> {
let weight_t = shape::swapaxes(&self.weight, -1, -2)?;
let mut out = linalg_basic::gather_mm(x, &weight_t, None, Some(indices), sorted_indices)?;
if let Some(bias) = &self.bias {
let selected = indexing::take_axis(bias, indices, 0)?;
let broadcastable = shape::expand_dims_axes(&selected, &[-2])?;
out = out.add(&broadcastable)?;
}
Ok(out)
}
}
#[derive(Debug)]
pub struct QuantizedSwitchLinear {
weight: Array,
scales: Array,
quant_biases: Option<Array>,
bias: Option<Array>,
group_size: i32,
bits: i32,
mode: String,
}
impl QuantizedSwitchLinear {
#[allow(clippy::too_many_arguments)]
pub fn from_parts(
weight: Array,
scales: Array,
quant_biases: Option<Array>,
bias: Option<Array>,
group_size: i32,
bits: i32,
mode: impl Into<String>,
) -> Result<Self> {
let mode = mode.into();
let w_shape = weight.shape();
if w_shape.len() != 3 {
return Err(crate::Error::RankMismatch(RankMismatchPayload::new(
"QuantizedSwitchLinear::from_parts: weight must be 3-D [num_experts, output_dims, packed_input_dims]",
w_shape.len() as u32,
w_shape.to_vec(),
)));
}
let w_dtype = weight.dtype()?;
if w_dtype != Dtype::U32 {
return Err(crate::Error::InvariantViolation(
InvariantViolationPayload::new(
"QuantizedSwitchLinear::from_parts: weight dtype (gather_qmm rejects non-`uint32` quantized weights)",
"must be `uint32` (the mlx-quantized-weight dtype)",
),
));
}
let e = w_shape[0];
let o = w_shape[1];
let s_shape = scales.shape();
if s_shape.len() != w_shape.len() {
return Err(crate::Error::RankMismatch(RankMismatchPayload::new(
"QuantizedSwitchLinear::from_parts: scales rank must match weight rank (mlx `quantize` preserves leading shape across (weight, scales, biases))",
s_shape.len() as u32,
s_shape.to_vec(),
)));
}
if s_shape[0] != e || s_shape[1] != o {
return Err(crate::Error::ShapePairMismatch(
ShapePairMismatchPayload::new(
"QuantizedSwitchLinear::from_parts: scales leading dims (E, O) must match weight",
vec![e, o],
vec![s_shape[0], s_shape[1]],
),
));
}
if let Some(qb) = &quant_biases {
let qb_shape = qb.shape();
if qb_shape.len() != s_shape.len() {
return Err(crate::Error::RankMismatch(RankMismatchPayload::new(
"QuantizedSwitchLinear::from_parts: quant_biases rank must match scales rank (mlx `affine_quantize` writes identical `[E, O, n_groups]` shape)",
qb_shape.len() as u32,
qb_shape.to_vec(),
)));
}
if qb_shape != s_shape {
return Err(crate::Error::ShapePairMismatch(
ShapePairMismatchPayload::new(
"QuantizedSwitchLinear::from_parts: quant_biases shape must match scales (mlx `affine_quantize` writes identical `[E, O, n_groups]` shape)",
s_shape.to_vec(),
qb_shape.to_vec(),
),
));
}
}
match (mode.as_str(), quant_biases.as_ref()) {
("affine", None) => {
return Err(crate::Error::InvariantViolation(
InvariantViolationPayload::new(
"QuantizedSwitchLinear::from_parts: `affine` mode quant_biases (mlx `affine_quantize` always writes {w_q, scales, biases})",
"must be Some for `affine` mode",
),
));
}
("mxfp4" | "mxfp8" | "nvfp4", Some(_)) => {
return Err(crate::Error::InvariantViolation(
InvariantViolationPayload::new(
"QuantizedSwitchLinear::from_parts: mxfp4 / mxfp8 / nvfp4 mode is scale-only (mlx `fp_quantize` writes {w_q, scales} with no biases); got a stale `quant_biases`",
"must be None for mxfp4 / mxfp8 / nvfp4 mode",
),
));
}
("affine", Some(_)) | ("mxfp4" | "mxfp8" | "nvfp4", None) => {
}
(other, _) => {
return Err(crate::Error::UnknownEnumValue(
UnknownEnumValuePayload::new(
"QuantizedSwitchLinear::mode",
other.to_string(),
&["affine", "mxfp4", "mxfp8", "nvfp4"],
),
));
}
}
if bits <= 0 {
return Err(crate::Error::OutOfRange(OutOfRangePayload::new(
"QuantizedSwitchLinear::from_parts: bits (per-mode value tables validated by mlx-c)",
"must be > 0",
format_smolstr!("{bits}"),
)));
}
if group_size <= 0 {
return Err(crate::Error::OutOfRange(OutOfRangePayload::new(
"QuantizedSwitchLinear::from_parts: group_size (per-mode value tables validated by mlx-c)",
"must be > 0",
format_smolstr!("{group_size}"),
)));
}
if let Some(b) = &bias {
let b_shape = b.shape();
if b_shape.len() != 2 {
return Err(crate::Error::RankMismatch(RankMismatchPayload::new(
"QuantizedSwitchLinear::from_parts: bias must be rank-2 [num_experts, output_dims]",
b_shape.len() as u32,
b_shape.to_vec(),
)));
}
if b_shape[0] != e || b_shape[1] != o {
return Err(crate::Error::ShapePairMismatch(
ShapePairMismatchPayload::new(
"QuantizedSwitchLinear::from_parts: bias must be [num_experts, output_dims]",
vec![e, o],
b_shape.to_vec(),
),
));
}
}
Ok(Self {
weight,
scales,
quant_biases,
bias,
group_size,
bits,
mode,
})
}
pub fn weight_ref(&self) -> &Array {
&self.weight
}
pub fn scales_ref(&self) -> &Array {
&self.scales
}
pub fn quant_biases(&self) -> Option<&Array> {
self.quant_biases.as_ref()
}
pub fn bias(&self) -> Option<&Array> {
self.bias.as_ref()
}
pub fn group_size(&self) -> i32 {
self.group_size
}
pub fn bits(&self) -> i32 {
self.bits
}
pub fn mode(&self) -> &str {
&self.mode
}
pub fn apply(&self, x: &Array, indices: &Array, sorted_indices: bool) -> Result<Array> {
let mut out = quantized::gather_qmm(
x,
&self.weight,
&self.scales,
self.quant_biases.as_ref(),
None, Some(indices),
true, self.group_size,
self.bits,
&self.mode,
sorted_indices,
)?;
if let Some(bias) = &self.bias {
let selected = indexing::take_axis(bias, indices, 0)?;
let broadcastable = shape::expand_dims_axes(&selected, &[-2])?;
out = out.add(&broadcastable)?;
}
Ok(out)
}
}
pub type Activation = Box<dyn Fn(&Array) -> Result<Array>>;
fn check_routing_indices(x: &Array, indices: &Array) -> Result<()> {
let x_shape = x.shape();
if x_shape.is_empty() {
return Err(crate::Error::RankMismatch(RankMismatchPayload::new(
"SwitchGLU/SwitchMLP::forward: x must have at least one axis ([..batch.., input_dims])",
0,
x_shape.to_vec(),
)));
}
let x_batch = &x_shape[..x_shape.len() - 1];
let idx_shape = indices.shape();
let expected_rank = x_batch.len() + 1;
if idx_shape.len() != expected_rank {
return Err(crate::Error::RankMismatch(RankMismatchPayload::new(
"SwitchGLU/SwitchMLP::forward: indices must be [..batch.., k] — missing or extra trailing top-k axis (pass [..batch.., 1] for top-1 routing)",
idx_shape.len() as u32,
idx_shape.to_vec(),
)));
}
let idx_lead = &idx_shape[..x_batch.len()];
if idx_lead != x_batch {
let mut diff_idx: Option<usize> = None;
let mut diff_count = 0usize;
for (i, (e, a)) in x_batch.iter().zip(idx_lead.iter()).enumerate() {
if e != a {
diff_count += 1;
if diff_count == 1 {
diff_idx = Some(i);
}
}
}
debug_assert!(diff_count >= 1, "idx_lead != x_batch ⇒ at least one diff");
if diff_count == 1 {
let i = diff_idx.expect("diff_count == 1 ⇒ diff_idx is Some");
return Err(crate::Error::LengthMismatch(LengthMismatchPayload::new(
"SwitchGLU/SwitchMLP::forward: indices leading-dim length must match x's corresponding batch dim",
x_batch[i],
idx_lead[i],
)));
}
return Err(crate::Error::ShapePairMismatch(
ShapePairMismatchPayload::new(
"SwitchGLU/SwitchMLP::forward: indices leading dims must match x's batch dims",
x_batch,
idx_lead,
),
));
}
Ok(())
}
fn gather_sort(x: &Array, indices: &Array) -> Result<(Array, Array, Array)> {
let m = *indices
.shape()
.last()
.expect("gather_sort: indices must have at least one axis");
let indices_flat = shape::flatten(indices, 0, -1)?;
let order = misc::argsort(&indices_flat)?;
let inv_order = misc::argsort(&order)?;
let x_flat = shape::flatten(x, 0, -3)?;
let m_u32 = u32::try_from(m).map_err(|_| {
crate::Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"gather_sort: top-k count exceeds u32::MAX",
"u32",
[("top_k", m as u64)],
))
})?;
let m_arr = Array::from_slice::<u32>(&[m_u32], &(1usize,))?;
let token_rows = arithmetic::floor_divide(&order, &m_arr)?;
let x_sorted = indexing::take_axis(&x_flat, &token_rows, 0)?;
let indices_sorted = indexing::take_axis(&indices_flat, &order, 0)?;
Ok((x_sorted, indices_sorted, inv_order))
}
fn scatter_unsort(x: &Array, inv_order: &Array, shape: &[usize]) -> Result<Array> {
let unsorted = indexing::take_axis(x, inv_order, 0)?;
let trailing = &unsorted.shape()[1..];
let mut target: Vec<usize> = Vec::with_capacity(shape.len() + trailing.len());
target.extend_from_slice(shape);
target.extend_from_slice(trailing);
shape::reshape(&unsorted, &target.as_slice())
}
pub struct SwitchGLU {
gate_proj: SwitchLinear,
up_proj: SwitchLinear,
down_proj: SwitchLinear,
activation: Activation,
}
impl SwitchGLU {
pub fn default_activation() -> Activation {
Box::new(silu)
}
pub fn new(
gate_proj: SwitchLinear,
up_proj: SwitchLinear,
down_proj: SwitchLinear,
activation: Activation,
) -> Result<Self> {
check_glu_shapes(&gate_proj, &up_proj, &down_proj)?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
activation,
})
}
pub fn gate_proj(&self) -> &SwitchLinear {
&self.gate_proj
}
pub fn up_proj(&self) -> &SwitchLinear {
&self.up_proj
}
pub fn down_proj(&self) -> &SwitchLinear {
&self.down_proj
}
pub fn forward(&self, x: &Array, indices: &Array) -> Result<Array> {
check_routing_indices(x, indices)?;
let mut x = shape::expand_dims_axes(x, &[-2, -3])?;
let do_sort = indices.size() >= 64;
let mut idx = indices.try_clone()?;
let mut inv_order: Option<Array> = None;
if do_sort {
let (x_sorted, idx_sorted, inv) = gather_sort(&x, indices)?;
x = x_sorted;
idx = idx_sorted;
inv_order = Some(inv);
}
let x_up = self.up_proj.apply(&x, &idx, do_sort)?;
let x_gate = self.gate_proj.apply(&x, &idx, do_sort)?;
let gated = (self.activation)(&x_gate)?.multiply(&x_up)?;
x = self.down_proj.apply(&gated, &idx, do_sort)?;
if let Some(inv) = &inv_order {
x = scatter_unsort(&x, inv, &indices.shape())?;
}
shape::squeeze_axes(&x, &[-2])
}
}
impl std::fmt::Debug for SwitchGLU {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SwitchGLU")
.field("gate_proj", &self.gate_proj)
.field("up_proj", &self.up_proj)
.field("down_proj", &self.down_proj)
.field("activation", &"<fn>")
.finish()
}
}
fn check_glu_shapes(
gate_proj: &SwitchLinear,
up_proj: &SwitchLinear,
down_proj: &SwitchLinear,
) -> Result<()> {
let (gi, gh, ge) = (
gate_proj.input_dims(),
gate_proj.output_dims(),
gate_proj.num_experts(),
);
let (ui, uh, ue) = (
up_proj.input_dims(),
up_proj.output_dims(),
up_proj.num_experts(),
);
let (di, dh, de) = (
down_proj.input_dims(),
down_proj.output_dims(),
down_proj.num_experts(),
);
if gi != ui || gh != uh {
return Err(crate::Error::ShapePairMismatch(
ShapePairMismatchPayload::new(
"SwitchGLU: gate_proj and up_proj must share [input_dims, hidden_dims]",
vec![gi, gh],
vec![ui, uh],
),
));
}
if di != gh || dh != gi {
return Err(crate::Error::ShapePairMismatch(
ShapePairMismatchPayload::new(
"SwitchGLU: down_proj must be the [hidden_dims, input_dims] inverse of gate_proj/up_proj",
vec![gh, gi],
vec![di, dh],
),
));
}
if ge != ue || ge != de {
return Err(crate::Error::ShapePairMismatch(
ShapePairMismatchPayload::new(
"SwitchGLU: all projections must have the same num_experts (gate_proj, up_proj, down_proj)",
vec![ge, ge, ge],
vec![ge, ue, de],
),
));
}
Ok(())
}
pub struct SwitchMLP {
fc1: SwitchLinear,
fc2: SwitchLinear,
activation: Activation,
}
impl SwitchMLP {
pub fn default_activation() -> Activation {
Box::new(super::activations::gelu_approx)
}
pub fn new(fc1: SwitchLinear, fc2: SwitchLinear, activation: Activation) -> Result<Self> {
if fc2.input_dims() != fc1.output_dims() || fc2.output_dims() != fc1.input_dims() {
return Err(crate::Error::ShapePairMismatch(
ShapePairMismatchPayload::new(
"SwitchMLP: fc2 must be the [hidden_dims, input_dims] inverse of fc1 [input_dims, hidden_dims]",
vec![fc1.output_dims(), fc1.input_dims()],
vec![fc2.input_dims(), fc2.output_dims()],
),
));
}
if fc1.num_experts() != fc2.num_experts() {
return Err(crate::Error::LengthMismatch(LengthMismatchPayload::new(
"SwitchMLP: fc1 and fc2 num_experts",
fc1.num_experts(),
fc2.num_experts(),
)));
}
Ok(Self {
fc1,
fc2,
activation,
})
}
pub fn fc1(&self) -> &SwitchLinear {
&self.fc1
}
pub fn fc2(&self) -> &SwitchLinear {
&self.fc2
}
pub fn forward(&self, x: &Array, indices: &Array) -> Result<Array> {
check_routing_indices(x, indices)?;
let mut x = shape::expand_dims_axes(x, &[-2, -3])?;
let do_sort = indices.size() >= 64;
let mut idx = indices.try_clone()?;
let mut inv_order: Option<Array> = None;
if do_sort {
let (x_sorted, idx_sorted, inv) = gather_sort(&x, indices)?;
x = x_sorted;
idx = idx_sorted;
inv_order = Some(inv);
}
x = self.fc1.apply(&x, &idx, do_sort)?;
x = (self.activation)(&x)?;
x = self.fc2.apply(&x, &idx, do_sort)?;
if let Some(inv) = &inv_order {
x = scatter_unsort(&x, inv, &indices.shape())?;
}
shape::squeeze_axes(&x, &[-2])
}
}
impl std::fmt::Debug for SwitchMLP {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SwitchMLP")
.field("fc1", &self.fc1)
.field("fc2", &self.fc2)
.field("activation", &"<fn>")
.finish()
}
}
#[cfg(test)]
mod tests;