use metal::{Buffer, Device, MTLResourceOptions, NSUInteger};
use rayon::prelude::*;
use crate::riir::backend::gpu::dense_mlp_gpu::{
DenseMlpGpuError, DenseMlpPipelines, encode_swiglu_ffn_layer_forward_gpu,
};
use crate::riir::io::embedding::bf16_to_f32;
use crate::riir::moe::expert_forward::{
ExpertForwardError, MoeBuffers, gpu_batched_experts_encode_mmap,
};
use crate::riir::io::expert_io::{ExpertFiles, ExpertIoError};
use crate::riir::backend::{BufferPool, MetalBufferPool};
use crate::riir::backend::gpu::gpu_matvec::{BfMatvecPipelines, encode_bf16_matvec};
use crate::riir::backend::gpu::metal::MetalContext;
use crate::riir::moe::moe_router::{MoeRouterError, noaux_tc_router_cpu};
use crate::riir::io::mtl_weight_buf::MtlWeightBuf;
use crate::riir::variants::VARIANT;
use crate::riir::io::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum CogitoMoeGpuError {
#[error("hidden length {got} != hidden_dim ({expected})")]
HiddenLen { got: usize, expected: usize },
#[error("hidden_out length {got} != hidden_dim ({expected})")]
HiddenOutLen { got: usize, expected: usize },
#[error("missing tensor '{name}'")]
MissingTensor { name: String },
#[error("tensor '{name}' size {got} bytes, expected {expected} bytes")]
TensorSize {
name: String,
got: usize,
expected: usize,
},
#[error("router: {0}")]
Router(#[from] MoeRouterError),
#[error("shared-expert FFN (GPU): {0}")]
SharedFfn(#[from] DenseMlpGpuError),
#[error("expert I/O: {0}")]
Io(#[from] ExpertIoError),
#[error("GPU experts: {0}")]
Experts(#[from] ExpertForwardError),
}
pub struct SharedExpertBuffers {
pub gate_out: Buffer,
pub up_out: Buffer,
pub act: Buffer,
}
impl SharedExpertBuffers {
pub fn new(device: &Device) -> Self {
let v = VARIANT;
let f32_buf = |n: usize| {
let b = device.new_buffer(
(n * std::mem::size_of::<f32>()) as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
unsafe {
std::ptr::write_bytes(b.contents() as *mut u8, 0, n * std::mem::size_of::<f32>());
}
b
};
Self {
gate_out: f32_buf(v.shared_intermediate),
up_out: f32_buf(v.shared_intermediate),
act: f32_buf(v.shared_intermediate),
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn cogito_moe_layer_forward_gpu(
metal: &mut MetalContext,
bufs: &mut MoeBuffers,
buffer_pool: &MetalBufferPool,
shared_bufs: &SharedExpertBuffers,
dense_pipes: &DenseMlpPipelines,
bf_pipes: &BfMatvecPipelines,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
expert_files: &ExpertFiles,
pool: &rayon::ThreadPool,
layer_idx: usize,
hidden: &[f32],
out: &mut [f32],
) -> Result<(), CogitoMoeGpuError> {
let v = VARIANT;
let hidden_dim = v.hidden_dim;
if hidden.len() != hidden_dim {
return Err(CogitoMoeGpuError::HiddenLen {
got: hidden.len(),
expected: hidden_dim,
});
}
if out.len() != hidden_dim {
return Err(CogitoMoeGpuError::HiddenOutLen {
got: out.len(),
expected: hidden_dim,
});
}
bufs.stage_host_input(buffer_pool, hidden);
let input_clone: Buffer = bufs.input_buffer(buffer_pool).clone();
cogito_moe_layer_forward_gpu_inner(
metal,
bufs,
buffer_pool,
shared_bufs,
dense_pipes,
bf_pipes,
wf,
wf_buf,
expert_files,
pool,
layer_idx,
input_clone,
)?;
out.copy_from_slice(&bufs.moe_hidden_to_vec(buffer_pool));
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn cogito_moe_layer_forward_gpu_buf_io(
metal: &mut MetalContext,
bufs: &mut MoeBuffers,
buffer_pool: &MetalBufferPool,
shared_bufs: &SharedExpertBuffers,
dense_pipes: &DenseMlpPipelines,
bf_pipes: &BfMatvecPipelines,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
expert_files: &ExpertFiles,
pool: &rayon::ThreadPool,
layer_idx: usize,
input_buf: &Buffer,
) -> Result<(), CogitoMoeGpuError> {
cogito_moe_layer_forward_gpu_inner(
metal,
bufs,
buffer_pool,
shared_bufs,
dense_pipes,
bf_pipes,
wf,
wf_buf,
expert_files,
pool,
layer_idx,
input_buf.clone(),
)
}
#[allow(clippy::too_many_arguments)]
fn cogito_moe_layer_forward_gpu_inner(
metal: &mut MetalContext,
bufs: &mut MoeBuffers,
buffer_pool: &MetalBufferPool,
shared_bufs: &SharedExpertBuffers,
dense_pipes: &DenseMlpPipelines,
bf_pipes: &BfMatvecPipelines,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
expert_files: &ExpertFiles,
pool: &rayon::ThreadPool,
layer_idx: usize,
input_buf: Buffer,
) -> Result<(), CogitoMoeGpuError> {
let v = VARIANT;
let hidden_dim = v.hidden_dim;
let num_experts = v.num_experts;
let k = v.num_experts_per_tok;
let gate_w_name = format!("model.layers.{layer_idx}.mlp.gate.weight");
let bias_name = format!("model.layers.{layer_idx}.mlp.gate.e_score_correction_bias");
let gate_w_off = wf_buf
.tensor_offset(wf, &gate_w_name)
.map_err(|e| CogitoMoeGpuError::SharedFfn(e.into()))?
.ok_or_else(|| CogitoMoeGpuError::MissingTensor {
name: gate_w_name.clone(),
})?;
{
let cmdbuf_gate = metal.queue().new_command_buffer();
encode_bf16_matvec(
cmdbuf_gate,
bf_pipes,
wf_buf.buffer(),
gate_w_off,
&input_buf,
bufs.gate_logits_buffer(buffer_pool),
hidden_dim as u32,
num_experts as u32,
);
cmdbuf_gate.commit();
cmdbuf_gate.wait_until_completed();
}
let mut gate_logits = bufs.gate_logits_to_vec(buffer_pool);
let bias_bytes =
wf.tensor_bytes(&bias_name)
.ok_or_else(|| CogitoMoeGpuError::MissingTensor {
name: bias_name.clone(),
})?;
let expected_bias_bytes = num_experts * 2;
if bias_bytes.len() != expected_bias_bytes {
return Err(CogitoMoeGpuError::TensorSize {
name: bias_name,
got: bias_bytes.len(),
expected: expected_bias_bytes,
});
}
let bias_u16 = bytemuck_u16(bias_bytes);
let bias_f32: Vec<f32> = bias_u16.iter().map(|&b| bf16_to_f32(b)).collect();
let mut indices = vec![0i32; k];
let mut weights = vec![0f32; k];
noaux_tc_router_cpu(
&mut gate_logits,
&bias_f32,
v.n_group,
v.topk_group,
k,
v.routed_scaling_factor,
&mut indices,
&mut weights,
)?;
bufs.stage_host_h_mid_zero(buffer_pool);
let mut dsts = bufs.data_synced_slots_mut_array(buffer_pool);
pool.install(|| -> Result<(), ExpertIoError> {
dsts[..k]
.par_iter_mut()
.zip(indices.par_iter().take(k))
.try_for_each(|(slot, &e)| expert_files.read_expert(layer_idx, e as usize, slot))
})?;
{
let cmdbuf_shared = metal.queue().new_command_buffer();
let prefix = format!("model.layers.{layer_idx}.mlp.shared_experts");
encode_swiglu_ffn_layer_forward_gpu(
cmdbuf_shared,
dense_pipes,
wf,
wf_buf,
&prefix,
v.shared_intermediate as u32,
&input_buf,
&shared_bufs.gate_out,
&shared_bufs.up_out,
&shared_bufs.act,
bufs.shared_out_buffer(buffer_pool),
)?;
cmdbuf_shared.commit();
}
let bindings: Vec<(&metal::Buffer, u64)> = (0..k)
.map(|slot| (buffer_pool.handle(bufs.data_synced_id(slot)), 0u64))
.collect();
let h_mid_clone: Buffer = bufs.h_mid_buffer(buffer_pool).clone();
let shared_clone: Buffer = bufs.shared_out_buffer(buffer_pool).clone();
let cmdbuf = gpu_batched_experts_encode_mmap(
metal,
bufs,
buffer_pool,
k as i32,
&input_buf,
&h_mid_clone,
&shared_clone,
&weights,
0.0,
&bindings,
None,
)?;
cmdbuf.commit();
cmdbuf.wait_until_completed();
Ok(())
}
fn bytemuck_u16(bytes: &[u8]) -> &[u16] {
static_assertions::assert_eq_size!(u16, [u8; 2]);
static_assertions::const_assert_eq!(std::mem::align_of::<u16>(), 2);
let (head, body, tail) = unsafe { bytes.align_to::<u16>() };
assert!(
head.is_empty() && tail.is_empty(),
"BF16 tensor not aligned to 2-byte boundary (head={}, tail={})",
head.len(),
tail.len()
);
body
}