mtl-rs 0.1.11

Rust bindings for Apple's Metal API
use core::ops::Range;

use objc2::{
    Message, extern_class, extern_conformance, extern_methods, extern_protocol, msg_send,
    rc::{Allocated, Retained},
    runtime::ProtocolObject,
};
use objc2_foundation::{CopyingHelper, NSArray, NSCopying, NSObject, NSObjectProtocol, NSRange, NSString};

use crate::*;

extern_class!(
    /// Description for a machine learning pipeline state.
    ///
    /// See also [Apple's documentation](https://developer.apple.com/documentation/metal/mtl4machinelearningpipelinedescriptor?language=objc)
    #[unsafe(super(MTL4PipelineDescriptor, NSObject))]
    #[derive(Debug, PartialEq, Eq, Hash)]
    pub struct MTL4MachineLearningPipelineDescriptor;
);

extern_conformance!(
    unsafe impl NSCopying for MTL4MachineLearningPipelineDescriptor {}
);

unsafe impl CopyingHelper for MTL4MachineLearningPipelineDescriptor {
    type Result = Self;
}

extern_conformance!(
    unsafe impl NSObjectProtocol for MTL4MachineLearningPipelineDescriptor {}
);

impl MTL4MachineLearningPipelineDescriptor {
    extern_methods!(
        /// Assigns the function that the machine learning pipeline you create from this descriptor executes.
        #[unsafe(method(machineLearningFunctionDescriptor))]
        #[unsafe(method_family = none)]
        pub fn machine_learning_function_descriptor(&self) -> Option<Retained<MTL4FunctionDescriptor>>;

        /// Setter for [`machineLearningFunctionDescriptor`][Self::machineLearningFunctionDescriptor].
        ///
        /// This is [copied][objc2_foundation::NSCopying::copy] when set.
        #[unsafe(method(setMachineLearningFunctionDescriptor:))]
        #[unsafe(method_family = none)]
        pub fn set_machine_learning_function_descriptor(
            &self,
            machine_learning_function_descriptor: Option<&MTL4FunctionDescriptor>,
        );

        /// Sets the dimension of an input tensor at a buffer index.
        ///
        /// - Parameters:
        /// - dimensions: the dimensions of the tensor.
        /// - bufferIndex: Index of the tensor to modify.
        #[unsafe(method(setInputDimensions:atBufferIndex:))]
        #[unsafe(method_family = none)]
        pub fn set_input_dimensions_at_buffer_index(
            &self,
            dimensions: Option<&MTLTensorExtents>,
            buffer_index: isize,
        );

        /// Obtains the dimensions of the input tensor at `bufferIndex` if set, `nil` otherwise.
        #[unsafe(method(inputDimensionsAtBufferIndex:))]
        #[unsafe(method_family = none)]
        pub fn input_dimensions_at_buffer_index(
            &self,
            buffer_index: isize,
        ) -> Option<Retained<MTLTensorExtents>>;

        /// Resets the descriptor to its default values.
        #[unsafe(method(reset))]
        #[unsafe(method_family = none)]
        pub fn reset(&self);
    );
}

/// Methods declared on superclass `NSObject`.
impl MTL4MachineLearningPipelineDescriptor {
    extern_methods!(
        #[unsafe(method(init))]
        #[unsafe(method_family = init)]
        pub fn init(this: Allocated<Self>) -> Retained<Self>;

        #[unsafe(method(new))]
        #[unsafe(method_family = new)]
        pub fn new() -> Retained<Self>;
    );
}

impl MTL4MachineLearningPipelineDescriptor {
    /// Assigns an optional string that helps identify pipeline states you create from this descriptor.
    pub fn label(&self) -> Option<String> {
        let s: Option<Retained<NSString>> = unsafe { msg_send![self, label] };
        s.map(|v| v.to_string())
    }

    /// Setter for [`label`][Self::label].
    pub fn set_label(
        &self,
        label: Option<&str>,
    ) {
        unsafe {
            let _: () = msg_send![self, setLabel: label.map(NSString::from_str).as_deref()];
        }
    }

    /// Sets the dimensions of multiple input tensors on a range of buffer bindings.
    ///
    /// Use this method to specify the dimensions of multiple input tensors at a range of indices in a single call.
    ///
    /// You can indicate that any tensors in the range have unspecified dimensions by providing `NSNull` at the their
    /// corresponding index location in the array.
    ///
    /// - Important: The range's length property needs to match the number of dimensions you provide. Specifically,
    /// `range.length` needs to match `dimensions.count`.
    ///
    /// - Parameters:
    /// - dimensions: An array of tensor extents.
    /// - range: The range of inputs of the `dimensions` argument.
    /// The range's `length` needs to match the dimensions' `count` property.
    pub fn set_input_dimensions_with_range(
        &self,
        dimensions: &[&MTLTensorExtents],
        range: Range<usize>,
    ) {
        let dimensions = NSArray::from_slice(dimensions);
        let ns_range = NSRange::from(range);
        unsafe {
            let _: () = msg_send![self, setInputDimensions: &*dimensions, withRange: ns_range];
        }
    }
}

extern_class!(
    /// Represents reflection information for a machine learning pipeline state.
    ///
    /// See also [Apple's documentation](https://developer.apple.com/documentation/metal/mtl4machinelearningpipelinereflection?language=objc)
    #[unsafe(super(NSObject))]
    #[derive(Debug, PartialEq, Eq, Hash)]
    pub struct MTL4MachineLearningPipelineReflection;
);

unsafe impl Send for MTL4MachineLearningPipelineReflection {}

unsafe impl Sync for MTL4MachineLearningPipelineReflection {}

extern_conformance!(
    unsafe impl NSObjectProtocol for MTL4MachineLearningPipelineReflection {}
);

impl MTL4MachineLearningPipelineReflection {
    extern_methods!();

    /// Describes every input and output of the pipeline.
    pub fn bindings(&self) -> Box<[Retained<ProtocolObject<dyn MTLBinding>>]> {
        let bindings: Retained<NSArray<ProtocolObject<dyn MTLBinding>>> = unsafe { msg_send![self, bindings] };
        bindings.to_vec().into_boxed_slice()
    }
}

/// Methods declared on superclass `NSObject`.
impl MTL4MachineLearningPipelineReflection {
    extern_methods!(
        #[unsafe(method(init))]
        #[unsafe(method_family = init)]
        pub fn init(this: Allocated<Self>) -> Retained<Self>;

        #[unsafe(method(new))]
        #[unsafe(method_family = new)]
        pub fn new() -> Retained<Self>;
    );
}

extern_protocol!(
    /// A pipeline state that you can use with machine-learning encoder instances.
    ///
    /// See ``MTL4MachineLearningCommandEncoder`` for more information.
    ///
    /// See also [Apple's documentation](https://developer.apple.com/documentation/metal/mtl4machinelearningpipelinestate?language=objc)
    pub unsafe trait MTL4MachineLearningPipelineState: MTLAllocation + NSObjectProtocol + Send + Sync {
        /// Returns the device the pipeline state belongs to.
        #[unsafe(method(device))]
        #[unsafe(method_family = none)]
        fn device(&self) -> Retained<ProtocolObject<dyn MTLDevice>>;

        /// Returns reflection information for this machine learning pipeline state.
        #[unsafe(method(reflection))]
        #[unsafe(method_family = none)]
        fn reflection(&self) -> Option<Retained<MTL4MachineLearningPipelineReflection>>;

        /// Obtain the size of the heap, in bytes, this pipeline requires during the execution.
        ///
        /// Use this value to allocate a ``MTLHeap`` instance of sufficient size that you can then provide to
        /// ``MTL4MachineLearningCommandEncoder/dispatchNetworkWithIntermediatesHeap:``.
        ///
        /// Metal uses this heap to store intermediate data as it executes the pipeline. It is your responsibility to provide
        /// a heap at least as large as this property requests.
        #[unsafe(method(intermediatesHeapSize))]
        #[unsafe(method_family = none)]
        fn intermediates_heap_size(&self) -> usize;
    }
);

pub trait MTL4MachineLearningPipelineStateExt: MTL4MachineLearningPipelineState + Message {
    /// Queries the string that helps identify this object.
    fn label(&self) -> Option<String> {
        let s: Option<Retained<NSString>> = unsafe { msg_send![self, label] };
        s.map(|v| v.to_string())
    }
}

impl<T: MTL4MachineLearningPipelineState + Message> MTL4MachineLearningPipelineStateExt for T {}