tract-metal 0.23.0-dev.6

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
use crate::encoder::EncoderExt;
use crate::kernels::utils;
use crate::{LibraryName, MetalStream};
use anyhow::ensure;
use metal::MTLSize;
use std::fmt;
use tract_core::internal::*;
use tract_gpu::tensor::DeviceTensor;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RotateHalf;

impl fmt::Display for RotateHalf {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{:?}", self)
    }
}

impl RotateHalf {
    pub fn is_supported_dt(dt: DatumType) -> bool {
        matches!(
            dt,
            DatumType::F32
                | DatumType::F16
                | DatumType::I8
                | DatumType::I16
                | DatumType::I32
                | DatumType::I64
        )
    }

    pub fn kernel_name(&self, dt: DatumType) -> TractResult<String> {
        ensure!(Self::is_supported_dt(dt), "Unsupported dt {:?} for metal rotate halfop", dt);
        let tname = DeviceTensor::tname(dt)?;
        Ok(format!("array_ops::rotate_half_nd2_{tname}"))
    }

    pub fn eval(&self, stream: &MetalStream, input: &DeviceTensor) -> TractResult<DeviceTensor> {
        let output = unsafe { DeviceTensor::uninitialized_dt(input.datum_type(), input.shape())? };
        self.dispatch_eval(stream, input, &output)?;
        stream.wait_until_completed()?;
        Ok(output)
    }

    pub fn dispatch_eval(
        &self,
        stream: &MetalStream,
        input: &DeviceTensor,
        output: &DeviceTensor,
    ) -> TractResult<()> {
        stream.retain_tensor(input);
        stream.retain_tensor(output);

        let shape_nd2 = utils::reshape_to_rank_2(input.shape(), input.rank() - 1);
        ensure!(
            shape_nd2[1] % 2 == 0,
            "Rotate half required most inner dimension to be a multiple of 2: {:?}",
            input.shape()
        );
        let strides_nd2 = Tensor::natural_strides(&shape_nd2);

        let kernel_name = self.kernel_name(input.datum_type())?;

        let pipeline = stream.load_pipeline(LibraryName::ArrayOps, &kernel_name)?;
        let command_buffer = stream.command_buffer();
        command_buffer.encode(|encoder| {
            encoder.set_compute_pipeline_state(&pipeline);
            encoder.set_metal_tensor(0, input, metal::MTLResourceUsage::Read);
            encoder.set_metal_tensor(1, output, metal::MTLResourceUsage::Write);
            encoder.set_slice(2, &shape_nd2);
            encoder.set_slice(3, &strides_nd2);

            let grid_size =
                MTLSize { width: (shape_nd2[1] / 2) as _, height: shape_nd2[0] as _, depth: 1 };
            let group_size = utils::build_metal_size_with_ones();

            encoder.dispatch_thread_groups(grid_size, group_size);
        });
        Ok(())
    }
}

pub fn metal_rotate_half_dispatch(input: &DeviceTensor, output: &DeviceTensor) -> TractResult<()> {
    crate::with_metal_stream(|stream| RotateHalf.dispatch_eval(stream, input, output))
}

crate::register_metal_op!(tract_transformers::ops::apply_rope::RotateHalf, |source, node, _op| {
    rule_if!(RotateHalf::is_supported_dt(source.node_input_facts(node.id)?[0].datum_type));
    Ok(Some(Box::new(tract_gpu::ops::rotate_half::GpuRotateHalf::new(
        "Metal",
        metal_rotate_half_dispatch,
    ))))
});

#[cfg(test)]
mod tests {
    use crate::utils::with_borrowed_metal_stream;

    use super::*;
    use num_traits::AsPrimitive;
    use tract_core::internal::Tensor;
    use tract_gpu::tensor::IntoDevice;
    use tract_transformers::ops::apply_rope;

    fn run_test_case<F>(shape: &[usize]) -> TractResult<()>
    where
        F: Copy + 'static + Datum,
        usize: AsPrimitive<F>,
    {
        with_borrowed_metal_stream(|stream| {
            let len = shape.iter().product::<usize>();

            let a =
                Tensor::from_shape(shape, &(0..len).map(|f| -> F { f.as_() }).collect::<Vec<_>>())?;

            let metal_a = a.clone().into_device()?;

            let cpu_output =
                apply_rope::RotateHalf.eval(tvec![a.clone().into()])?[0].clone().into_tensor();
            let metal_output = RotateHalf.eval(stream, &metal_a)?;

            cpu_output
                .close_enough(&metal_output.to_host()?.into_tensor(), Approximation::Exact)
                .with_context(|| {
                format!(
                    "Input: {:?} Cpu: {:?}, Metal: {:?}",
                    a.dump(true),
                    cpu_output.dump(true),
                    metal_output.to_host().and_then(|it| it.dump(true))
                )
            })?;
            Ok(())
        })
    }

    #[test]
    fn test_rotate_half() -> TractResult<()> {
        run_test_case::<f32>(&[2, 2])?;
        run_test_case::<f32>(&[512, 512])?;
        run_test_case::<f32>(&[10, 8, 8])?;
        run_test_case::<f32>(&[10, 512, 1024])?;
        run_test_case::<f32>(&[10, 512, 1024])?;
        run_test_case::<f16>(&[10, 256, 4])?;
        Ok(())
    }
}