mtl-rs 0.1.10

Rust bindings for Apple's Metal API
use objc2::{extern_protocol, runtime::ProtocolObject};

use crate::*;

extern_protocol!(
    /// Encodes commands for dispatching machine learning networks on Apple silicon.
    ///
    /// See also [Apple's documentation](https://developer.apple.com/documentation/metal/mtl4machinelearningcommandencoder?language=objc)
    pub unsafe trait MTL4MachineLearningCommandEncoder: MTL4CommandEncoder {
        /// Configures the encoder with a machine learning pipeline state instance.
        ///
        /// The pipeline state instance affects all subsequent Machine Learning commands.
        ///
        /// - Parameters:
        /// - pipelineState: A Machine Learning pipeline state instance.
        #[unsafe(method(setPipelineState:))]
        #[unsafe(method_family = none)]
        fn set_pipeline_state(
            &self,
            pipeline_state: &ProtocolObject<dyn MTL4MachineLearningPipelineState>,
        );

        /// Sets an argument table for the command encoder's machine learning shader stage.
        ///
        /// The argument table provides inputs to all subsequent Machine Learning dispatches.
        /// - Parameters:
        /// - argumentTable: An argument table to set on the command encoder's Machine Learning stage.
        #[unsafe(method(setArgumentTable:))]
        #[unsafe(method_family = none)]
        fn set_argument_table(
            &self,
            argument_table: &ProtocolObject<dyn MTL4ArgumentTable>,
        );

        /// Dispatches a machine learning network using the current pipeline state and argument table.
        ///
        /// This method takes a parameter consisting of a `MTLHeap` that Metal can use to allocate intermediate tensors.
        /// You can query the minimum size Metal requires for this heap by calling
        /// ``MTL4MachineLearningPipelineState/intermediatesHeapSize``.
        ///
        /// - Parameters:
        /// - heap: a heap that Metal can use to allocate intermediate tensors.
        #[unsafe(method(dispatchNetworkWithIntermediatesHeap:))]
        #[unsafe(method_family = none)]
        fn dispatch_network_with_intermediates_heap(
            &self,
            heap: &ProtocolObject<dyn MTLHeap>,
        );
    }
);