apple-mpsgraph 0.2.8

Safe Rust bindings for Apple's MetalPerformanceShadersGraph framework on macOS, backed by a Swift bridge
Documentation
import Foundation
import Metal
import MetalPerformanceShaders
import MetalPerformanceShadersGraph

@_cdecl("mpsgraph_tensor_data_new_with_bytes")
public func mpsgraph_tensor_data_new_with_bytes(
    _ deviceHandle: UnsafeMutableRawPointer?,
    _ bytes: UnsafeRawPointer?,
    _ byteLen: Int,
    _ shape: UnsafePointer<UInt>?,
    _ shapeLen: Int,
    _ dataTypeRaw: UInt32
) -> UnsafeMutableRawPointer? {
    guard let deviceHandle, let dataType = mpsgraph_data_type(dataTypeRaw) else {
        return nil
    }
    guard byteLen == 0 || bytes != nil else {
        return nil
    }

    let tensorData = MPSGraphTensorData(
        device: mpsgraph_graph_device(deviceHandle),
        data: mpsgraph_data(bytes, byteLen),
        shape: mpsgraph_shape(shape, shapeLen),
        dataType: dataType
    )
    return mpsgraph_retain(tensorData)
}

@_cdecl("mpsgraph_tensor_data_new_with_buffer")
public func mpsgraph_tensor_data_new_with_buffer(
    _ bufferHandle: UnsafeMutableRawPointer?,
    _ shape: UnsafePointer<UInt>?,
    _ shapeLen: Int,
    _ dataTypeRaw: UInt32
) -> UnsafeMutableRawPointer? {
    guard let bufferHandle, let dataType = mpsgraph_data_type(dataTypeRaw) else {
        return nil
    }

    let buffer: MTLBuffer = mpsgraph_borrow(bufferHandle)
    let tensorData = MPSGraphTensorData(buffer, shape: mpsgraph_shape(shape, shapeLen), dataType: dataType)
    return mpsgraph_retain(tensorData)
}

@_cdecl("mpsgraph_tensor_data_new_with_tensor")
public func mpsgraph_tensor_data_new_with_tensor(
    _ tensorHandle: UnsafeMutableRawPointer?
) -> UnsafeMutableRawPointer? {
    guard #available(macOS 16.0, *), let tensorHandle else {
        return nil
    }

    let tensor: MTLTensor = mpsgraph_borrow(tensorHandle)
    let tensorData = MPSGraphTensorData(tensor)
    return mpsgraph_retain(tensorData)
}

@_cdecl("mpsgraph_tensor_data_data_type")
public func mpsgraph_tensor_data_data_type(_ handle: UnsafeMutableRawPointer?) -> UInt32 {
    guard let handle else {
        return 0
    }
    let tensorData: MPSGraphTensorData = mpsgraph_borrow(handle)
    return tensorData.dataType.rawValue
}

@_cdecl("mpsgraph_tensor_data_shape_len")
public func mpsgraph_tensor_data_shape_len(_ handle: UnsafeMutableRawPointer?) -> Int {
    guard let handle else {
        return 0
    }
    let tensorData: MPSGraphTensorData = mpsgraph_borrow(handle)
    return tensorData.shape.count
}

@_cdecl("mpsgraph_tensor_data_copy_shape")
public func mpsgraph_tensor_data_copy_shape(
    _ handle: UnsafeMutableRawPointer?,
    _ outShape: UnsafeMutablePointer<UInt>?
) {
    guard let handle, let outShape else {
        return
    }

    let tensorData: MPSGraphTensorData = mpsgraph_borrow(handle)
    for (index, value) in tensorData.shape.enumerated() {
        outShape[index] = UInt(truncating: value)
    }
}

@_cdecl("mpsgraph_tensor_data_read_bytes")
public func mpsgraph_tensor_data_read_bytes(
    _ handle: UnsafeMutableRawPointer?,
    _ dst: UnsafeMutableRawPointer?,
    _ dstLen: Int
) -> Bool {
    guard let handle else {
        return false
    }
    if dstLen == 0 {
        return true
    }
    guard let dst else {
        return false
    }

    let tensorData: MPSGraphTensorData = mpsgraph_borrow(handle)
    tensorData.mpsndarray().readBytes(dst, strideBytes: nil)
    return true
}

@_cdecl("mpsgraph_tensor_data_device")
public func mpsgraph_tensor_data_device(_ handle: UnsafeMutableRawPointer?) -> UnsafeMutableRawPointer? {
    guard let handle else {
        return nil
    }
    let tensorData: MPSGraphTensorData = mpsgraph_borrow(handle)
    return mpsgraph_retain(tensorData.device)
}