tract-metal 0.22.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
use crate::encoder::EncoderExt;
use crate::kernels::{BroadcastKind, utils};
use crate::{LibraryName, MetalStream};
use anyhow::ensure;
use std::fmt;
use tract_core::internal::*;
use tract_gpu::tensor::DeviceTensor;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Concat {
    pub axis: usize,
}

impl fmt::Display for Concat {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{:?}", self)
    }
}

impl Concat {
    pub fn is_supported_dt(dt: DatumType) -> bool {
        matches!(
            dt,
            DatumType::F32
                | DatumType::F16
                | DatumType::U8
                | DatumType::U16
                | DatumType::U32
                | DatumType::U64
                | DatumType::I8
                | DatumType::I16
                | DatumType::I32
                | DatumType::I64
        )
    }

    pub fn kernel_name(&self, dt: DatumType, broadcast_kind: BroadcastKind) -> TractResult<String> {
        ensure!(Self::is_supported_dt(dt), "Unsupported dt {:?} for metal concatop", dt);
        let tname = DeviceTensor::tname(dt)?;
        let broadcast_name = broadcast_kind.name();
        Ok(format!("array_ops::copy_{broadcast_name}_{tname}"))
    }

    pub fn eval(
        &self,
        stream: &MetalStream,
        inputs: &[&DeviceTensor],
    ) -> TractResult<DeviceTensor> {
        ensure!(!inputs.is_empty());
        let mut output_shape = inputs[0].shape().to_vec();
        output_shape[self.axis] = inputs.iter().map(|it| it.shape()[self.axis]).sum();
        let output =
            unsafe { DeviceTensor::uninitialized_dt(inputs[0].datum_type(), &output_shape)? };

        self.dispatch_eval(stream, inputs, &output)?;
        stream.wait_until_completed()?;
        Ok(output)
    }

    pub fn dispatch_eval(
        &self,
        stream: &MetalStream,
        inputs: &[&DeviceTensor],
        output: &DeviceTensor,
    ) -> TractResult<()> {
        ensure!(!inputs.is_empty());

        stream.retain_tensor(output);

        let output_shape = output.shape();
        let output_strides = output.strides();

        let mut offsets = tvec![0; inputs.len()];
        let mut cursor = 0;

        for (i_idx, input) in inputs.iter().enumerate() {
            let i_shape = input.shape();
            ensure!(i_shape[..self.axis] == output_shape[..self.axis]);
            ensure!(i_shape[self.axis + 1..] == output_shape[self.axis + 1..]);
            offsets[i_idx] =
                cursor * (output_strides[self.axis] as usize) * output.datum_type().size_of();
            cursor += i_shape[self.axis];
        }

        let broadcast_kind = BroadcastKind::from_rank(output.rank()).with_context(|| {
            format!(
                "Unsupported broadcast for broadcast op: (in: {:?}, out: {:?})",
                inputs[0].shape(),
                output_shape
            )
        })?;

        let kernel_name = self.kernel_name(output.datum_type(), broadcast_kind)?;
        let pipeline = stream.load_pipeline(LibraryName::ArrayOps, &kernel_name)?;
        let command_buffer = stream.command_buffer();

        for (input, offset) in inputs.iter().zip(offsets.into_iter()) {
            if input.len() == 0 {
                continue;
            }
            stream.retain_tensor(input);

            let i_strides = input.strides();
            let i_shape = input.shape();

            command_buffer.encode(|encoder| {
                encoder.set_compute_pipeline_state(&pipeline);
                encoder.set_metal_tensor(0, input, metal::MTLResourceUsage::Read);
                encoder.set_slice(1, i_strides);
                encoder.set_metal_tensor_with_offset(
                    2,
                    output,
                    offset as _,
                    metal::MTLResourceUsage::Write,
                );
                encoder.set_slice(3, i_shape);
                encoder.set_slice(4, output_strides);

                let (grid_size, group_size) = utils::build_metal_grid_and_groups_for_el_wise_op(
                    i_shape,
                    pipeline.max_total_threads_per_threadgroup() as _,
                );
                encoder.dispatch_thread_groups(grid_size, group_size);
            });
        }

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use crate::utils::with_borrowed_metal_stream;

    use super::*;
    use tract_gpu::tensor::IntoDevice;
    use tract_itertools::Itertools;

    use num_traits::Zero;

    use tract_core::internal::Tensor;

    fn run_test_case<F: Datum + Zero + Copy>(shapes: &[&[usize]], axis: usize) -> TractResult<()> {
        with_borrowed_metal_stream(|stream| {
            let mut inputs = tvec![];
            for shape in shapes {
                let len = shape.iter().product::<usize>();
                let data = (0..len).map(|f| f as f32).collect::<Vec<_>>();
                inputs.push(Tensor::from_shape(shape, &data)?.into_device()?);
            }

            let output = Concat { axis }.eval(stream, &inputs.iter().collect_vec())?;
            let ref_output = Tensor::stack_tensors(
                axis,
                &inputs.iter().map(|it| it.to_host()).collect::<TractResult<Vec<_>>>()?,
            )?;
            assert_eq!(output.to_host()?.into_tensor(), ref_output);
            Ok(())
        })
    }

    #[test]
    fn test_concat() -> TractResult<()> {
        run_test_case::<f32>(&[&[3, 4], &[3, 4]], 0)?;
        run_test_case::<f32>(&[&[3, 4], &[3, 4]], 1)?;
        run_test_case::<f32>(&[&[1, 5, 4], &[2, 5, 4]], 0)?;
        run_test_case::<f32>(&[&[3, 1, 8], &[3, 2, 8]], 1)?;
        run_test_case::<f32>(&[&[3, 5, 1], &[3, 5, 2]], 2)?;
        run_test_case::<f32>(&[&[1, 5, 4, 10], &[2, 5, 4, 10]], 0)?;
        run_test_case::<f32>(&[&[3, 1, 8, 10], &[3, 2, 8, 10]], 1)?;
        run_test_case::<f32>(&[&[3, 5, 1, 10], &[3, 5, 2, 10]], 2)?;
        run_test_case::<f32>(&[&[3, 5, 1, 10], &[3, 5, 0, 10]], 2)?;
        run_test_case::<f32>(&[&[3, 5, 1, 2], &[3, 5, 1, 8]], 3)?;
        Ok(())
    }
}