use crate::error::Result;
use crate::quant::QuantTensor;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub trait QuantMatmulOps<R: Runtime> {
fn quant_matmul(&self, activation: &Tensor<R>, weight: &QuantTensor<R>) -> Result<Tensor<R>>;
fn quant_matmul_batch(
&self,
activation: &Tensor<R>,
weights: &[&QuantTensor<R>],
) -> Result<Vec<Tensor<R>>> {
weights
.iter()
.map(|w| self.quant_matmul(activation, w))
.collect()
}
fn int4_gemm(
&self,
input: &Tensor<R>,
qweight: &Tensor<R>,
scales: &Tensor<R>,
zeros: &Tensor<R>,
group_size: usize,
) -> Result<Tensor<R>>;
fn int4_gemm_gptq(
&self,
input: &Tensor<R>,
qweight: &Tensor<R>,
qzeros: &Tensor<R>,
scales: &Tensor<R>,
g_idx: &Tensor<R>,
) -> Result<Tensor<R>>;
fn marlin_gemm(
&self,
input: &Tensor<R>,
weight: &Tensor<R>,
scales: &Tensor<R>,
zeros: &Tensor<R>,
group_size: usize,
) -> Result<Tensor<R>>;
fn quant_swiglu(
&self,
activation: &Tensor<R>,
gate_weight: &QuantTensor<R>,
up_weight: &QuantTensor<R>,
) -> Result<Tensor<R>>;
}