use ferrum_kernels::backend::{Backend, GgufQuantType};
use ferrum_kernels::Linear;
use ferrum_types::Result;
pub struct QuantLinear<B: Backend> {
store: B::QuantStore,
in_features: usize,
out_features: usize,
}
impl<B: Backend> QuantLinear<B> {
pub fn from_gguf_bytes(
kind: GgufQuantType,
bytes: &[u8],
out_features: usize,
in_features: usize,
) -> Result<Self> {
let store = B::load_quant(kind, bytes, out_features, in_features)?;
Ok(Self {
store,
in_features,
out_features,
})
}
pub fn from_gguf_fused(
parts: &[(GgufQuantType, &[u8], usize)],
in_features: usize,
) -> Result<Self> {
let store = B::load_quant_fused(parts, in_features)?;
let out_features = parts.iter().map(|(_, _, n)| *n).sum();
Ok(Self {
store,
in_features,
out_features,
})
}
pub fn from_store(store: B::QuantStore, out_features: usize, in_features: usize) -> Self {
Self {
store,
in_features,
out_features,
}
}
pub fn store(&self) -> &B::QuantStore {
&self.store
}
}
impl<B: Backend> Linear<B> for QuantLinear<B> {
fn in_features(&self) -> usize {
self.in_features
}
fn out_features(&self) -> usize {
self.out_features
}
fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
B::gemm_quant(ctx, input, &self.store, out, m).expect("gemm_quant");
}
}