Skip to main content

ferrum_kernels/quant_linear/
cpu_marlin_stack.rs

1//! `MarlinExpertStack<CpuBackend>` impl on top of CPU's dequant-on-load
2//! GptqStore. Facade — delegates to the existing
3//! `BackendQuantMarlin::moe_gemm_phase_*` (default trait impl that loops
4//! calling `gemm_gptq_with_offset_strided` on CPU) and
5//! `make_stacked_expert_linear` methods.
6//!
7//! CPU has no real Marlin tiles; the impl exists so the bucketed MoE
8//! path's parity test (`tests/moe_bucketed_parity_test.rs`) still
9//! compiles after the Phase C trait-object migration. Phase C step 4
10//! inlines the kernel calls here and deletes the trait methods.
11
12use crate::backend::cpu::CpuBackend;
13use crate::marlin_expert_stack::MarlinExpertStack;
14use crate::Linear;
15use ferrum_types::Result;
16use std::sync::Arc;
17
18pub struct CpuMarlinExpertStack {
19    pub store: Arc<crate::backend::cpu::CpuGptqStore>,
20    pub num_experts: usize,
21    pub n_per_expert: usize,
22    pub k: usize,
23}
24
25impl CpuMarlinExpertStack {
26    pub fn new(
27        store: Arc<crate::backend::cpu::CpuGptqStore>,
28        num_experts: usize,
29        n_per_expert: usize,
30        k: usize,
31    ) -> Self {
32        Self {
33            store,
34            num_experts,
35            n_per_expert,
36            k,
37        }
38    }
39}
40
41impl MarlinExpertStack<CpuBackend> for CpuMarlinExpertStack {
42    fn n_per_expert(&self) -> usize {
43        self.n_per_expert
44    }
45    fn k(&self) -> usize {
46        self.k
47    }
48    fn num_experts(&self) -> usize {
49        self.num_experts
50    }
51
52    fn as_any(&self) -> &dyn std::any::Any {
53        self
54    }
55
56    fn zero_workspace(
57        &self,
58        _ctx: &mut <CpuBackend as crate::backend::Backend>::Context,
59    ) -> Result<()> {
60        // CPU GptqStore (dequant-on-load) has no per-expert workspace
61        // mutex slots — Marlin-specific GPU artefact. No-op.
62        Ok(())
63    }
64
65    fn gemm_phase_batched(
66        &self,
67        ctx: &mut <CpuBackend as crate::backend::Backend>::Context,
68        input: &<CpuBackend as crate::backend::Backend>::Buffer,
69        dispatches: &[(usize, usize, usize, usize)],
70        output: &mut <CpuBackend as crate::backend::Backend>::Buffer,
71        k: usize,
72    ) -> Result<()> {
73        // Phase C step 4c+4e: serial loop calling the moved-out CPU
74        // gemm_gptq_with_offset_strided free function. No longer
75        // routes through any Backend trait method.
76        for (expert_idx, in_row_offset, out_row_offset, m) in dispatches {
77            crate::backend::cpu::cpu_gemm_gptq_with_offset_strided(
78                ctx,
79                input,
80                *in_row_offset,
81                &self.store,
82                expert_idx * self.n_per_expert,
83                self.n_per_expert,
84                output,
85                *out_row_offset,
86                *m,
87                k,
88            )?;
89        }
90        Ok(())
91    }
92
93    fn make_expert_linear(
94        self: Arc<Self>,
95        expert_offset: usize,
96        expert_n: usize,
97        bias_host: Option<&[f32]>,
98    ) -> Result<Box<dyn Linear<CpuBackend> + Send + Sync>> {
99        // Inlined from BackendQuantMarlin::make_stacked_expert_linear
100        // (Phase C step 4b).
101        if expert_offset + expert_n > self.store.n {
102            return Err(ferrum_types::FerrumError::model(format!(
103                "make_expert_linear OOB: offset {expert_offset} + n {expert_n} > stacked_n {}",
104                self.store.n
105            )));
106        }
107        if self.k != self.store.k {
108            return Err(ferrum_types::FerrumError::model(format!(
109                "make_expert_linear k mismatch: arg {} vs store.k {}",
110                self.k, self.store.k
111            )));
112        }
113        let row_start = expert_offset * self.k;
114        let row_end = (expert_offset + expert_n) * self.k;
115        let slice = self.store.weight_f32[row_start..row_end].to_vec();
116        Ok(Box::new(crate::quant_linear::cpu_dequant::CpuGptqLinear {
117            weight_f32: slice,
118            bias: bias_host.map(|b| b.to_vec()),
119            in_features: self.k,
120            out_features: expert_n,
121        }))
122    }
123}