use std::collections::VecDeque;
use super::expert_forward::{
gpu_batched_experts_encode, gpu_batched_experts_encode_pre_staged,
ChainToNormed, ExpertForwardError, MoeBuffers,
};
use super::metal::MetalBackend;
use super::variants::VARIANT;
use super::{RsCtx, RsError};
pub struct DeferredRing {
states: VecDeque<DeferredState>,
}
impl DeferredRing {
pub(crate) const DEPTH: usize = 2;
pub fn new() -> Self {
Self {
states: VecDeque::with_capacity(Self::DEPTH),
}
}
pub fn is_empty(&self) -> bool {
self.states.is_empty()
}
pub fn is_full(&self) -> bool {
self.states.len() >= Self::DEPTH
}
pub fn len(&self) -> usize {
self.states.len()
}
pub(crate) fn push(
&mut self,
state: DeferredState,
) -> Result<(), DeferredError> {
if self.is_full() {
return Err(DeferredError::RingFull);
}
self.states.push_back(state);
Ok(())
}
pub(crate) fn pop_oldest(&mut self) -> Option<DeferredState> {
self.states.pop_front()
}
}
impl Default for DeferredRing {
fn default() -> Self {
Self::new()
}
}
pub enum DeferredMode {
Gpu,
Cpu {
h_mid: Vec<f32>,
shared_out: Vec<f32>,
expert_weights: Vec<f32>,
shared_gate_score: f32,
},
}
pub struct DeferredState {
cmd_buffer: metal::CommandBuffer,
mode: DeferredMode,
#[allow(dead_code)]
layer_idx: i32,
actual_k: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum DeferredError {
#[error(
"deferred-experts ring is at capacity; drain the oldest dispatch \
before pushing a new one"
)]
RingFull,
#[error(
"hidden_out must be HIDDEN_DIM={expected} floats, got {actual}"
)]
BadHiddenOutLen { expected: usize, actual: usize },
#[error("Metal backend or MoE buffers init failed")]
Init,
#[error(transparent)]
Encode(#[from] ExpertForwardError),
}
impl From<RsError> for DeferredError {
fn from(_: RsError) -> Self {
DeferredError::Init
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn gpu_batched_experts_begin(
metal: &mut MetalBackend,
bufs: &mut MoeBuffers,
ring: &mut DeferredRing,
actual_k: i32,
expert_data: &[u8],
h_post: &[f32],
h_mid: &[f32],
shared_out: &[f32],
expert_weights: &[f32],
shared_gate_score: f32,
layer_idx: i32,
gpu_combine: bool,
) -> Result<(), DeferredError> {
if ring.is_full() {
return Err(DeferredError::RingFull);
}
let cmd_buffer = gpu_batched_experts_encode(
metal,
bufs,
actual_k,
expert_data,
h_post,
h_mid,
shared_out,
expert_weights,
shared_gate_score,
gpu_combine,
)?;
cmd_buffer.commit();
let mode = if gpu_combine {
DeferredMode::Gpu
} else {
DeferredMode::Cpu {
h_mid: h_mid.to_vec(),
shared_out: shared_out.to_vec(),
expert_weights: expert_weights.to_vec(),
shared_gate_score,
}
};
ring.push(DeferredState {
cmd_buffer,
mode,
layer_idx,
actual_k: actual_k as usize,
})?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn gpu_batched_experts_begin_pre_staged(
metal: &mut MetalBackend,
bufs: &mut MoeBuffers,
ring: &mut DeferredRing,
actual_k: i32,
input: &metal::BufferRef,
h_mid: &metal::BufferRef,
shared_out: &metal::BufferRef,
expert_weights: &[f32],
shared_gate_score: f32,
layer_idx: i32,
data_set_per_slot: &[super::SlotSource; super::MAX_K],
prefetch_set: usize,
chain: Option<ChainToNormed<'_>>,
) -> Result<(), DeferredError> {
if ring.is_full() {
return Err(DeferredError::RingFull);
}
let cmd_buffer = gpu_batched_experts_encode_pre_staged(
metal,
bufs,
actual_k,
input,
h_mid,
shared_out,
expert_weights,
shared_gate_score,
data_set_per_slot,
prefetch_set,
chain,
)?;
cmd_buffer.commit();
ring.push(DeferredState {
cmd_buffer,
mode: DeferredMode::Gpu,
layer_idx,
actual_k: actual_k as usize,
})?;
Ok(())
}
pub(crate) fn complete_deferred_experts_into(
ring: &mut DeferredRing,
bufs: &MoeBuffers,
hidden_out: &mut [f32],
) -> Result<(), DeferredError> {
if hidden_out.len() != VARIANT.hidden_dim {
return Err(DeferredError::BadHiddenOutLen {
expected: VARIANT.hidden_dim,
actual: hidden_out.len(),
});
}
let Some(state) = ring.pop_oldest() else {
return Ok(());
};
state.cmd_buffer.wait_until_completed();
match state.mode {
DeferredMode::Gpu => {
hidden_out.copy_from_slice(&bufs.moe_hidden().to_vec());
}
DeferredMode::Cpu {
h_mid,
shared_out,
expert_weights,
shared_gate_score,
} => {
cpu_combine(
bufs,
state.actual_k,
&h_mid,
&shared_out,
&expert_weights,
shared_gate_score,
hidden_out,
);
}
}
Ok(())
}
pub(crate) fn complete_deferred_experts_chained(
ring: &mut DeferredRing,
) -> Result<(), DeferredError> {
let Some(state) = ring.pop_oldest() else {
return Ok(());
};
state.cmd_buffer.wait_until_completed();
Ok(())
}
fn cpu_combine(
bufs: &MoeBuffers,
actual_k: usize,
h_mid: &[f32],
shared_out: &[f32],
expert_weights: &[f32],
shared_gate_score: f32,
hidden_out: &mut [f32],
) {
let dim = VARIANT.hidden_dim;
debug_assert_eq!(h_mid.len(), dim);
debug_assert_eq!(shared_out.len(), dim);
debug_assert_eq!(expert_weights.len(), actual_k);
debug_assert_eq!(hidden_out.len(), dim);
let mut moe_out = vec![0.0f32; dim];
for k in 0..actual_k {
let expert_k = bufs.out(k).to_vec();
debug_assert_eq!(expert_k.len(), dim);
super::cpu_ops::cpu_vec_madd(
&mut moe_out,
&expert_k,
expert_weights[k],
);
}
let shared_weight = super::cpu_ops::cpu_sigmoid_scalar(shared_gate_score);
for i in 0..dim {
let s = shared_out[i].mul_add(shared_weight, moe_out[i]);
hidden_out[i] = h_mid[i] + s;
}
}
pub(crate) fn discard_deferred_experts_in(ring: &mut DeferredRing) {
while let Some(state) = ring.pop_oldest() {
state.cmd_buffer.wait_until_completed();
}
}
impl RsCtx {
#[allow(clippy::too_many_arguments)]
pub fn begin_deferred_experts(
&mut self,
actual_k: i32,
expert_data: &[u8],
h_post: &[f32],
h_mid: &[f32],
shared_out: &[f32],
expert_weights: &[f32],
shared_gate_score: f32,
layer_idx: i32,
) -> Result<(), DeferredError> {
let _ = self.metal_and_moe_mut()?;
let Self {
metal,
moe_buffers,
deferred,
..
} = self;
let metal = metal.as_mut().expect("metal_and_moe_mut just-set");
let bufs =
moe_buffers.as_mut().expect("metal_and_moe_mut just-set");
gpu_batched_experts_begin(
metal,
bufs,
deferred,
actual_k,
expert_data,
h_post,
h_mid,
shared_out,
expert_weights,
shared_gate_score,
layer_idx,
true,
)
}
pub fn complete_deferred_experts(
&mut self,
hidden_out: &mut [f32],
) -> Result<(), DeferredError> {
let Self {
moe_buffers,
deferred,
..
} = self;
let Some(bufs) = moe_buffers.as_ref() else {
return Ok(());
};
complete_deferred_experts_into(deferred, bufs, hidden_out)
}
pub fn discard_deferred_experts(&mut self) {
discard_deferred_experts_in(&mut self.deferred);
}
}