use crate::kernels::matmul::{GemmImpl, GemmKernel};
use anyhow::{bail, ensure};
use tract_core::internal::*;
use tract_core::tract_linalg::block_quant::Q4_0;
use tract_gpu::tensor::DeviceTensorExt;
use tract_gpu::utils::as_quant_fact;
#[derive(Debug, Default, Clone, Hash, PartialEq, Eq)]
pub struct MetalGemm<K: GemmKernel> {
pub kernel: GemmImpl<K>,
}
impl<K: GemmKernel> MetalGemm<K> {
pub fn new(transpose_a: bool, transpose_b: bool) -> Self {
Self { kernel: GemmImpl::<K>::new(transpose_a, transpose_b) }
}
}
impl<K: GemmKernel + 'static> Op for MetalGemm<K> {
fn name(&self) -> StaticName {
format!("Metal{}", self.kernel).into()
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![
format!("transpose_a: {} transpose_b: {}", self.transpose_a(), self.transpose_b(),),
])
}
op_as_typed_op!();
}
impl<K: GemmKernel> MetalGemm<K> {
fn transpose_a(&self) -> bool {
self.kernel.transpose_a
}
fn transpose_b(&self) -> bool {
self.kernel.transpose_b
}
fn resolve_output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let [a, b] = inputs else {
bail!("Expects 2 inputs");
};
if a.is_plain() && b.is_plain() && a.datum_type.is_number() && b.datum_type.is_number() {
ensure!(a.rank() == b.rank());
ensure!(a.rank() >= 2);
ensure!(
a.shape[a.rank() - 2 + !self.transpose_a() as usize]
== b.shape[b.rank() - 2 + self.transpose_b() as usize]
);
let out_shape = self.kernel.output_shape(&a.shape, &b.shape);
Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?)
} else if as_quant_fact(inputs[0], &Q4_0).is_some()
|| as_quant_fact(inputs[1], &Q4_0).is_some()
{
let out_shape = self.kernel.output_shape(&a.shape, &b.shape);
Ok(self.kernel.output_facts(&out_shape, a.datum_type, b.datum_type)?)
} else {
todo!()
}
}
}
impl<K: GemmKernel + 'static> EvalOp for MetalGemm<K> {
fn is_stateless(&self) -> bool {
true
}
fn eval_with_session(
&self,
node_id: usize,
session: &TurnState,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let (a_raw, b_raw) = args_2!(inputs);
let a = a_raw
.to_device_tensor()
.with_context(|| format!("A tensor is not a metal tensor: {:?}", a_raw))?;
let b = b_raw
.to_device_tensor()
.with_context(|| format!("B tensor is not a metal tensor {:?}", b_raw))?;
let b_shape = b.shape().to_vec();
let c_dt = self.kernel.matmul.output_dt(a.datum_type(), b.datum_type())?;
let c_shape = self.kernel.output_shape(a.shape(), &b_shape);
let c = tract_gpu::session_handler::make_tensor_for_node(session, node_id, c_dt, &c_shape)?;
crate::with_metal_stream(|stream| self.kernel.dispatch_eval(stream, a, b, &c))?;
Ok(tvec![c.into_tensor().into_tvalue()])
}
}
impl<K: GemmKernel + 'static> TypedOp for MetalGemm<K> {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
tract_gpu::utils::facts_to_device_facts(inputs, |input_facts| {
self.resolve_output_facts(input_facts)
})
.with_context(|| format!("Error while computing output facts for {}", self.name()))
}
fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
tract_gpu::utils::get_device_facts(inputs, |input_facts| {
let fma = self.resolve_output_facts(input_facts)?[0].shape.iter().product::<TDim>()
* input_facts[0].shape.last().unwrap();
if input_facts[0].datum_type == f16::datum_type() {
Ok(tvec!((Cost::FMA(f16::datum_type()), fma)))
} else {
Ok(tvec!((Cost::FMA(f32::datum_type()), fma)))
}
})
.with_context(|| format!("Error while computing cost for {:?}", self.name()))
}
as_op!();
}