burn-mpsgraph 0.0.1

Apple MPSGraph backend for the Burn deep learning framework
Documentation
//! GPU-resident tensor management and MPSGraph execution.
//!
//! Key design: tensors live as MTLBuffers on the GPU.  Data only moves to the
//! host when the user calls `*_into_data`.  Each operation builds a small graph,
//! runs it inside an autorelease pool so all temporary ObjC objects are freed,
//! and the result is a new MTLBuffer.

use burn_backend::{DType, Shape};

use crate::ffi::{self, Id};
use crate::tensor::{elem_size, MpsGraphTensor};
use crate::MpsGraphDevice;

// ─── DType conversion ───────────────────────────────────────────────────────

pub fn burn_to_mps_dtype(dt: DType) -> u32 {
    match dt {
        DType::F32 | DType::Flex32 => ffi::MPSDataType::FLOAT32,
        DType::F16              => ffi::MPSDataType::FLOAT16,
        DType::BF16             => ffi::MPSDataType::BFLOAT16,
        DType::I64              => ffi::MPSDataType::INT64,
        DType::I32              => ffi::MPSDataType::INT32,
        DType::I16              => ffi::MPSDataType::INT16,
        DType::I8               => ffi::MPSDataType::INT8,
        DType::U64              => ffi::MPSDataType::UINT64,
        DType::U32              => ffi::MPSDataType::UINT32,
        DType::U16              => ffi::MPSDataType::UINT16,
        DType::U8               => ffi::MPSDataType::UINT8,
        DType::Bool             => ffi::MPSDataType::BOOL,
        DType::F64 => panic!("F64 is not supported by Metal GPUs. Cast to F32 first."),
        other => panic!("Unsupported dtype for MPSGraph: {:?}", other),
    }
}

pub fn mps_to_burn_dtype(dt: u32) -> DType {
    match dt {
        ffi::MPSDataType::FLOAT32  => DType::F32,
        ffi::MPSDataType::FLOAT16  => DType::F16,
        ffi::MPSDataType::BFLOAT16 => DType::BF16,
        ffi::MPSDataType::INT64    => DType::I64,
        ffi::MPSDataType::INT32    => DType::I32,
        ffi::MPSDataType::INT16    => DType::I16,
        ffi::MPSDataType::INT8     => DType::I8,
        ffi::MPSDataType::UINT64   => DType::U64,
        ffi::MPSDataType::UINT32   => DType::U32,
        ffi::MPSDataType::UINT16   => DType::U16,
        ffi::MPSDataType::UINT8    => DType::U8,
        ffi::MPSDataType::BOOL     => DType::Bool,
        _ => panic!("Unknown MPSDataType: {dt:#x}"),
    }
}

// ─── Shape conversion ───────────────────────────────────────────────────────

pub fn shape_to_ns(shape: &Shape) -> Id {
    let dims: Vec<usize> = (0..shape.num_dims()).map(|i| shape[i]).collect();
    unsafe { ffi::ns_usize_array(&dims) }
}

pub fn ns_to_shape(ns_shape: Id) -> Shape {
    unsafe {
        let n = ffi::ns_array_count(ns_shape);
        let dims: Vec<usize> = (0..n)
            .map(|i| ffi::ns_number_to_usize(ffi::ns_array_get(ns_shape, i)))
            .collect();
        Shape::from(dims)
    }
}

// ─── MTLBuffer helpers ──────────────────────────────────────────────────────

/// Upload host bytes into an MTLBuffer-backed tensor.
pub fn tensor_from_bytes(bytes: &[u8], shape: Shape, dtype: DType, device: MpsGraphDevice) -> MpsGraphTensor {
    let buf = unsafe { ffi::mtl_buffer_from_bytes(bytes) };
    MpsGraphTensor { buffer: buf, shape, dtype, device }
}

/// Create a zero tensor.
pub fn tensor_zeros(shape: Shape, dtype: DType, device: MpsGraphDevice) -> MpsGraphTensor {
    let nbytes = shape.num_elements() * elem_size(dtype);
    let buf = unsafe { ffi::mtl_buffer_zeroed(nbytes) };
    MpsGraphTensor { buffer: buf, shape, dtype, device }
}

/// Read tensor bytes back to host.  Safe because MTLBuffer shared storage
/// is coherent after the command queue finishes (our runs are synchronous).
pub fn tensor_to_bytes(tensor: &MpsGraphTensor) -> Vec<u8> {
    let nbytes = tensor.shape.num_elements() * elem_size(tensor.dtype);
    let ptr = unsafe { ffi::mtl_buffer_contents(tensor.buffer) } as *const u8;
    unsafe { std::slice::from_raw_parts(ptr, nbytes).to_vec() }
}

// ─── Slice helpers ──────────────────────────────────────────────────────────

/// Convert Burn `Slice` array to NSArray triple (starts, ends, strides),
/// resolving `end: None` to the dimension size from `shape`.
pub fn slices_to_ns(slices: &[burn_std::Slice], shape: &Shape) -> (Id, Id, Id) {
    let starts: Vec<isize> = slices.iter().map(|s| s.start).collect();
    let ends: Vec<isize> = slices.iter().enumerate()
        .map(|(i, s)| s.end.unwrap_or(shape[i] as isize))
        .collect();
    let strides: Vec<isize> = slices.iter().map(|s| s.step).collect();
    unsafe {
        (ffi::ns_isize_array(&starts),
         ffi::ns_isize_array(&ends),
         ffi::ns_isize_array(&strides))
    }
}

// ─── Graph execution core ───────────────────────────────────────────────────
//
// Every graph run is wrapped in an ObjC autorelease pool so temporary objects
// (NSArray, NSDictionary, MPSGraphTensorData, the graph itself) are freed
// promptly instead of leaking until the next pool drain.

/// Create MPSGraphTensorData backed by the tensor's MTLBuffer (zero-copy).
fn make_tensor_data(tensor: &MpsGraphTensor) -> Id {
    let ns_shape = shape_to_ns(&tensor.shape);
    let mps_dt = burn_to_mps_dtype(tensor.dtype);
    unsafe { ffi::tensor_data_from_buffer(tensor.buffer, ns_shape, mps_dt) }
}

/// Run a graph, extract result into a fresh MTLBuffer.
///
/// All temporary ObjC allocations happen inside an autorelease pool.
fn run_graph(
    graph: Id,
    inputs: &[(&MpsGraphTensor, Id)],
    output: Id,
    device: MpsGraphDevice,
) -> MpsGraphTensor {
    unsafe {
        let pool = ffi::objc_autoreleasePoolPush();

        // Build feeds dictionary
        let mut keys = Vec::with_capacity(inputs.len());
        let mut vals = Vec::with_capacity(inputs.len());
        for (tensor, placeholder) in inputs {
            keys.push(*placeholder);
            vals.push(make_tensor_data(tensor));
        }
        let feeds = ffi::ns_dictionary(&keys, &vals);
        let targets = ffi::ns_array(&[output]);

        // Run synchronously
        let results = ffi::graph_run(graph, feeds, targets);

        // Extract result tensor data
        let result_td = ffi::ns_dict_get(results, output);
        assert!(!result_td.is_null(), "MPSGraph run returned no result");

        let out_shape = ns_to_shape(ffi::tensor_data_shape(result_td));
        let out_dtype = mps_to_burn_dtype(ffi::tensor_data_dtype(result_td));
        let nbytes = out_shape.num_elements() * elem_size(out_dtype);

        // Read directly into a new MTLBuffer (single copy, no temp Vec)
        let buf = ffi::mtl_buffer_zeroed(nbytes);
        ffi::tensor_data_read_bytes(
            result_td,
            std::slice::from_raw_parts_mut(ffi::mtl_buffer_contents(buf) as *mut u8, nbytes),
        );

        // Release the graph — we created it, we own it
        ffi::release(graph);

        ffi::objc_autoreleasePoolPop(pool);

        MpsGraphTensor { buffer: buf, shape: out_shape, dtype: out_dtype, device }
    }
}

// ─── Public graph execution helpers ─────────────────────────────────────────

pub fn run_unary(
    t: &MpsGraphTensor,
    build_fn: unsafe fn(Id, Id) -> Id,
) -> MpsGraphTensor {
    unsafe {
        let g = ffi::mpsgraph_new();
        let ph = ffi::graph_placeholder(g, shape_to_ns(&t.shape), burn_to_mps_dtype(t.dtype));
        let out = build_fn(g, ph);
        run_graph(g, &[(t, ph)], out, t.device)
    }
}

pub fn run_binary(
    a: &MpsGraphTensor,
    b: &MpsGraphTensor,
    build_fn: unsafe fn(Id, Id, Id) -> Id,
) -> MpsGraphTensor {
    unsafe {
        let g = ffi::mpsgraph_new();
        let pa = ffi::graph_placeholder(g, shape_to_ns(&a.shape), burn_to_mps_dtype(a.dtype));
        let pb = ffi::graph_placeholder(g, shape_to_ns(&b.shape), burn_to_mps_dtype(b.dtype));
        let out = build_fn(g, pa, pb);
        run_graph(g, &[(a, pa), (b, pb)], out, a.device)
    }
}

pub fn run_ternary(
    a: &MpsGraphTensor, b: &MpsGraphTensor, c: &MpsGraphTensor,
    build_fn: unsafe fn(Id, Id, Id, Id) -> Id,
) -> MpsGraphTensor {
    unsafe {
        let g = ffi::mpsgraph_new();
        let pa = ffi::graph_placeholder(g, shape_to_ns(&a.shape), burn_to_mps_dtype(a.dtype));
        let pb = ffi::graph_placeholder(g, shape_to_ns(&b.shape), burn_to_mps_dtype(b.dtype));
        let pc = ffi::graph_placeholder(g, shape_to_ns(&c.shape), burn_to_mps_dtype(c.dtype));
        let out = build_fn(g, pa, pb, pc);
        run_graph(g, &[(a, pa), (b, pb), (c, pc)], out, a.device)
    }
}

pub fn run_unary_ctx<F>(t: &MpsGraphTensor, f: F) -> MpsGraphTensor
where F: FnOnce(Id, Id) -> Id {
    unsafe {
        let g = ffi::mpsgraph_new();
        let ph = ffi::graph_placeholder(g, shape_to_ns(&t.shape), burn_to_mps_dtype(t.dtype));
        let out = f(g, ph);
        run_graph(g, &[(t, ph)], out, t.device)
    }
}

pub fn run_binary_ctx<F>(a: &MpsGraphTensor, b: &MpsGraphTensor, f: F) -> MpsGraphTensor
where F: FnOnce(Id, Id, Id) -> Id {
    unsafe {
        let g = ffi::mpsgraph_new();
        let pa = ffi::graph_placeholder(g, shape_to_ns(&a.shape), burn_to_mps_dtype(a.dtype));
        let pb = ffi::graph_placeholder(g, shape_to_ns(&b.shape), burn_to_mps_dtype(b.dtype));
        let out = f(g, pa, pb);
        run_graph(g, &[(a, pa), (b, pb)], out, a.device)
    }
}

pub fn run_multi_ctx<F>(inputs: &[&MpsGraphTensor], device: MpsGraphDevice, f: F) -> MpsGraphTensor
where F: FnOnce(Id, &[Id]) -> Id {
    unsafe {
        let g = ffi::mpsgraph_new();
        let phs: Vec<Id> = inputs.iter()
            .map(|t| ffi::graph_placeholder(g, shape_to_ns(&t.shape), burn_to_mps_dtype(t.dtype)))
            .collect();
        let out = f(g, &phs);
        let pairs: Vec<(&MpsGraphTensor, Id)> = inputs.iter().zip(phs.iter())
            .map(|(t, &p)| (*t, p))
            .collect();
        run_graph(g, &pairs, out, device)
    }
}

pub fn run_unary_two_outputs(
    t: &MpsGraphTensor,
    f: impl FnOnce(Id, Id) -> (Id, Id),
) -> (MpsGraphTensor, MpsGraphTensor) {
    unsafe {
        let pool = ffi::objc_autoreleasePoolPush();

        let g = ffi::mpsgraph_new();
        let ph = ffi::graph_placeholder(g, shape_to_ns(&t.shape), burn_to_mps_dtype(t.dtype));
        let (out1, out2) = f(g, ph);

        let td = make_tensor_data(t);
        let feeds = ffi::ns_dictionary(&[ph], &[td]);
        let targets = ffi::ns_array(&[out1, out2]);
        let results = ffi::graph_run(g, feeds, targets);

        let extract = |key: Id| -> MpsGraphTensor {
            let rtd = ffi::ns_dict_get(results, key);
            let s = ns_to_shape(ffi::tensor_data_shape(rtd));
            let dt = mps_to_burn_dtype(ffi::tensor_data_dtype(rtd));
            let nb = s.num_elements() * elem_size(dt);
            let buf = ffi::mtl_buffer_zeroed(nb);
            ffi::tensor_data_read_bytes(
                rtd,
                std::slice::from_raw_parts_mut(ffi::mtl_buffer_contents(buf) as *mut u8, nb),
            );
            MpsGraphTensor { buffer: buf, shape: s, dtype: dt, device: t.device }
        };

        let r1 = extract(out1);
        let r2 = extract(out2);

        ffi::release(g);
        ffi::objc_autoreleasePoolPop(pool);

        (r1, r2)
    }
}