use candle_core::{
quantized::{QMatMul, QTensor},
Result, Tensor,
};
use candle_nn::Linear;
use std::sync::Arc;
use crate::{QuantMethod, QuantMethodConfig, UnquantLinear};
pub fn qtensor_indexed_moe_forward(
qtensor: &Arc<QTensor>,
x: &Tensor,
ids: &Tensor,
) -> Result<Tensor> {
let device = x.device();
let weights = qtensor.dequantize(device)?;
let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(weights, None)))?;
unquant.gather_forward(x, ids)
}
pub fn cpu_indexed_moe_forward(qmatmul: &QMatMul, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
match qmatmul {
QMatMul::QTensor(qtensor) => qtensor_indexed_moe_forward(qtensor, x, ids),
QMatMul::Tensor(t) | QMatMul::TensorF16(t) => {
let unquant =
UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(t.clone(), None)))?;
unquant.gather_forward(x, ids)
}
}
}