tract-core 0.22.1

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use tract_linalg::block_quant::{BlockQuantFact, BlockQuantValue};

use crate::internal::*;

#[derive(Debug, Clone, new, Hash)]
pub struct BlockQuantIntoShape {
    pub shape: TVec<usize>,
}

impl Op for BlockQuantIntoShape {
    fn name(&self) -> StaticName {
        "BlockQuantIntoShape".into()
    }
    op_as_typed_op!();
}

impl EvalOp for BlockQuantIntoShape {
    fn is_stateless(&self) -> bool {
        true
    }

    fn state(
        &self,
        _session: &mut SessionState,
        _node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        Ok(None)
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let mut input = args_1!(inputs).into_tensor();
        for o in input.as_slice_mut::<Opaque>()? {
            let old =
                o.0.downcast_ref::<BlockQuantValue>().context("Expects only BlockQuantValues")?;
            let new = BlockQuantValue {
                value: old.value.clone(),
                fact: BlockQuantFact::new(old.fact.format.clone(), self.shape.clone()),
            };
            *o = Opaque(Arc::new(new))
        }
        Ok(tvec!(input.into_tvalue()))
    }
}

impl TypedOp for BlockQuantIntoShape {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let old = inputs[0]
            .opaque_fact
            .as_ref()
            .and_then(|of| of.downcast_ref::<BlockQuantFact>())
            .context("Expects BlockQuantFact")?
            .clone();
        let new = BlockQuantFact::new(old.format, self.shape.clone());
        let fact = inputs[0].clone().with_opaque_fact(new);
        Ok(tvec!(fact))
    }
    as_op!();
}

#[derive(Debug, Clone, new, Hash)]
pub struct SplitGroupBlockQuant {
    pub group: usize,
}

impl Op for SplitGroupBlockQuant {
    fn name(&self) -> StaticName {
        "SplitGroupBlockQuant".into()
    }

    op_as_typed_op!();
}

impl EvalOp for SplitGroupBlockQuant {
    fn is_stateless(&self) -> bool {
        true
    }

    fn state(
        &self,
        _session: &mut SessionState,
        _node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        Ok(None)
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs);
        let bqv = input
            .to_scalar::<Opaque>()?
            .downcast_ref::<BlockQuantValue>()
            .context("Expect BlockQuantValue")?;
        let rows_per_group = bqv.fact.m() / self.group;
        let splits = (0..self.group)
            .map(|g| {
                Ok(Opaque(Arc::new(bqv.split_rows(g * rows_per_group..(g + 1) * rows_per_group)?)))
            })
            .collect::<TractResult<Vec<_>>>()?;
        let output = tensor1(&splits);
        Ok(tvec!(output.into_tvalue()))
    }
}

impl TypedOp for SplitGroupBlockQuant {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let input = inputs[0];
        ensure!(input.shape.rank() == 0);
        let bqf = input
            .opaque_fact
            .as_ref()
            .and_then(|of| of.downcast_ref::<BlockQuantFact>())
            .context("Expect BlockQuantFact")?;
        let mut shape: TVec<usize> = bqf.shape().into();
        shape[0] /= self.group;
        let opaque_fact = BlockQuantFact::new(bqf.format.clone(), shape);
        let fact = Opaque::fact([self.group]).with_opaque_fact(opaque_fact);
        Ok(tvec!(fact))
    }
    as_op!();
}