tract-core 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use tract_linalg::block_quant::{BlockQuantFact, BlockQuantStorage};

use crate::internal::*;

#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
pub struct BlockQuantIntoShape {
    pub shape: TVec<usize>,
}

impl Op for BlockQuantIntoShape {
    fn name(&self) -> StaticName {
        "BlockQuantIntoShape".into()
    }
    op_as_typed_op!();
}

impl EvalOp for BlockQuantIntoShape {
    fn is_stateless(&self) -> bool {
        true
    }

    fn state(
        &self,
        _session: &TurnState,
        _node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        Ok(None)
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs).into_tensor();
        let g = input.shape()[0];
        let bqs = input.try_storage_as::<BlockQuantStorage>()?.clone();
        let new_m = self.shape[0];
        let new_k: usize = self.shape[1..].iter().product();
        Ok(tvec!(bqs.into_tensor_with_shape(input.datum_type(), &[g, new_m, new_k]).into_tvalue()))
    }
}

impl TypedOp for BlockQuantIntoShape {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let input = inputs[0];
        let old = input
            .exotic_fact
            .as_ref()
            .and_then(|of| of.downcast_ref::<BlockQuantFact>())
            .context("Expects BlockQuantFact")?;
        let g: usize = input.shape[0].to_usize()?;
        let new_m = self.shape[0];
        let new_k: usize = self.shape[1..].iter().product();
        let bqf_shape = tvec!(g, new_m, new_k);
        let new = BlockQuantFact::new(old.format.clone(), bqf_shape.clone());
        let shape: TVec<TDim> = bqf_shape.iter().map(|d| d.to_dim()).collect();
        let fact = inputs[0].datum_type.fact(&*shape).with_exotic_fact(new);
        Ok(tvec!(fact))
    }
    as_op!();
}

#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
pub struct SplitGroupBlockQuant {
    pub group: usize,
}

impl Op for SplitGroupBlockQuant {
    fn name(&self) -> StaticName {
        "SplitGroupBlockQuant".into()
    }

    op_as_typed_op!();
}

impl EvalOp for SplitGroupBlockQuant {
    fn is_stateless(&self) -> bool {
        true
    }

    fn state(
        &self,
        _session: &TurnState,
        _node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        Ok(None)
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs);
        let bqs = input.try_storage_as::<BlockQuantStorage>()?.clone();
        let mut new_shape: TVec<usize> = input.shape().into();
        let o = new_shape[0];
        new_shape[0] = o / self.group;
        new_shape.insert(0, self.group);
        Ok(tvec!(bqs.into_tensor_with_shape(input.datum_type(), &new_shape).into_tvalue()))
    }
}

impl TypedOp for SplitGroupBlockQuant {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let input = inputs[0];
        let bqf = input
            .exotic_fact
            .as_ref()
            .and_then(|of| of.downcast_ref::<BlockQuantFact>())
            .context("Expect BlockQuantFact")?;
        let o: usize = input.shape[0].to_usize()?;
        ensure!(o % self.group == 0);
        let mut new_shape: TVec<usize> =
            input.shape.iter().map(|d| d.to_usize()).collect::<TractResult<_>>()?;
        new_shape[0] = o / self.group;
        new_shape.insert(0, self.group);
        let exotic_fact = BlockQuantFact::new(bqf.format.clone(), new_shape.clone());
        let fact = inputs[0]
            .datum_type
            .fact(&*new_shape.iter().map(|d| d.to_dim()).collect::<TVec<_>>())
            .with_exotic_fact(exotic_fact);
        Ok(tvec!(fact))
    }
    as_op!();
}