use ferrum_kernels::backend::{Backend, BackendQuantGguf, GgufQuantType};
use ferrum_kernels::Linear;
use ferrum_types::Result;
pub struct QuantLinear<B: Backend + BackendQuantGguf> {
inner: Box<dyn Linear<B> + Send + Sync>,
}
impl<B: Backend + BackendQuantGguf> QuantLinear<B> {
pub fn from_gguf_bytes(
kind: GgufQuantType,
bytes: &[u8],
out_features: usize,
in_features: usize,
) -> Result<Self> {
let inner = B::load_quant(kind, bytes, out_features, in_features)?;
Ok(Self { inner })
}
pub fn from_gguf_fused(
parts: &[(GgufQuantType, &[u8], usize)],
in_features: usize,
) -> Result<Self> {
let inner = B::load_quant_fused(parts, in_features)?;
Ok(Self { inner })
}
}
impl<B: Backend + BackendQuantGguf> Linear<B> for QuantLinear<B> {
fn in_features(&self) -> usize {
self.inner.in_features()
}
fn out_features(&self) -> usize {
self.inner.out_features()
}
fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
self.inner.forward(ctx, input, out, m);
}
}