tract-metal 0.22.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
#![allow(unused)]

pub mod array;
mod bin_ops;
mod element_wise;
pub mod matmul;
pub mod nn;
mod utils;

pub use bin_ops::BinOps;
pub use element_wise::ElementWiseOps;

use tract_core::internal::*;

#[cfg(target_os = "ios")]
const METAL_FLASH_ATTENTION_LIB: &[u8] =
    include_bytes!("matmul/mfa/libMetalFlashAttention-ios.metallib");

#[cfg(target_os = "macos")]
const METAL_FLASH_ATTENTION_LIB: &[u8] =
    include_bytes!("matmul/mfa/libMetalFlashAttention-macos.metallib");

#[cfg(not(any(target_os = "ios", target_os = "macos")))]
const METAL_FLASH_ATTENTION_LIB: &[u8] = &[];

const MLX_GEMM: &str = include_str!("matmul/mlx_gemm/mlx_gemm.metal");
const MLX_GEMV: &str = include_str!("matmul/mlx_gemm/mlx_gemv.metal");
const GGML: &str = include_str!("matmul/ggml_gemm/ggml_mm_mv.metal");
const BASIC_MAT_MUL: &str = include_str!("matmul/basic/basic_mat_mul.metal");
const ARRAY_OPS: &str = include_str!("array/array_ops.metal");
const BIN_OPS: &str = include_str!("bin_ops.metal");
const NN_OPS: &str = include_str!("nn/nn_ops.metal");
const ELEMENT_WISE_OPS: &str = include_str!("element_wise.metal");

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum LibraryContent<'a> {
    Data(&'a [u8]),
    Source(&'a str),
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LibraryName {
    MlxGemm,
    MlxGemv,
    MfaLib,
    BasicMatMul,
    BinOps,
    ArrayOps,
    NNOps,
    ElementWiseOps,
    Ggml,
}

impl LibraryName {
    pub fn content(&self) -> LibraryContent<'static> {
        match self {
            Self::MfaLib => LibraryContent::Data(METAL_FLASH_ATTENTION_LIB),
            Self::BasicMatMul => LibraryContent::Source(BASIC_MAT_MUL),
            Self::ArrayOps => LibraryContent::Source(ARRAY_OPS),
            Self::BinOps => LibraryContent::Source(BIN_OPS),
            Self::NNOps => LibraryContent::Source(NN_OPS),
            Self::ElementWiseOps => LibraryContent::Source(ELEMENT_WISE_OPS),
            Self::MlxGemm => LibraryContent::Source(MLX_GEMM),
            Self::MlxGemv => LibraryContent::Source(MLX_GEMV),
            Self::Ggml => LibraryContent::Source(GGML),
        }
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BroadcastKind {
    Unicast,
    ByScalarLeft,
    ByScalarRight,
    Nd1,
    Nd2,
    Nd3,
    Nd4,
    Nd5,
    Nd6,
}

impl BroadcastKind {
    pub fn from_rank(rank: usize) -> TractResult<Self> {
        match rank {
            1 => Ok(Self::Nd1),
            2 => Ok(Self::Nd2),
            3 => Ok(Self::Nd3),
            4 => Ok(Self::Nd4),
            5 => Ok(Self::Nd5),
            6 => Ok(Self::Nd6),
            _ => bail!("Unsupported rank {rank} for broadcasting"),
        }
    }

    pub fn name(&self) -> &'static str {
        match self {
            BroadcastKind::Unicast => "unicast",
            BroadcastKind::ByScalarLeft => "by_scalar_lhs",
            BroadcastKind::ByScalarRight => "by_scalar_rhs",
            BroadcastKind::Nd1 => "nd1",
            BroadcastKind::Nd2 => "nd2",
            BroadcastKind::Nd3 => "nd3",
            BroadcastKind::Nd4 => "nd4",
            BroadcastKind::Nd5 => "nd5",
            BroadcastKind::Nd6 => "nd6",
        }
    }
}

impl BroadcastKind {
    const ALL: [BroadcastKind; 8] = [
        Self::Unicast,
        Self::ByScalarLeft,
        Self::ByScalarRight,
        Self::Nd1,
        Self::Nd2,
        Self::Nd3,
        Self::Nd4,
        Self::Nd5,
    ];
}