use crate::backend::cpu::CpuBackend;
use crate::marlin_expert_stack::MarlinExpertStack;
use crate::Linear;
use ferrum_types::Result;
use std::sync::Arc;
pub struct CpuMarlinExpertStack {
pub store: Arc<crate::backend::cpu::CpuGptqStore>,
pub num_experts: usize,
pub n_per_expert: usize,
pub k: usize,
}
impl CpuMarlinExpertStack {
pub fn new(
store: Arc<crate::backend::cpu::CpuGptqStore>,
num_experts: usize,
n_per_expert: usize,
k: usize,
) -> Self {
Self {
store,
num_experts,
n_per_expert,
k,
}
}
}
impl MarlinExpertStack<CpuBackend> for CpuMarlinExpertStack {
fn n_per_expert(&self) -> usize {
self.n_per_expert
}
fn k(&self) -> usize {
self.k
}
fn num_experts(&self) -> usize {
self.num_experts
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn zero_workspace(
&self,
_ctx: &mut <CpuBackend as crate::backend::Backend>::Context,
) -> Result<()> {
Ok(())
}
fn gemm_phase_batched(
&self,
ctx: &mut <CpuBackend as crate::backend::Backend>::Context,
input: &<CpuBackend as crate::backend::Backend>::Buffer,
dispatches: &[(usize, usize, usize, usize)],
output: &mut <CpuBackend as crate::backend::Backend>::Buffer,
k: usize,
) -> Result<()> {
for (expert_idx, in_row_offset, out_row_offset, m) in dispatches {
crate::backend::cpu::cpu_gemm_gptq_with_offset_strided(
ctx,
input,
*in_row_offset,
&self.store,
expert_idx * self.n_per_expert,
self.n_per_expert,
output,
*out_row_offset,
*m,
k,
)?;
}
Ok(())
}
fn make_expert_linear(
self: Arc<Self>,
expert_offset: usize,
expert_n: usize,
bias_host: Option<&[f32]>,
) -> Result<Box<dyn Linear<CpuBackend> + Send + Sync>> {
if expert_offset + expert_n > self.store.n {
return Err(ferrum_types::FerrumError::model(format!(
"make_expert_linear OOB: offset {expert_offset} + n {expert_n} > stacked_n {}",
self.store.n
)));
}
if self.k != self.store.k {
return Err(ferrum_types::FerrumError::model(format!(
"make_expert_linear k mismatch: arg {} vs store.k {}",
self.k, self.store.k
)));
}
let row_start = expert_offset * self.k;
let row_end = (expert_offset + expert_n) * self.k;
let slice = self.store.weight_f32[row_start..row_end].to_vec();
Ok(Box::new(crate::quant_linear::cpu_dequant::CpuGptqLinear {
weight_f32: slice,
bias: bias_host.map(|b| b.to_vec()),
in_features: self.k,
out_features: expert_n,
}))
}
}