use crate::error::{Error, Result};
use crate::quant::QuantFormat;
use crate::quant::decomposed::DecomposedQuantTensor;
use crate::quant::tensor::QuantTensor;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub enum Weight<R: Runtime> {
Standard(Tensor<R>),
Quantized(QuantTensor<R>),
DecomposedQuant(Box<DecomposedQuantTensor<R>>),
}
impl<R: Runtime> Weight<R> {
pub fn is_quantized(&self) -> bool {
matches!(self, Self::Quantized(_) | Self::DecomposedQuant(_))
}
pub fn as_tensor(&self) -> Result<&Tensor<R>> {
match self {
Self::Standard(t) => Ok(t),
_ => Err(Error::ModelError {
reason: "expected standard tensor, got quantized".into(),
}),
}
}
pub fn as_quant_tensor(&self) -> Result<&QuantTensor<R>> {
match self {
Self::Quantized(q) => Ok(q),
_ => Err(Error::ModelError {
reason: "expected block-quantized tensor".into(),
}),
}
}
pub fn as_decomposed_quant_tensor(&self) -> Result<&DecomposedQuantTensor<R>> {
match self {
Self::DecomposedQuant(dq) => Ok(dq),
_ => Err(Error::ModelError {
reason: "expected decomposed quantized tensor".into(),
}),
}
}
pub fn into_tensor(self) -> Result<Tensor<R>> {
match self {
Self::Standard(t) => Ok(t),
_ => Err(Error::ModelError {
reason: "expected standard tensor, got quantized".into(),
}),
}
}
pub fn into_quant_tensor(self) -> Result<QuantTensor<R>> {
match self {
Self::Quantized(q) => Ok(q),
_ => Err(Error::ModelError {
reason: "expected block-quantized tensor".into(),
}),
}
}
pub fn into_decomposed_quant_tensor(self) -> Result<DecomposedQuantTensor<R>> {
match self {
Self::DecomposedQuant(dq) => Ok(*dq),
_ => Err(Error::ModelError {
reason: "expected decomposed quantized tensor".into(),
}),
}
}
}
impl<R: Runtime<DType = numr::dtype::DType>> Weight<R> {
pub fn shape(&self) -> &[usize] {
match self {
Self::Standard(t) => t.shape(),
Self::Quantized(q) => q.shape(),
Self::DecomposedQuant(dq) => dq.shape(),
}
}
pub fn quant_format(&self) -> Option<QuantFormat> {
match self {
Self::Quantized(q) => Some(q.format()),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
fn device() -> CpuDevice {
CpuDevice::new()
}
#[test]
fn test_standard_weight() {
let d = device();
let t = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &d);
let w = Weight::Standard(t);
assert!(!w.is_quantized());
assert_eq!(w.shape(), &[3]);
assert!(w.as_tensor().is_ok());
assert!(w.as_quant_tensor().is_err());
assert!(w.quant_format().is_none());
}
#[test]
fn test_quantized_weight() {
let d = device();
let data = vec![0u8; 18]; let qt =
QuantTensor::<CpuRuntime>::from_bytes(&data, QuantFormat::Q4_0, &[32], &d).unwrap();
let w = Weight::Quantized(qt);
assert!(w.is_quantized());
assert_eq!(w.shape(), &[32]);
assert!(w.as_tensor().is_err());
assert!(w.as_quant_tensor().is_ok());
assert_eq!(w.quant_format(), Some(QuantFormat::Q4_0));
}
#[test]
fn test_into_tensor() {
let d = device();
let t = Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &d);
let w = Weight::Standard(t);
assert!(w.into_tensor().is_ok());
}
#[test]
fn test_into_quant_tensor() {
let d = device();
let data = vec![0u8; 18];
let qt =
QuantTensor::<CpuRuntime>::from_bytes(&data, QuantFormat::Q4_0, &[32], &d).unwrap();
let w = Weight::Quantized(qt);
assert!(w.into_quant_tensor().is_ok());
}
}