use crate::backend::cuda::{CudaBackend, GptqStoreCuda};
use crate::backend::CudaBuf;
use crate::Linear;
use ferrum_types::{FerrumError, Result};
use std::sync::Arc;
pub struct CudaMarlinLinear {
pub store: GptqStoreCuda,
pub bias: Option<CudaBuf>,
pub in_features: usize,
pub out_features: usize,
}
impl Linear<CudaBackend> for CudaMarlinLinear {
fn in_features(&self) -> usize {
self.in_features
}
fn out_features(&self) -> usize {
self.out_features
}
#[allow(clippy::needless_return)]
fn forward(
&self,
ctx: &mut <CudaBackend as crate::backend::Backend>::Context,
input: &<CudaBackend as crate::backend::Backend>::Buffer,
out: &mut <CudaBackend as crate::backend::Backend>::Buffer,
m: usize,
) {
let res: Result<()> = {
#[cfg(feature = "marlin")]
{
#[cfg(feature = "triton-kernels")]
{
match &self.store {
GptqStoreCuda::Marlin(mw) => crate::backend::cuda::marlin_gemm_with_perm(
ctx,
input.as_f16(),
mw,
out.as_f16_mut(),
m,
),
GptqStoreCuda::Triton(tw) => {
let func = ctx.func(
"triton_w4a16_gptq",
crate::triton_ptx::w4a16_gptq_f16::PTX,
crate::triton_w4a16::fn_name(),
);
let stream = ctx.stream.clone();
crate::triton_w4a16::launch_w4a16_gptq_triton(
&stream,
&func,
input.as_f16(),
tw,
out.as_f16_mut(),
m as i32,
)
.map_err(|e| FerrumError::model(format!("triton w4a16: {e}")))
}
}
}
#[cfg(not(feature = "triton-kernels"))]
{
crate::backend::cuda::marlin_gemm_with_perm(
ctx,
input.as_f16(),
&self.store,
out.as_f16_mut(),
m,
)
}
}
#[cfg(all(not(feature = "marlin"), feature = "triton-kernels"))]
{
match &self.store {
GptqStoreCuda::Marlin(_) => Err(FerrumError::unsupported(
"cargo feature `marlin` disabled — Marlin variant unusable; \
set FERRUM_TRITON_INT4=1 to force the triton path",
)),
GptqStoreCuda::Triton(tw) => {
let func = ctx.func(
"triton_w4a16_gptq",
crate::triton_ptx::w4a16_gptq_f16::PTX,
crate::triton_w4a16::fn_name(),
);
let stream = ctx.stream.clone();
crate::triton_w4a16::launch_w4a16_gptq_triton(
&stream,
&func,
input.as_f16(),
tw,
out.as_f16_mut(),
m as i32,
)
.map_err(|e| FerrumError::model(format!("triton w4a16: {e}")))
}
}
}
#[cfg(all(not(feature = "marlin"), not(feature = "triton-kernels")))]
{
let _ = (ctx, input, out, m);
Err(FerrumError::unsupported(
"cargo features `marlin` and `triton-kernels` both disabled — \
GPTQ not available",
))
}
};
res.unwrap_or_else(|e| panic!("CudaMarlinLinear forward failed: {e}"));
if let Some(bias) = &self.bias {
<CudaBackend as crate::backend::Backend>::add_bias(
ctx,
out,
bias,
m,
self.out_features,
);
}
}
}
pub struct CudaMarlinStackedExpertLinear {
pub store: Arc<GptqStoreCuda>,
pub expert_offset: usize,
pub expert_n: usize,
pub k: usize,
pub bias: Option<CudaBuf>,
}
impl Linear<CudaBackend> for CudaMarlinStackedExpertLinear {
fn in_features(&self) -> usize {
self.k
}
fn out_features(&self) -> usize {
self.expert_n
}
fn forward(
&self,
ctx: &mut <CudaBackend as crate::backend::Backend>::Context,
input: &<CudaBackend as crate::backend::Backend>::Buffer,
out: &mut <CudaBackend as crate::backend::Backend>::Buffer,
m: usize,
) {
let res: Result<()> = {
#[cfg(feature = "marlin")]
{
#[cfg(feature = "triton-kernels")]
let mw = match self.store.as_ref() {
GptqStoreCuda::Marlin(mw) => mw,
GptqStoreCuda::Triton(_) => {
panic!(
"CudaMarlinStackedExpertLinear: Triton w4a16 store has no \
stride-aware variant; load MoE via Marlin (default)"
);
}
};
#[cfg(not(feature = "triton-kernels"))]
let mw: &crate::marlin::MarlinWeight = self.store.as_ref();
let stream = ctx.stream.clone();
crate::marlin::marlin_gemm_with_offset(
&stream,
input.as_f16(),
mw,
out.as_f16_mut(),
m as i32,
self.expert_offset as i32,
self.expert_n as i32,
)
.map_err(|e| FerrumError::model(format!("marlin offset gemm: {e}")))
}
#[cfg(not(feature = "marlin"))]
{
let _ = (ctx, input, out, m);
Err(FerrumError::unsupported(
"cargo feature `marlin` disabled — \
CudaMarlinStackedExpertLinear unavailable",
))
}
};
res.unwrap_or_else(|e| panic!("CudaMarlinStackedExpertLinear forward failed: {e}"));
if let Some(bias) = &self.bias {
<CudaBackend as crate::backend::Backend>::add_bias(ctx, out, bias, m, self.expert_n);
}
}
}