use candle_core::quantized::QTensor;
use candle_core::{Device, Result as CandleResult};
use ferrum_kernels::backend::Backend;
use crate::dense::DenseLinear;
use crate::traits::Linear;
pub struct GgufLinear<B: Backend> {
inner: DenseLinear<B>,
}
impl<B: Backend> GgufLinear<B> {
pub fn from_qtensor(qt: &QTensor) -> CandleResult<Self> {
let dims = qt.shape().dims();
if dims.len() != 2 {
return Err(candle_core::Error::Msg(format!(
"GgufLinear: expected 2-D weight tensor, got rank {} (shape {:?})",
dims.len(),
dims
)));
}
let out_features = dims[0];
let in_features = dims[1];
let weights = dequantize_to_vec(qt)?;
Ok(Self {
inner: DenseLinear::<B>::from_rows(&weights, out_features, in_features),
})
}
pub fn from_qtensor_with_bias(qt: &QTensor, bias_qt: &QTensor) -> CandleResult<Self> {
let weight_dims = qt.shape().dims();
if weight_dims.len() != 2 {
return Err(candle_core::Error::Msg(format!(
"GgufLinear: expected 2-D weight, got rank {}",
weight_dims.len()
)));
}
let out_features = weight_dims[0];
let in_features = weight_dims[1];
let bias_dims = bias_qt.shape().dims();
if bias_dims.len() != 1 || bias_dims[0] != out_features {
return Err(candle_core::Error::Msg(format!(
"GgufLinear: bias shape {:?} doesn't match weight out_features {}",
bias_dims, out_features
)));
}
let weights = dequantize_to_vec(qt)?;
let bias = dequantize_to_vec(bias_qt)?;
Ok(Self {
inner: DenseLinear::<B>::from_rows_with_bias(
&weights,
&bias,
out_features,
in_features,
),
})
}
pub fn from_dense_rows(
weight_row_major: &[f32],
out_features: usize,
in_features: usize,
) -> Self {
Self {
inner: DenseLinear::<B>::from_rows(weight_row_major, out_features, in_features),
}
}
}
impl<B: Backend> Linear<B> for GgufLinear<B> {
fn in_features(&self) -> usize {
self.inner.in_features()
}
fn out_features(&self) -> usize {
self.inner.out_features()
}
fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
self.inner.forward(ctx, input, out, m);
}
}
pub fn linear_from_qtensor<B: Backend>(qt: &QTensor) -> CandleResult<Box<dyn Linear<B>>> {
Ok(Box::new(GgufLinear::<B>::from_qtensor(qt)?))
}
fn dequantize_to_vec(qt: &QTensor) -> CandleResult<Vec<f32>> {
let dense = qt.dequantize(&Device::Cpu)?;
let flat = dense.flatten_all()?;
flat.to_vec1::<f32>()
}