tract-linalg 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use std::fmt::Debug;
use tract_data::prelude::DatumType;

use crate::block_quant::{BlockQuant, PackedBlockQuantFormat};

use crate::mmm::MMMInputFormat;
use crate::pack::PackedFormat;

#[derive(Clone)]
pub enum WeightType {
    Plain(DatumType),
    BlockQuant(Box<dyn BlockQuant>),
}

impl From<DatumType> for WeightType {
    fn from(value: DatumType) -> Self {
        match value {
            DatumType::F16 => WeightType::Plain(DatumType::F16),
            DatumType::F32 => WeightType::Plain(DatumType::F32),
            DatumType::F64 => WeightType::Plain(DatumType::F64),
            DatumType::I32 => WeightType::Plain(DatumType::I32),
            DatumType::I8 | DatumType::QI8(_) => WeightType::Plain(DatumType::I8),
            DatumType::U8 | DatumType::QU8(_) => WeightType::Plain(DatumType::U8),
            _ => panic!("Can't build a WeightType from {value:?}"),
        }
    }
}

impl From<Box<dyn MMMInputFormat>> for WeightType {
    fn from(value: Box<dyn MMMInputFormat>) -> Self {
        (&*value).into()
    }
}

impl From<&dyn MMMInputFormat> for WeightType {
    fn from(value: &dyn MMMInputFormat) -> Self {
        if let Some(pf) = value.downcast_ref::<PackedFormat>() {
            WeightType::Plain(pf.dt)
        } else if let Some(pbqf) = value.downcast_ref::<PackedBlockQuantFormat>() {
            WeightType::BlockQuant(dyn_clone::clone_box(&*pbqf.bq))
        } else {
            todo!()
        }
    }
}

impl PartialEq for WeightType {
    fn eq(&self, other: &Self) -> bool {
        use WeightType::*;
        match (self, other) {
            (Plain(a), Plain(b)) => a == b,
            (BlockQuant(a), BlockQuant(b)) => a == b,
            _ => false,
        }
    }
}

impl<BQ: BlockQuant> From<BQ> for WeightType {
    fn from(value: BQ) -> Self {
        WeightType::BlockQuant(dyn_clone::clone_box(&value))
    }
}

impl Debug for WeightType {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Plain(p) => write!(f, "{p:?}"),
            Self::BlockQuant(bq) => write!(f, "{bq:?}"),
        }
    }
}

impl WeightType {
    pub fn as_dt(&self) -> Option<DatumType> {
        match self {
            WeightType::Plain(dt) => Some(*dt),
            _ => None,
        }
    }
}