tract-metal 0.23.0-dev.5

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
use crate::kernels::matmul::{GemmImpl, GemmKernel};

use anyhow::{bail, ensure};
use tract_core::internal::*;
use tract_core::tract_linalg::block_quant::Q4_0;
use tract_gpu::tensor::DeviceTensorExt;
use tract_gpu::utils::as_quant_fact;

#[derive(Debug, Default, Clone, Hash, PartialEq, Eq)]
pub struct MetalGemm<K: GemmKernel> {
    pub kernel: GemmImpl<K>,
}

impl<K: GemmKernel> MetalGemm<K> {
    pub fn new(transpose_a: bool, transpose_b: bool) -> Self {
        Self { kernel: GemmImpl::<K>::new(transpose_a, transpose_b) }
    }
}

impl<K: GemmKernel + 'static> Op for MetalGemm<K> {
    fn name(&self) -> StaticName {
        format!("Metal{}", self.kernel).into()
    }

    fn info(&self) -> TractResult<Vec<String>> {
        Ok(vec![
            format!("transpose_a: {} transpose_b: {}", self.transpose_a(), self.transpose_b(),),
        ])
    }

    op_as_typed_op!();
}

impl<K: GemmKernel> MetalGemm<K> {
    fn transpose_a(&self) -> bool {
        self.kernel.transpose_a
    }

    fn transpose_b(&self) -> bool {
        self.kernel.transpose_b
    }

    fn resolve_output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let [a, b] = inputs else {
            bail!("Expects 2 inputs");
        };

        if a.is_plain() && b.is_plain() && a.datum_type.is_number() && b.datum_type.is_number() {
            ensure!(a.rank() == b.rank());
            ensure!(a.rank() >= 2);
            ensure!(
                a.shape[a.rank() - 2 + !self.transpose_a() as usize]
                    == b.shape[b.rank() - 2 + self.transpose_b() as usize]
            );
            let out_shape = self.kernel.output_shape(&a.shape, &b.shape);
            Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?)
        } else if as_quant_fact(inputs[0], &Q4_0).is_some()
            || as_quant_fact(inputs[1], &Q4_0).is_some()
        {
            // Exotic tensors now carry their full logical shape directly on the
            // fact, so no need to chain with BlockQuantFact.shape().
            let out_shape = self.kernel.output_shape(&a.shape, &b.shape);
            Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?)
        } else {
            todo!()
        }
    }
}
impl<K: GemmKernel + 'static> EvalOp for MetalGemm<K> {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval_with_session(
        &self,
        node_id: usize,
        session: &TurnState,
        inputs: TVec<TValue>,
    ) -> TractResult<TVec<TValue>> {
        let (a_raw, b_raw) = args_2!(inputs);
        let a = a_raw
            .to_device_tensor()
            .with_context(|| format!("A tensor is not a metal tensor: {:?}", a_raw))?;
        let b = b_raw
            .to_device_tensor()
            .with_context(|| format!("B tensor is not a metal tensor {:?}", b_raw))?;

        // For q40 weights the tensor shape already carries the full logical
        // dimensions [batch, n, k].  No need to chain with the fact shape.
        let b_shape = b.shape().to_vec();

        let c_dt = self.kernel.matmul.output_dt(a.datum_type(), b.datum_type())?;
        let c_shape = self.kernel.output_shape(a.shape(), &b_shape);
        let c = tract_gpu::session_handler::make_tensor_for_node(session, node_id, c_dt, &c_shape)?;

        crate::with_metal_stream(|stream| self.kernel.dispatch_eval(stream, a, b, &c))?;

        Ok(tvec![c.into_tensor().into_tvalue()])
    }
}

impl<K: GemmKernel + 'static> TypedOp for MetalGemm<K> {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        tract_gpu::utils::facts_to_device_facts(inputs, |input_facts| {
            self.resolve_output_facts(input_facts)
        })
        .with_context(|| format!("Error while computing output facts for {}", self.name()))
    }

    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
        tract_gpu::utils::get_device_facts(inputs, |input_facts| {
            let fma = self.resolve_output_facts(input_facts)?[0].shape.iter().product::<TDim>()
                * input_facts[0].shape.last().unwrap();
            if input_facts[0].datum_type == f16::datum_type() {
                Ok(tvec!((Cost::FMA(f16::datum_type()), fma)))
            } else {
                Ok(tvec!((Cost::FMA(f32::datum_type()), fma)))
            }
        })
        .with_context(|| format!("Error while computing cost for {:?}", self.name()))
    }

    as_op!();
}