use metal::{
Buffer, BufferRef, CommandBufferRef, ComputePipelineState, MTLSize,
NSUInteger,
};
use moeflux_metal::{GatherQmmCall, Kernels, QuantWeights};
use crate::riir::backend::buftype::DeprecatedCogitoBuf;
use crate::riir::backend::{BufId, BufferPool, MetalBufferPool};
use crate::riir::backend::gpu::gpu_matvec::{encode_matvec_n_tokens, MatvecPipelines};
use crate::riir::backend::gpu::gpu_norm::{encode_rms_norm_bf16_into, RmsNormBf16Pipelines};
use crate::riir::backend::gpu::metal::{
buffer_as_mut_slice, buffer_as_slice, MetalContext, MetalError,
MtlBuffer,
};
use crate::riir::moe::moe_router::ExpertBuckets;
use crate::riir::variants::{SharedExpertGate, Variant, GROUP_SIZE, VARIANT};
fn combine_kernel_name() -> &'static str {
match VARIANT.shared_expert_gate {
SharedExpertGate::SigmoidGate => "moe_combine_residual",
SharedExpertGate::Unscaled => "moe_combine_residual_unscaled",
}
}
pub(crate) struct ChainToNormed<'a> {
pub pipes: &'a RmsNormBf16Pipelines,
pub wf_buf: &'a Buffer,
pub next_norm_off: u64,
pub combine_out: &'a Buffer,
pub chain_sum_sq: &'a Buffer,
pub chain_normed: &'a Buffer,
pub eps: f32,
}
pub const MAX_K: usize = 16;
#[derive(Clone, Copy)]
pub struct ExpertPayload<'a> {
pub h_post: &'a [f32],
pub h_mid: &'a [f32],
pub shared_out: &'a [f32],
pub expert_weights: &'a [f32],
pub shared_gate_score: f32,
}
#[derive(Debug, thiserror::Error)]
pub enum ExpertForwardError {
#[error(
"expert_data is the wrong length: expected {expected} bytes \
(4-bit layout), got {actual}"
)]
BadExpertDataLen { expected: usize, actual: usize },
#[error("h_post must be HIDDEN_DIM={expected} floats, got {actual}")]
BadHPostLen { expected: usize, actual: usize },
#[error("expert_out must be HIDDEN_DIM={expected} floats, got {actual}")]
BadExpertOutLen { expected: usize, actual: usize },
#[error("h_mid must be HIDDEN_DIM={expected} floats, got {actual}")]
BadHMidLen { expected: usize, actual: usize },
#[error("shared_out must be HIDDEN_DIM={expected} floats, got {actual}")]
BadSharedOutLen { expected: usize, actual: usize },
#[error("hidden_out must be HIDDEN_DIM={expected} floats, got {actual}")]
BadHiddenOutLen { expected: usize, actual: usize },
#[error(
"actual_K out of range: must be 1..={max}, got {actual}"
)]
BadK { actual: i32, max: usize },
#[error("expert_weights must be {expected} floats, got {actual}")]
BadWeightsLen { expected: usize, actual: usize },
#[error("Metal backend: {0}")]
Metal(#[from] MetalError),
}
pub fn gpu_expert_forward(
metal: &mut MetalContext,
expert_data: &[u8],
h_post: &[f32],
expert_out: &mut [f32],
) -> Result<(), ExpertForwardError> {
let v = VARIANT;
let expected_data_len = v.expert_size_4bit();
if expert_data.len() != expected_data_len {
return Err(ExpertForwardError::BadExpertDataLen {
expected: expected_data_len,
actual: expert_data.len(),
});
}
if h_post.len() != v.hidden_dim {
return Err(ExpertForwardError::BadHPostLen {
expected: v.hidden_dim,
actual: h_post.len(),
});
}
if expert_out.len() != v.hidden_dim {
return Err(ExpertForwardError::BadExpertOutLen {
expected: v.hidden_dim,
actual: expert_out.len(),
});
}
let matvec = metal.pipeline("dequant_matvec_4bit_v3")?.clone();
let swiglu = metal.pipeline("swiglu_fused")?.clone();
let device = metal.device();
let data = MtlBuffer::<u8>::with_data(device, expert_data);
let input = MtlBuffer::<f32>::with_data(device, h_post);
let gate_out = MtlBuffer::<f32>::with_len(device, v.moe_intermediate);
let up_out = MtlBuffer::<f32>::with_len(device, v.moe_intermediate);
let act = MtlBuffer::<f32>::with_len(device, v.moe_intermediate);
let out = MtlBuffer::<f32>::with_len(device, v.hidden_dim);
let cmdbuf = metal.queue().new_command_buffer();
encode_matvec(
cmdbuf,
&matvec,
&data,
v.gate_w_off_4bit(),
v.gate_s_off_4bit(),
v.gate_b_off_4bit(),
&input,
&gate_out,
v.moe_intermediate as u32,
v.hidden_dim as u32,
);
encode_matvec(
cmdbuf,
&matvec,
&data,
v.up_w_off_4bit(),
v.up_s_off_4bit(),
v.up_b_off_4bit(),
&input,
&up_out,
v.moe_intermediate as u32,
v.hidden_dim as u32,
);
encode_swiglu(
cmdbuf,
&swiglu,
&gate_out,
&up_out,
&act,
v.moe_intermediate as u32,
);
encode_matvec(
cmdbuf,
&matvec,
&data,
v.down_w_off_4bit(),
v.down_s_off_4bit(),
v.down_b_off_4bit(),
&act,
&out,
v.hidden_dim as u32,
v.moe_intermediate as u32,
);
cmdbuf.commit();
cmdbuf.wait_until_completed();
expert_out.copy_from_slice(&out.to_vec());
Ok(())
}
fn encode_matvec(
cmdbuf: &metal::CommandBufferRef,
pipeline: &metal::ComputePipelineState,
data: &MtlBuffer<u8>,
w_off: usize,
s_off: usize,
b_off: usize,
input: &MtlBuffer<f32>,
output: &MtlBuffer<f32>,
out_dim: u32,
in_dim: u32,
) {
let group_size = GROUP_SIZE as u32;
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(data.raw()), w_off as NSUInteger);
enc.set_buffer(1, Some(data.raw()), s_off as NSUInteger);
enc.set_buffer(2, Some(data.raw()), b_off as NSUInteger);
enc.set_buffer(3, Some(input.raw()), 0);
enc.set_buffer(4, Some(output.raw()), 0);
enc.set_bytes(5, 4, (&out_dim as *const u32).cast());
enc.set_bytes(6, 4, (&in_dim as *const u32).cast());
enc.set_bytes(7, 4, (&group_size as *const u32).cast());
let num_tgs = (out_dim + 7) / 8;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
fn encode_swiglu(
cmdbuf: &metal::CommandBufferRef,
pipeline: &metal::ComputePipelineState,
gate: &MtlBuffer<f32>,
up: &MtlBuffer<f32>,
act: &MtlBuffer<f32>,
dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(gate.raw()), 0);
enc.set_buffer(1, Some(up.raw()), 0);
enc.set_buffer(2, Some(act.raw()), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
pub struct MoeBuffers {
data_synced: [BufId<DeprecatedCogitoBuf>; MAX_K],
data_prefetch: [[BufId<DeprecatedCogitoBuf>; MAX_K]; 2],
gate: [BufId<DeprecatedCogitoBuf>; MAX_K],
up: [BufId<DeprecatedCogitoBuf>; MAX_K],
act: [BufId<DeprecatedCogitoBuf>; MAX_K],
out: [BufId<DeprecatedCogitoBuf>; MAX_K],
gate_flat: BufId<DeprecatedCogitoBuf>,
up_flat: BufId<DeprecatedCogitoBuf>,
act_flat: BufId<DeprecatedCogitoBuf>,
out_flat: BufId<DeprecatedCogitoBuf>,
input: BufId<DeprecatedCogitoBuf>,
h_mid: BufId<DeprecatedCogitoBuf>,
shared_out: BufId<DeprecatedCogitoBuf>,
moe_hidden: BufId<DeprecatedCogitoBuf>,
combine_params: BufId<DeprecatedCogitoBuf>,
gate_logits: BufId<DeprecatedCogitoBuf>,
}
impl MoeBuffers {
pub fn new(pool: &mut MetalBufferPool) -> Self {
let v: Variant = VARIANT;
const TWO_MIB: usize = 2 * 1024 * 1024;
let f32_size = std::mem::size_of::<f32>();
let data_synced: [BufId<DeprecatedCogitoBuf>; MAX_K] =
std::array::from_fn(|_| {
pool.alloc_aligned(
v.expert_size_4bit(),
TWO_MIB,
"moe.data_synced",
true,
)
});
let data_prefetch: [[BufId<DeprecatedCogitoBuf>; MAX_K]; 2] =
std::array::from_fn(|_| {
std::array::from_fn(|_| {
pool.alloc_aligned(
v.expert_size_4bit(),
TWO_MIB,
"moe.data_prefetch",
true,
)
})
});
let probe = |label: &str, ids: &[BufId<DeprecatedCogitoBuf>]| {
for (slot, &id) in ids.iter().enumerate() {
let addr = pool.handle(id).contents() as usize;
debug_assert_eq!(
addr % TWO_MIB,
0,
"data_{label} slot {slot} not 2 MB aligned (contents=0x{addr:x})",
);
}
};
probe("synced", &data_synced[..]);
probe("prefetch[0]", &data_prefetch[0][..]);
probe("prefetch[1]", &data_prefetch[1][..]);
let gate: [BufId<DeprecatedCogitoBuf>; MAX_K] =
std::array::from_fn(|_| {
pool.alloc(
v.moe_intermediate * f32_size,
"moe.gate",
true,
)
.expect("pool.alloc moe.gate")
});
let up: [BufId<DeprecatedCogitoBuf>; MAX_K] = std::array::from_fn(
|_| {
pool.alloc(v.moe_intermediate * f32_size, "moe.up", true)
.expect("pool.alloc moe.up")
},
);
let act: [BufId<DeprecatedCogitoBuf>; MAX_K] = std::array::from_fn(
|_| {
pool.alloc(v.moe_intermediate * f32_size, "moe.act", true)
.expect("pool.alloc moe.act")
},
);
let out: [BufId<DeprecatedCogitoBuf>; MAX_K] = std::array::from_fn(
|_| {
pool.alloc(v.hidden_dim * f32_size, "moe.out", true)
.expect("pool.alloc moe.out")
},
);
let k_max = v.num_experts_per_tok;
let gate_flat = pool
.alloc(k_max * v.moe_intermediate * f32_size, "moe.gate_flat", true)
.expect("pool.alloc moe.gate_flat");
let up_flat = pool
.alloc(k_max * v.moe_intermediate * f32_size, "moe.up_flat", true)
.expect("pool.alloc moe.up_flat");
let act_flat = pool
.alloc(k_max * v.moe_intermediate * f32_size, "moe.act_flat", true)
.expect("pool.alloc moe.act_flat");
let out_flat = pool
.alloc(k_max * v.hidden_dim * f32_size, "moe.out_flat", true)
.expect("pool.alloc moe.out_flat");
let input: BufId<DeprecatedCogitoBuf> = pool
.alloc(v.hidden_dim * f32_size, "moe.input", true)
.expect("pool.alloc moe.input");
let h_mid: BufId<DeprecatedCogitoBuf> = pool
.alloc(v.hidden_dim * f32_size, "moe.h_mid", true)
.expect("pool.alloc moe.h_mid");
let shared_out: BufId<DeprecatedCogitoBuf> = pool
.alloc(v.hidden_dim * f32_size, "moe.shared_out", true)
.expect("pool.alloc moe.shared_out");
let moe_hidden: BufId<DeprecatedCogitoBuf> = pool
.alloc(v.hidden_dim * f32_size, "moe.moe_hidden", true)
.expect("pool.alloc moe.moe_hidden");
let combine_params: BufId<DeprecatedCogitoBuf> = pool
.alloc(18 * f32_size, "moe.combine_params", true)
.expect("pool.alloc moe.combine_params");
let gate_logits: BufId<DeprecatedCogitoBuf> = pool
.alloc(
v.num_experts.max(1) * f32_size,
"moe.gate_logits",
true,
)
.expect("pool.alloc moe.gate_logits");
Self {
data_synced,
data_prefetch,
gate,
up,
act,
out,
gate_flat,
up_flat,
act_flat,
out_flat,
input,
h_mid,
shared_out,
moe_hidden,
combine_params,
gate_logits,
}
}
pub(crate) fn moe_hidden_id(&self) -> BufId<DeprecatedCogitoBuf> {
self.moe_hidden
}
pub(crate) fn out_id(&self, slot: usize) -> BufId<DeprecatedCogitoBuf> {
self.out[slot]
}
pub(crate) fn gate_id(&self, slot: usize) -> BufId<DeprecatedCogitoBuf> {
self.gate[slot]
}
pub(crate) fn up_id(&self, slot: usize) -> BufId<DeprecatedCogitoBuf> {
self.up[slot]
}
pub(crate) fn act_id(&self, slot: usize) -> BufId<DeprecatedCogitoBuf> {
self.act[slot]
}
pub(crate) fn h_mid_id(&self) -> BufId<DeprecatedCogitoBuf> {
self.h_mid
}
pub(crate) fn shared_out_id(&self) -> BufId<DeprecatedCogitoBuf> {
self.shared_out
}
pub(crate) fn combine_params_id(&self) -> BufId<DeprecatedCogitoBuf> {
self.combine_params
}
pub(crate) fn data_synced_id(
&self,
slot: usize,
) -> BufId<DeprecatedCogitoBuf> {
self.data_synced[slot]
}
pub(crate) fn data_prefetch_id(
&self,
set: usize,
slot: usize,
) -> BufId<DeprecatedCogitoBuf> {
debug_assert!(set < 2);
debug_assert!(slot < MAX_K);
self.data_prefetch[set][slot]
}
pub(crate) fn gate_flat_id(&self) -> BufId<DeprecatedCogitoBuf> {
self.gate_flat
}
pub(crate) fn up_flat_id(&self) -> BufId<DeprecatedCogitoBuf> {
self.up_flat
}
pub(crate) fn act_flat_id(&self) -> BufId<DeprecatedCogitoBuf> {
self.act_flat
}
pub(crate) fn out_flat_id(&self) -> BufId<DeprecatedCogitoBuf> {
self.out_flat
}
pub(crate) fn data_synced_slots_mut_array<'p>(
&self,
pool: &'p MetalBufferPool,
) -> [&'p mut [u8]; MAX_K] {
pool.as_mut_slices_u8(self.data_synced)
}
pub(crate) fn data_prefetch_slots_mut_array<'p>(
&self,
pool: &'p MetalBufferPool,
set: usize,
) -> [&'p mut [u8]; MAX_K] {
debug_assert!(set < 2, "prefetch set index must be 0 or 1");
pool.as_mut_slices_u8(self.data_prefetch[set])
}
pub(crate) fn input_buffer<'p>(
&self,
pool: &'p MetalBufferPool,
) -> &'p metal::Buffer {
pool.handle(self.input)
}
pub(crate) fn h_mid_buffer<'p>(
&self,
pool: &'p MetalBufferPool,
) -> &'p metal::Buffer {
pool.handle(self.h_mid)
}
pub(crate) fn shared_out_buffer<'p>(
&self,
pool: &'p MetalBufferPool,
) -> &'p metal::Buffer {
pool.handle(self.shared_out)
}
pub(crate) fn stage_host_input(
&self,
pool: &MetalBufferPool,
hidden: &[f32],
) {
debug_assert_eq!(hidden.len(), VARIANT.hidden_dim);
let buf = pool.handle(self.input);
let dst: &mut [f32] =
unsafe { buffer_as_mut_slice::<f32>(buf, VARIANT.hidden_dim) };
dst.copy_from_slice(hidden);
}
pub(crate) fn stage_host_h_mid_zero(&self, pool: &MetalBufferPool) {
let buf = pool.handle(self.h_mid);
let dst: &mut [f32] =
unsafe { buffer_as_mut_slice::<f32>(buf, VARIANT.hidden_dim) };
dst.fill(0.0);
}
pub(crate) fn moe_hidden_to_vec(
&self,
pool: &MetalBufferPool,
) -> Vec<f32> {
let buf = pool.handle(self.moe_hidden);
let src: &[f32] =
unsafe { buffer_as_slice::<f32>(buf, VARIANT.hidden_dim) };
src.to_vec()
}
pub fn moe_hidden_ref<'p>(
&self,
pool: &'p MetalBufferPool,
) -> &'p metal::Buffer {
pool.handle(self.moe_hidden)
}
pub(crate) fn gate_logits_buffer<'p>(
&self,
pool: &'p MetalBufferPool,
) -> &'p metal::Buffer {
pool.handle(self.gate_logits)
}
pub(crate) fn gate_logits_to_vec(
&self,
pool: &MetalBufferPool,
) -> Vec<f32> {
let buf = pool.handle(self.gate_logits);
let n = VARIANT.num_experts.max(1);
let src: &[f32] = unsafe { buffer_as_slice::<f32>(buf, n) };
src.to_vec()
}
}
impl std::fmt::Debug for MoeBuffers {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MoeBuffers")
.field("max_k", &MAX_K)
.field("hidden_dim", &VARIANT.hidden_dim)
.field("moe_intermediate", &VARIANT.moe_intermediate)
.field("expert_size_4bit", &VARIANT.expert_size_4bit())
.finish()
}
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_batched_experts_forward(
metal: &mut MetalContext,
bufs: &mut MoeBuffers,
buffer_pool: &MetalBufferPool,
actual_k: i32,
expert_data: &[u8],
payload: ExpertPayload<'_>,
hidden_out: &mut [f32],
) -> Result<(), ExpertForwardError> {
let v = VARIANT;
if hidden_out.len() != v.hidden_dim {
return Err(ExpertForwardError::BadHiddenOutLen {
expected: v.hidden_dim,
actual: hidden_out.len(),
});
}
let cmdbuf = gpu_batched_experts_encode(
metal,
bufs,
buffer_pool,
actual_k,
expert_data,
payload,
true,
)?;
cmdbuf.commit();
cmdbuf.wait_until_completed();
hidden_out.copy_from_slice(&bufs.moe_hidden_to_vec(buffer_pool));
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn gpu_batched_experts_encode(
metal: &mut MetalContext,
bufs: &mut MoeBuffers,
buffer_pool: &MetalBufferPool,
actual_k: i32,
expert_data: &[u8],
payload: ExpertPayload<'_>,
_gpu_combine: bool,
) -> Result<metal::CommandBuffer, ExpertForwardError> {
let ExpertPayload {
h_post,
h_mid,
shared_out,
expert_weights,
shared_gate_score,
} = payload;
let v = VARIANT;
validate_inputs(actual_k, expert_data, expert_weights)?;
let k = actual_k as usize;
if h_post.len() != v.hidden_dim {
return Err(ExpertForwardError::BadHPostLen {
expected: v.hidden_dim,
actual: h_post.len(),
});
}
if h_mid.len() != v.hidden_dim {
return Err(ExpertForwardError::BadHMidLen {
expected: v.hidden_dim,
actual: h_mid.len(),
});
}
if shared_out.len() != v.hidden_dim {
return Err(ExpertForwardError::BadSharedOutLen {
expected: v.hidden_dim,
actual: shared_out.len(),
});
}
let matvec = MatvecPipelines::fetch(metal)?;
let swiglu = metal.pipeline("swiglu_fused")?.clone();
let v3_experts = metal.pipeline("dequant_matvec_4bit_v3_experts").ok().cloned();
let combine_flat = metal.pipeline("moe_combine_residual_flat")?.clone();
let expert_size = v.expert_size_4bit();
for slot in 0..k {
let src = &expert_data[slot * expert_size..(slot + 1) * expert_size];
let dst = buffer_pool.as_mut_slice_u8(bufs.data_synced_id(slot));
dst.copy_from_slice(src);
}
bufs.stage_host_input(buffer_pool, h_post);
{
let buf = buffer_pool.handle(bufs.h_mid_id());
let dst: &mut [f32] =
unsafe { buffer_as_mut_slice::<f32>(buf, v.hidden_dim) };
dst.copy_from_slice(h_mid);
}
{
let buf = buffer_pool.handle(bufs.shared_out_id());
let dst: &mut [f32] =
unsafe { buffer_as_mut_slice::<f32>(buf, v.hidden_dim) };
dst.copy_from_slice(shared_out);
}
{
let buf = buffer_pool.handle(bufs.combine_params_id());
let params: &mut [f32] =
unsafe { buffer_as_mut_slice::<f32>(buf, 18) };
params.fill(0.0);
params[..k].copy_from_slice(expert_weights);
params[16] = shared_gate_score;
}
let bindings: Vec<(&Buffer, u64)> = (0..k)
.map(|slot| (buffer_pool.handle(bufs.data_synced_id(slot)), 0u64))
.collect();
let cmdbuf = metal.queue().new_command_buffer();
let input_buf = bufs.input_buffer(buffer_pool).clone();
let h_mid_buf = bufs.h_mid_buffer(buffer_pool).clone();
let shared_out_buf = bufs.shared_out_buffer(buffer_pool).clone();
emit_batched_experts(
cmdbuf, &matvec, &swiglu, v3_experts.as_ref(), &combine_flat,
bufs, buffer_pool, &input_buf, &h_mid_buf, &shared_out_buf,
k, v, &bindings, None,
);
Ok(cmdbuf.to_owned())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn gpu_batched_experts_encode_mmap(
metal: &mut MetalContext,
bufs: &mut MoeBuffers,
buffer_pool: &MetalBufferPool,
actual_k: i32,
input: &BufferRef,
h_mid: &BufferRef,
shared_out: &BufferRef,
expert_weights: &[f32],
shared_gate_score: f32,
expert_bindings: &[(&Buffer, u64)],
chain: Option<ChainToNormed<'_>>,
) -> Result<metal::CommandBuffer, ExpertForwardError> {
let v = VARIANT;
if actual_k < 1 || (actual_k as usize) > MAX_K {
return Err(ExpertForwardError::BadK {
actual: actual_k,
max: MAX_K,
});
}
let k = actual_k as usize;
if expert_weights.len() != k {
return Err(ExpertForwardError::BadWeightsLen {
expected: k,
actual: expert_weights.len(),
});
}
let matvec = MatvecPipelines::fetch(metal)?;
let swiglu = metal.pipeline("swiglu_fused")?.clone();
let v3_experts = metal.pipeline("dequant_matvec_4bit_v3_experts").ok().cloned();
let combine_flat = metal.pipeline("moe_combine_residual_flat")?.clone();
{
let buf = buffer_pool.handle(bufs.combine_params_id());
let params: &mut [f32] =
unsafe { buffer_as_mut_slice::<f32>(buf, 18) };
params.fill(0.0);
params[..k].copy_from_slice(expert_weights);
params[16] = shared_gate_score;
}
let cmdbuf = metal.queue().new_command_buffer();
emit_batched_experts(
cmdbuf, &matvec, &swiglu, v3_experts.as_ref(), &combine_flat,
bufs, buffer_pool, input, h_mid, shared_out,
k, v, expert_bindings, chain,
);
Ok(cmdbuf.to_owned())
}
fn validate_inputs(
actual_k: i32,
expert_data: &[u8],
expert_weights: &[f32],
) -> Result<(), ExpertForwardError> {
let v = VARIANT;
if actual_k < 1 || (actual_k as usize) > MAX_K {
return Err(ExpertForwardError::BadK {
actual: actual_k,
max: MAX_K,
});
}
let k = actual_k as usize;
let expected_data_len = k * v.expert_size_4bit();
if expert_data.len() != expected_data_len {
return Err(ExpertForwardError::BadExpertDataLen {
expected: expected_data_len,
actual: expert_data.len(),
});
}
if expert_weights.len() != k {
return Err(ExpertForwardError::BadWeightsLen {
expected: k,
actual: expert_weights.len(),
});
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn emit_batched_experts(
cmdbuf: &CommandBufferRef,
matvec: &MatvecPipelines,
swiglu: &ComputePipelineState,
v3_experts: Option<&ComputePipelineState>,
combine_flat: &ComputePipelineState,
bufs: &MoeBuffers,
buffer_pool: &MetalBufferPool,
input: &BufferRef,
h_mid: &BufferRef,
shared_out: &BufferRef,
k: usize,
v: Variant,
expert_bindings: &[(&Buffer, u64)],
chain: Option<ChainToNormed<'_>>,
) {
{
use std::sync::atomic::{AtomicBool, Ordering};
static LOGGED: AtomicBool = AtomicBool::new(false);
if !LOGGED.swap(true, Ordering::Relaxed) {
eprintln!(
"[moe] expert dispatch: hidden_dim={} k={} v3_experts={} → {}",
v.hidden_dim, k, v3_experts.is_some(),
if v.hidden_dim <= 4096 && k <= 8 && v3_experts.is_some() {
"batched"
} else {
"per-expert (fallback)"
},
);
}
}
let use_batched = v.hidden_dim <= 4096
&& k <= 8
&& v3_experts.is_some();
if use_batched {
let v3e = v3_experts.unwrap();
let gate_flat = buffer_pool.handle(bufs.gate_flat_id());
let up_flat = buffer_pool.handle(bufs.up_flat_id());
let act_flat = buffer_pool.handle(bufs.act_flat_id());
let out_flat = buffer_pool.handle(bufs.out_flat_id());
let group_size = GROUP_SIZE as u32;
let k_u32 = k as u32;
{
let enc = cmdbuf.new_compute_command_encoder();
encode_matvec_experts(
enc, v3e, expert_bindings, input, gate_flat,
v.moe_intermediate as u32, v.hidden_dim as u32, group_size,
v.gate_w_off_4bit() as u32,
v.gate_s_off_4bit() as u32,
v.gate_b_off_4bit() as u32,
k_u32, 0,
);
encode_matvec_experts(
enc, v3e, expert_bindings, input, up_flat,
v.moe_intermediate as u32, v.hidden_dim as u32, group_size,
v.up_w_off_4bit() as u32,
v.up_s_off_4bit() as u32,
v.up_b_off_4bit() as u32,
k_u32, 0,
);
enc.end_encoding();
}
{
let enc = cmdbuf.new_compute_command_encoder();
encode_swiglu_into_buf(
enc, swiglu, gate_flat, up_flat, act_flat,
k_u32 * v.moe_intermediate as u32,
);
enc.end_encoding();
}
{
let enc = cmdbuf.new_compute_command_encoder();
encode_matvec_experts(
enc, v3e, expert_bindings, act_flat, out_flat,
v.hidden_dim as u32, v.moe_intermediate as u32, group_size,
v.down_w_off_4bit() as u32,
v.down_s_off_4bit() as u32,
v.down_b_off_4bit() as u32,
k_u32, v.moe_intermediate as u32,
);
enc.end_encoding();
}
} else {
for slot in 0..k {
let (wb, base_off) = expert_bindings[slot];
let gate_buf = buffer_pool.handle(bufs.gate_id(slot));
let up_buf = buffer_pool.handle(bufs.up_id(slot));
let act_buf = buffer_pool.handle(bufs.act_id(slot));
let out_buf = buffer_pool.handle(bufs.out_id(slot));
{
let enc = cmdbuf.new_compute_command_encoder();
encode_matvec_into(
enc, matvec, wb, base_off,
v.gate_w_off_4bit(), v.gate_s_off_4bit(), v.gate_b_off_4bit(),
input, gate_buf, v.moe_intermediate as u32, v.hidden_dim as u32,
);
encode_matvec_into(
enc, matvec, wb, base_off,
v.up_w_off_4bit(), v.up_s_off_4bit(), v.up_b_off_4bit(),
input, up_buf, v.moe_intermediate as u32, v.hidden_dim as u32,
);
enc.end_encoding();
}
{
let enc = cmdbuf.new_compute_command_encoder();
encode_swiglu_into_buf(
enc, swiglu, gate_buf, up_buf, act_buf,
v.moe_intermediate as u32,
);
encode_matvec_into(
enc, matvec, wb, base_off,
v.down_w_off_4bit(), v.down_s_off_4bit(), v.down_b_off_4bit(),
act_buf, out_buf, v.hidden_dim as u32, v.moe_intermediate as u32,
);
enc.end_encoding();
}
}
}
{
if use_batched {
let moe_hidden_buf = buffer_pool.handle(bufs.moe_hidden_id());
let combine_out: &BufferRef = match chain.as_ref() {
Some(c) => c.combine_out,
None => moe_hidden_buf,
};
let combine_params_buf =
buffer_pool.handle(bufs.combine_params_id());
let enc = cmdbuf.new_compute_command_encoder();
encode_combine_flat(
enc, combine_flat, h_mid, shared_out, combine_out,
buffer_pool.handle(bufs.out_flat_id()),
combine_params_buf,
v.hidden_dim as u32, k as u32,
);
enc.end_encoding();
} else {
panic!(
"[moe] per-expert fallback combine not wired for mmap \
(hidden_dim={}, k={}). This path is only for Cogito-V2.",
v.hidden_dim, k,
);
}
if let Some(c) = chain {
encode_rms_norm_bf16_into(
cmdbuf,
c.pipes,
c.combine_out,
c.wf_buf,
c.next_norm_off,
c.chain_sum_sq,
c.chain_normed,
v.hidden_dim as u32,
c.eps,
);
}
}
}
fn encode_matvec_into(
enc: &metal::ComputeCommandEncoderRef,
pipes: &MatvecPipelines,
data: &Buffer,
base_off: u64,
w_off: usize,
s_off: usize,
b_off: usize,
input: &BufferRef,
output: &Buffer,
out_dim: u32,
in_dim: u32,
) {
let group_size = GROUP_SIZE as u32;
let use_v3 = in_dim <= 4096;
let pipeline = if use_v3 {
&pipes.v3_4bit
} else {
&pipes.fast_4bit
};
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(data), (base_off as usize + w_off) as NSUInteger);
enc.set_buffer(1, Some(data), (base_off as usize + s_off) as NSUInteger);
enc.set_buffer(2, Some(data), (base_off as usize + b_off) as NSUInteger);
enc.set_buffer(3, Some(input), 0);
enc.set_buffer(4, Some(output), 0);
enc.set_bytes(5, 4, (&out_dim as *const u32).cast());
enc.set_bytes(6, 4, (&in_dim as *const u32).cast());
enc.set_bytes(7, 4, (&group_size as *const u32).cast());
if use_v3 {
let num_tgs = (out_dim + 7) / 8;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
} else {
enc.dispatch_thread_groups(
MTLSize::new(out_dim as NSUInteger, 1, 1),
MTLSize::new(64, 1, 1),
);
}
}
fn encode_swiglu_into_buf(
enc: &metal::ComputeCommandEncoderRef,
pipeline: &metal::ComputePipelineState,
gate: &Buffer,
up: &Buffer,
act: &Buffer,
dim: u32,
) {
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(gate), 0);
enc.set_buffer(1, Some(up), 0);
enc.set_buffer(2, Some(act), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
}
fn encode_matvec_experts(
enc: &metal::ComputeCommandEncoderRef,
pipeline: &ComputePipelineState,
expert_bindings: &[(&Buffer, u64)],
input: &BufferRef,
output: &Buffer,
out_dim: u32,
in_dim: u32,
group_size: u32,
w_byte_off: u32,
s_byte_off: u32,
b_byte_off: u32,
k: u32,
input_stride: u32,
) {
enc.set_compute_pipeline_state(pipeline);
for i in 0..8usize {
let (buf, off) = if i < expert_bindings.len() {
expert_bindings[i]
} else {
expert_bindings[0]
};
enc.set_buffer(i as NSUInteger, Some(buf), off as NSUInteger);
}
enc.set_buffer(8, Some(input), 0);
enc.set_buffer(9, Some(output), 0);
enc.set_bytes(10, 4, (&out_dim as *const u32).cast());
enc.set_bytes(11, 4, (&in_dim as *const u32).cast());
enc.set_bytes(12, 4, (&group_size as *const u32).cast());
enc.set_bytes(13, 4, (&w_byte_off as *const u32).cast());
enc.set_bytes(14, 4, (&s_byte_off as *const u32).cast());
enc.set_bytes(15, 4, (&b_byte_off as *const u32).cast());
let num_row_tiles = (out_dim + 7) / 8;
enc.set_bytes(16, 4, (&num_row_tiles as *const u32).cast());
enc.set_bytes(17, 4, (&k as *const u32).cast());
enc.set_bytes(18, 4, (&input_stride as *const u32).cast());
let total_tgs = num_row_tiles * k;
enc.dispatch_thread_groups(
MTLSize::new(total_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
}
fn encode_combine_flat(
enc: &metal::ComputeCommandEncoderRef,
pipeline: &ComputePipelineState,
h_mid: &BufferRef,
shared_out: &BufferRef,
hidden_out: &BufferRef,
expert_out_flat: &Buffer,
params: &Buffer,
dim: u32,
k: u32,
) {
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(h_mid), 0);
enc.set_buffer(1, Some(shared_out), 0);
enc.set_buffer(2, Some(hidden_out), 0);
enc.set_buffer(3, Some(expert_out_flat), 0);
enc.set_buffer(4, Some(params), 0);
enc.set_bytes(5, 4, (&dim as *const u32).cast());
enc.set_bytes(6, 4, (&k as *const u32).cast());
let tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
}
#[allow(clippy::too_many_arguments)]
pub fn encode_moe_batched_permute_fuse(
cmdbuf: &CommandBufferRef,
matvec: &MatvecPipelines,
kernels: &Kernels,
swiglu: &ComputePipelineState,
bucket_accumulate: &ComputePipelineState,
expert_base: &Buffer,
expert_stride: u64,
expert_indices: &Buffer,
expert_slots: &[u32],
bucket_input: &Buffer,
bucket_gate: &Buffer,
bucket_up: &Buffer,
bucket_act: &Buffer,
bucket_out: &Buffer,
bucket_token_idx: &Buffer,
bucket_weights: &Buffer,
out_sum: &Buffer,
buckets: &ExpertBuckets,
v: Variant,
gather: bool,
) {
debug_assert_eq!(expert_slots.len(), buckets.expert_ids.len());
if gather {
encode_moe_gather(
cmdbuf, kernels, swiglu, bucket_accumulate, expert_base,
expert_stride, expert_indices, bucket_input, bucket_gate,
bucket_up, bucket_act, bucket_out, bucket_token_idx,
bucket_weights, out_sum, buckets, v,
);
} else {
encode_moe_per_bucket(
cmdbuf, matvec, swiglu, bucket_accumulate, expert_base,
expert_stride, expert_slots, bucket_input, bucket_gate,
bucket_up, bucket_act, bucket_out, bucket_token_idx,
bucket_weights, out_sum, buckets, v,
);
}
}
fn encode_swiglu_at(
cmdbuf: &CommandBufferRef,
swiglu: &ComputePipelineState,
gate: &Buffer,
up: &Buffer,
act: &Buffer,
off: u64,
dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(swiglu);
enc.set_buffer(0, Some(gate), off as NSUInteger);
enc.set_buffer(1, Some(up), off as NSUInteger);
enc.set_buffer(2, Some(act), off as NSUInteger);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
#[allow(clippy::too_many_arguments)]
fn encode_bucket_scatter(
cmdbuf: &CommandBufferRef,
bucket_accumulate: &ComputePipelineState,
bucket_out: &Buffer,
bucket_token_idx: &Buffer,
bucket_weights: &Buffer,
out_sum: &Buffer,
start: u64,
b_size: u32,
hidden_dim: u32,
) {
let f32_sz = std::mem::size_of::<f32>() as u64;
let i32_sz = std::mem::size_of::<i32>() as u64;
let out_off = start * hidden_dim as u64 * f32_sz;
let idx_off = start * i32_sz;
let w_off_b = start * f32_sz;
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(bucket_accumulate);
enc.set_buffer(0, Some(bucket_out), out_off as NSUInteger);
enc.set_buffer(1, Some(bucket_token_idx), idx_off as NSUInteger);
enc.set_buffer(2, Some(bucket_weights), w_off_b as NSUInteger);
enc.set_buffer(3, Some(out_sum), 0);
enc.set_bytes(4, 4, (&hidden_dim as *const u32).cast());
enc.set_bytes(5, 4, (&b_size as *const u32).cast());
let tgs_x = b_size as NSUInteger;
let tgs_y = ((hidden_dim + 255) / 256) as NSUInteger;
enc.dispatch_thread_groups(
MTLSize::new(tgs_x, tgs_y, 1),
MTLSize::new(1, 256, 1),
);
enc.end_encoding();
}
#[allow(clippy::too_many_arguments)]
fn encode_moe_gather(
cmdbuf: &CommandBufferRef,
kernels: &Kernels,
swiglu: &ComputePipelineState,
bucket_accumulate: &ComputePipelineState,
expert_base: &Buffer,
expert_stride: u64,
expert_indices: &Buffer,
bucket_input: &Buffer,
bucket_gate: &Buffer,
bucket_up: &Buffer,
bucket_act: &Buffer,
bucket_out: &Buffer,
bucket_token_idx: &Buffer,
bucket_weights: &Buffer,
out_sum: &Buffer,
buckets: &ExpertBuckets,
v: Variant,
) {
let total = buckets.token_idx.len() as u32;
if total == 0 {
return;
}
let hidden_dim = v.hidden_dim as u32;
let moe_inter = v.moe_intermediate as u32;
let stride_s = expert_stride / 2;
let gather = |w_off: u64, s_off: u64, b_off: u64, input: &Buffer,
output: &Buffer, in_dim: u32, out_dim: u32| {
kernels.encode(
cmdbuf,
&GatherQmmCall {
weights: QuantWeights {
buffer: expert_base,
packed_offset: w_off,
scales_offset: s_off,
biases_offset: b_off,
},
input,
input_offset: 0,
output,
output_offset: 0,
indices: expert_indices,
indices_offset: 0,
in_dim,
out_dim,
n_tokens: total,
stride_w: expert_stride,
stride_s,
},
);
};
gather(
v.gate_w_off_4bit() as u64,
v.gate_s_off_4bit() as u64,
v.gate_b_off_4bit() as u64,
bucket_input,
bucket_gate,
hidden_dim,
moe_inter,
);
gather(
v.up_w_off_4bit() as u64,
v.up_s_off_4bit() as u64,
v.up_b_off_4bit() as u64,
bucket_input,
bucket_up,
hidden_dim,
moe_inter,
);
encode_swiglu_at(
cmdbuf,
swiglu,
bucket_gate,
bucket_up,
bucket_act,
0,
total * moe_inter,
);
gather(
v.down_w_off_4bit() as u64,
v.down_s_off_4bit() as u64,
v.down_b_off_4bit() as u64,
bucket_act,
bucket_out,
moe_inter,
hidden_dim,
);
for bi in 0..buckets.expert_ids.len() {
let start = buckets.offsets[bi] as u64;
let b_size = (buckets.offsets[bi + 1] - buckets.offsets[bi]) as u32;
if b_size == 0 {
continue;
}
encode_bucket_scatter(
cmdbuf, bucket_accumulate, bucket_out, bucket_token_idx,
bucket_weights, out_sum, start, b_size, hidden_dim,
);
}
}
#[allow(clippy::too_many_arguments)]
pub fn encode_moe_gather_id_fuse(
cmdbuf: &CommandBufferRef,
kernels: &Kernels,
swiglu: &ComputePipelineState,
expert_base: &Buffer,
expert_stride: u64,
indices: &Buffer,
weights: &Buffer,
mlp_in: &Buffer,
out_sum: &Buffer,
htpe: &Buffer,
hids: &Buffer,
gate_mid: &Buffer,
up_mid: &Buffer,
down_mid: &Buffer,
n_tokens: u32,
n_experts: u32,
k: u32,
v: Variant,
) {
if n_tokens == 0 {
return;
}
let hidden_dim = v.hidden_dim as u32;
let moe_inter = v.moe_intermediate as u32;
kernels.encode(
cmdbuf,
&moeflux_metal::MoeIdMap0Call {
indices,
indices_offset: 0,
htpe,
htpe_offset: 0,
hids,
hids_offset: 0,
n_experts,
n_tokens,
k,
},
);
let mm_id = |src1: &Buffer,
dst: &Buffer,
k_in: u32,
n_out: u32,
packed_off: u64,
scales_off: u64,
biases_off: u64,
ne11: u32,
nb11: u64,
nb12: u64| {
kernels.encode(
cmdbuf,
&moeflux_metal::MoeGatherIdCall {
src0: expert_base,
src0_offset: 0,
src1,
src1_offset: 0,
htpe,
htpe_offset: 0,
hids,
hids_offset: 0,
dst,
dst_offset: 0,
k_in,
n_out,
n_experts,
n_tokens,
k,
ne11,
nb02: expert_stride,
nb01_w: (k_in / 2) as u64,
nb01_s: (k_in / 32) as u64, packed_off,
scales_off,
biases_off,
nb10: 4,
nb11,
nb12,
},
);
};
mm_id(
mlp_in,
gate_mid,
hidden_dim,
moe_inter,
v.gate_w_off_4bit() as u64,
v.gate_s_off_4bit() as u64,
v.gate_b_off_4bit() as u64,
1,
0,
(hidden_dim * 4) as u64,
);
mm_id(
mlp_in,
up_mid,
hidden_dim,
moe_inter,
v.up_w_off_4bit() as u64,
v.up_s_off_4bit() as u64,
v.up_b_off_4bit() as u64,
1,
0,
(hidden_dim * 4) as u64,
);
encode_swiglu_at(
cmdbuf,
swiglu,
gate_mid,
up_mid,
gate_mid, 0,
n_tokens * k * moe_inter,
);
mm_id(
gate_mid,
down_mid,
moe_inter,
hidden_dim,
v.down_w_off_4bit() as u64,
v.down_s_off_4bit() as u64,
v.down_b_off_4bit() as u64,
k,
(moe_inter * 4) as u64,
(k * moe_inter * 4) as u64,
);
kernels.encode(
cmdbuf,
&moeflux_metal::MoeCombineTopkCall {
mid: down_mid,
mid_offset: 0,
weights,
weights_offset: 0,
out: out_sum,
out_offset: 0,
n_tokens,
hidden_dim,
k,
},
);
}
#[allow(clippy::too_many_arguments)]
fn encode_moe_per_bucket(
cmdbuf: &CommandBufferRef,
matvec: &MatvecPipelines,
swiglu: &ComputePipelineState,
bucket_accumulate: &ComputePipelineState,
expert_base: &Buffer,
expert_stride: u64,
expert_slots: &[u32],
bucket_input: &Buffer,
bucket_gate: &Buffer,
bucket_up: &Buffer,
bucket_act: &Buffer,
bucket_out: &Buffer,
bucket_token_idx: &Buffer,
bucket_weights: &Buffer,
out_sum: &Buffer,
buckets: &ExpertBuckets,
v: Variant,
) {
let hidden_dim = v.hidden_dim as u32;
let moe_inter = v.moe_intermediate as u32;
let f32_sz = std::mem::size_of::<f32>() as u64;
for (bi, &slot) in expert_slots.iter().enumerate() {
let start = buckets.offsets[bi] as u64;
let end = buckets.offsets[bi + 1] as u64;
let b_size = (end - start) as u32;
if b_size == 0 {
continue;
}
let expert_off = slot as u64 * expert_stride;
let in_off = start * v.hidden_dim as u64 * f32_sz;
let mid_off = start * v.moe_intermediate as u64 * f32_sz;
let out_off = start * v.hidden_dim as u64 * f32_sz;
encode_matvec_n_tokens(
cmdbuf,
matvec,
expert_base,
expert_off + v.gate_w_off_4bit() as u64,
expert_off + v.gate_s_off_4bit() as u64,
expert_off + v.gate_b_off_4bit() as u64,
bucket_input,
in_off,
bucket_gate,
mid_off,
v.hidden_dim as u32,
moe_inter,
b_size,
4,
);
encode_matvec_n_tokens(
cmdbuf,
matvec,
expert_base,
expert_off + v.up_w_off_4bit() as u64,
expert_off + v.up_s_off_4bit() as u64,
expert_off + v.up_b_off_4bit() as u64,
bucket_input,
in_off,
bucket_up,
mid_off,
v.hidden_dim as u32,
moe_inter,
b_size,
4,
);
encode_swiglu_at(
cmdbuf,
swiglu,
bucket_gate,
bucket_up,
bucket_act,
mid_off,
b_size * moe_inter,
);
encode_matvec_n_tokens(
cmdbuf,
matvec,
expert_base,
expert_off + v.down_w_off_4bit() as u64,
expert_off + v.down_s_off_4bit() as u64,
expert_off + v.down_b_off_4bit() as u64,
bucket_act,
mid_off,
bucket_out,
out_off,
moe_inter,
v.hidden_dim as u32,
b_size,
4,
);
encode_bucket_scatter(
cmdbuf, bucket_accumulate, bucket_out, bucket_token_idx,
bucket_weights, out_sum, start, b_size, hidden_dim,
);
}
}
#[allow(clippy::too_many_arguments)]
pub fn encode_moe_combine_residual_n_tokens(
cmdbuf: &CommandBufferRef,
pipeline: &ComputePipelineState,
h_mid: &Buffer,
moe_sum: &Buffer,
shared_out: &Buffer,
shared_gate: &Buffer,
hidden_out: &Buffer,
n_tokens: u32,
dim: u32,
) {
let total = n_tokens * dim;
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(h_mid), 0);
enc.set_buffer(1, Some(moe_sum), 0);
enc.set_buffer(2, Some(shared_out), 0);
enc.set_buffer(3, Some(shared_gate), 0);
enc.set_buffer(4, Some(hidden_out), 0);
enc.set_bytes(5, 4, (&n_tokens as *const u32).cast());
enc.set_bytes(6, 4, (&dim as *const u32).cast());
let num_tgs = (total + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "needs Metal device + access to shaders.metal source"]
fn gpu_expert_forward_runs_and_produces_finite_output() {
let mut metal = MetalContext::new().expect("MetalContext::new");
let expert_data = synth::expert_data_seeded();
let h_post = synth::h_post_seeded();
let mut out = vec![0.0f32; VARIANT.hidden_dim];
gpu_expert_forward(&mut metal, &expert_data, &h_post, &mut out)
.expect("gpu_expert_forward");
assert!(out.iter().all(|x| x.is_finite()), "output has NaN/Inf");
assert!(
out.iter().any(|&x| x.abs() > 0.0),
"output is all zero — kernel didn't write?"
);
}
}
pub mod synth {
use super::*;
pub fn expert_data_seeded() -> Vec<u8> {
let v: Variant = VARIANT;
let mut data = vec![0u8; v.expert_size_4bit()];
for block in 0..3 {
let block_off = block * v.expert_block_bytes_4bit();
let w_end = block_off + v.expert_weight_bytes_4bit();
let mut state: u64 = 0xCAFE_BEEF + block as u64;
for byte in &mut data[block_off..w_end] {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*byte = (state >> 32) as u8;
}
let s_end = w_end + v.expert_scale_bytes();
for chunk in data[w_end..s_end].chunks_exact_mut(2) {
chunk[0] = 0x00;
chunk[1] = 0x3C;
}
}
data
}
pub fn h_post_seeded() -> Vec<f32> {
let v = VARIANT;
(0..v.hidden_dim)
.map(|i| {
(i as f32 - v.hidden_dim as f32 / 2.0) * 1e-3
/ v.hidden_dim as f32
})
.collect()
}
pub fn k_expert_data_seeded(k: usize) -> Vec<u8> {
let v: Variant = VARIANT;
let per_expert = v.expert_size_4bit();
let mut data = vec![0u8; k * per_expert];
for slot in 0..k {
let dst = &mut data[slot * per_expert..(slot + 1) * per_expert];
for block in 0..3 {
let block_off = block * v.expert_block_bytes_4bit();
let w_end = block_off + v.expert_weight_bytes_4bit();
let mut state: u64 = 0xCAFE_BEEF
^ ((slot as u64) << 32)
^ (block as u64);
for byte in &mut dst[block_off..w_end] {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*byte = (state >> 32) as u8;
}
let s_end = w_end + v.expert_scale_bytes();
for chunk in dst[w_end..s_end].chunks_exact_mut(2) {
chunk[0] = 0x00;
chunk[1] = 0x3C;
}
}
}
data
}
pub fn h_mid_seeded() -> Vec<f32> {
let v = VARIANT;
(0..v.hidden_dim)
.map(|i| (i as f32 * 0.0007 - 0.05).sin() * 0.001)
.collect()
}
pub fn shared_out_seeded() -> Vec<f32> {
let v = VARIANT;
(0..v.hidden_dim)
.map(|i| (i as f32 * 0.0011 + 0.03).cos() * 0.001)
.collect()
}
pub fn expert_weights_seeded(k: usize) -> Vec<f32> {
let raw: Vec<f32> = (0..k)
.map(|i| ((i as f32) * 0.37 + 1.0).abs())
.collect();
let total: f32 = raw.iter().sum();
raw.iter().map(|w| w / total).collect()
}
}