ferrum_kernels/quant_linear/
cpu_marlin_stack.rs1use crate::backend::cpu::CpuBackend;
13use crate::marlin_expert_stack::MarlinExpertStack;
14use crate::Linear;
15use ferrum_types::Result;
16use std::sync::Arc;
17
18pub struct CpuMarlinExpertStack {
19 pub store: Arc<crate::backend::cpu::CpuGptqStore>,
20 pub num_experts: usize,
21 pub n_per_expert: usize,
22 pub k: usize,
23}
24
25impl CpuMarlinExpertStack {
26 pub fn new(
27 store: Arc<crate::backend::cpu::CpuGptqStore>,
28 num_experts: usize,
29 n_per_expert: usize,
30 k: usize,
31 ) -> Self {
32 Self {
33 store,
34 num_experts,
35 n_per_expert,
36 k,
37 }
38 }
39}
40
41impl MarlinExpertStack<CpuBackend> for CpuMarlinExpertStack {
42 fn n_per_expert(&self) -> usize {
43 self.n_per_expert
44 }
45 fn k(&self) -> usize {
46 self.k
47 }
48 fn num_experts(&self) -> usize {
49 self.num_experts
50 }
51
52 fn as_any(&self) -> &dyn std::any::Any {
53 self
54 }
55
56 fn zero_workspace(
57 &self,
58 _ctx: &mut <CpuBackend as crate::backend::Backend>::Context,
59 ) -> Result<()> {
60 Ok(())
63 }
64
65 fn gemm_phase_batched(
66 &self,
67 ctx: &mut <CpuBackend as crate::backend::Backend>::Context,
68 input: &<CpuBackend as crate::backend::Backend>::Buffer,
69 dispatches: &[(usize, usize, usize, usize)],
70 output: &mut <CpuBackend as crate::backend::Backend>::Buffer,
71 k: usize,
72 ) -> Result<()> {
73 for (expert_idx, in_row_offset, out_row_offset, m) in dispatches {
77 crate::backend::cpu::cpu_gemm_gptq_with_offset_strided(
78 ctx,
79 input,
80 *in_row_offset,
81 &self.store,
82 expert_idx * self.n_per_expert,
83 self.n_per_expert,
84 output,
85 *out_row_offset,
86 *m,
87 k,
88 )?;
89 }
90 Ok(())
91 }
92
93 fn make_expert_linear(
94 self: Arc<Self>,
95 expert_offset: usize,
96 expert_n: usize,
97 bias_host: Option<&[f32]>,
98 ) -> Result<Box<dyn Linear<CpuBackend> + Send + Sync>> {
99 if expert_offset + expert_n > self.store.n {
102 return Err(ferrum_types::FerrumError::model(format!(
103 "make_expert_linear OOB: offset {expert_offset} + n {expert_n} > stacked_n {}",
104 self.store.n
105 )));
106 }
107 if self.k != self.store.k {
108 return Err(ferrum_types::FerrumError::model(format!(
109 "make_expert_linear k mismatch: arg {} vs store.k {}",
110 self.k, self.store.k
111 )));
112 }
113 let row_start = expert_offset * self.k;
114 let row_end = (expert_offset + expert_n) * self.k;
115 let slice = self.store.weight_f32[row_start..row_end].to_vec();
116 Ok(Box::new(crate::quant_linear::cpu_dequant::CpuGptqLinear {
117 weight_f32: slice,
118 bias: bias_host.map(|b| b.to_vec()),
119 in_features: self.k,
120 out_features: expert_n,
121 }))
122 }
123}