use burn_backend::{DType, Shape};
use crate::ffi::{self, Id};
use crate::tensor::{elem_size, MpsGraphTensor};
use crate::MpsGraphDevice;
pub fn burn_to_mps_dtype(dt: DType) -> u32 {
match dt {
DType::F32 | DType::Flex32 => ffi::MPSDataType::FLOAT32,
DType::F16 => ffi::MPSDataType::FLOAT16,
DType::BF16 => ffi::MPSDataType::BFLOAT16,
DType::I64 => ffi::MPSDataType::INT64,
DType::I32 => ffi::MPSDataType::INT32,
DType::I16 => ffi::MPSDataType::INT16,
DType::I8 => ffi::MPSDataType::INT8,
DType::U64 => ffi::MPSDataType::UINT64,
DType::U32 => ffi::MPSDataType::UINT32,
DType::U16 => ffi::MPSDataType::UINT16,
DType::U8 => ffi::MPSDataType::UINT8,
DType::Bool => ffi::MPSDataType::BOOL,
DType::F64 => panic!("F64 is not supported by Metal GPUs. Cast to F32 first."),
other => panic!("Unsupported dtype for MPSGraph: {:?}", other),
}
}
pub fn mps_to_burn_dtype(dt: u32) -> DType {
match dt {
ffi::MPSDataType::FLOAT32 => DType::F32,
ffi::MPSDataType::FLOAT16 => DType::F16,
ffi::MPSDataType::BFLOAT16 => DType::BF16,
ffi::MPSDataType::INT64 => DType::I64,
ffi::MPSDataType::INT32 => DType::I32,
ffi::MPSDataType::INT16 => DType::I16,
ffi::MPSDataType::INT8 => DType::I8,
ffi::MPSDataType::UINT64 => DType::U64,
ffi::MPSDataType::UINT32 => DType::U32,
ffi::MPSDataType::UINT16 => DType::U16,
ffi::MPSDataType::UINT8 => DType::U8,
ffi::MPSDataType::BOOL => DType::Bool,
_ => panic!("Unknown MPSDataType: {dt:#x}"),
}
}
pub fn shape_to_ns(shape: &Shape) -> Id {
let dims: Vec<usize> = (0..shape.num_dims()).map(|i| shape[i]).collect();
unsafe { ffi::ns_usize_array(&dims) }
}
pub fn ns_to_shape(ns_shape: Id) -> Shape {
unsafe {
let n = ffi::ns_array_count(ns_shape);
let dims: Vec<usize> = (0..n)
.map(|i| ffi::ns_number_to_usize(ffi::ns_array_get(ns_shape, i)))
.collect();
Shape::from(dims)
}
}
pub fn tensor_from_bytes(bytes: &[u8], shape: Shape, dtype: DType, device: MpsGraphDevice) -> MpsGraphTensor {
let buf = unsafe { ffi::mtl_buffer_from_bytes(bytes) };
MpsGraphTensor { buffer: buf, shape, dtype, device }
}
pub fn tensor_zeros(shape: Shape, dtype: DType, device: MpsGraphDevice) -> MpsGraphTensor {
let nbytes = shape.num_elements() * elem_size(dtype);
let buf = unsafe { ffi::mtl_buffer_zeroed(nbytes) };
MpsGraphTensor { buffer: buf, shape, dtype, device }
}
pub fn tensor_to_bytes(tensor: &MpsGraphTensor) -> Vec<u8> {
let nbytes = tensor.shape.num_elements() * elem_size(tensor.dtype);
let ptr = unsafe { ffi::mtl_buffer_contents(tensor.buffer) } as *const u8;
unsafe { std::slice::from_raw_parts(ptr, nbytes).to_vec() }
}
pub fn slices_to_ns(slices: &[burn_std::Slice], shape: &Shape) -> (Id, Id, Id) {
let starts: Vec<isize> = slices.iter().map(|s| s.start).collect();
let ends: Vec<isize> = slices.iter().enumerate()
.map(|(i, s)| s.end.unwrap_or(shape[i] as isize))
.collect();
let strides: Vec<isize> = slices.iter().map(|s| s.step).collect();
unsafe {
(ffi::ns_isize_array(&starts),
ffi::ns_isize_array(&ends),
ffi::ns_isize_array(&strides))
}
}
fn make_tensor_data(tensor: &MpsGraphTensor) -> Id {
let ns_shape = shape_to_ns(&tensor.shape);
let mps_dt = burn_to_mps_dtype(tensor.dtype);
unsafe { ffi::tensor_data_from_buffer(tensor.buffer, ns_shape, mps_dt) }
}
fn run_graph(
graph: Id,
inputs: &[(&MpsGraphTensor, Id)],
output: Id,
device: MpsGraphDevice,
) -> MpsGraphTensor {
unsafe {
let pool = ffi::objc_autoreleasePoolPush();
let mut keys = Vec::with_capacity(inputs.len());
let mut vals = Vec::with_capacity(inputs.len());
for (tensor, placeholder) in inputs {
keys.push(*placeholder);
vals.push(make_tensor_data(tensor));
}
let feeds = ffi::ns_dictionary(&keys, &vals);
let targets = ffi::ns_array(&[output]);
let results = ffi::graph_run(graph, feeds, targets);
let result_td = ffi::ns_dict_get(results, output);
assert!(!result_td.is_null(), "MPSGraph run returned no result");
let out_shape = ns_to_shape(ffi::tensor_data_shape(result_td));
let out_dtype = mps_to_burn_dtype(ffi::tensor_data_dtype(result_td));
let nbytes = out_shape.num_elements() * elem_size(out_dtype);
let buf = ffi::mtl_buffer_zeroed(nbytes);
ffi::tensor_data_read_bytes(
result_td,
std::slice::from_raw_parts_mut(ffi::mtl_buffer_contents(buf) as *mut u8, nbytes),
);
ffi::release(graph);
ffi::objc_autoreleasePoolPop(pool);
MpsGraphTensor { buffer: buf, shape: out_shape, dtype: out_dtype, device }
}
}
pub fn run_unary(
t: &MpsGraphTensor,
build_fn: unsafe fn(Id, Id) -> Id,
) -> MpsGraphTensor {
unsafe {
let g = ffi::mpsgraph_new();
let ph = ffi::graph_placeholder(g, shape_to_ns(&t.shape), burn_to_mps_dtype(t.dtype));
let out = build_fn(g, ph);
run_graph(g, &[(t, ph)], out, t.device)
}
}
pub fn run_binary(
a: &MpsGraphTensor,
b: &MpsGraphTensor,
build_fn: unsafe fn(Id, Id, Id) -> Id,
) -> MpsGraphTensor {
unsafe {
let g = ffi::mpsgraph_new();
let pa = ffi::graph_placeholder(g, shape_to_ns(&a.shape), burn_to_mps_dtype(a.dtype));
let pb = ffi::graph_placeholder(g, shape_to_ns(&b.shape), burn_to_mps_dtype(b.dtype));
let out = build_fn(g, pa, pb);
run_graph(g, &[(a, pa), (b, pb)], out, a.device)
}
}
pub fn run_ternary(
a: &MpsGraphTensor, b: &MpsGraphTensor, c: &MpsGraphTensor,
build_fn: unsafe fn(Id, Id, Id, Id) -> Id,
) -> MpsGraphTensor {
unsafe {
let g = ffi::mpsgraph_new();
let pa = ffi::graph_placeholder(g, shape_to_ns(&a.shape), burn_to_mps_dtype(a.dtype));
let pb = ffi::graph_placeholder(g, shape_to_ns(&b.shape), burn_to_mps_dtype(b.dtype));
let pc = ffi::graph_placeholder(g, shape_to_ns(&c.shape), burn_to_mps_dtype(c.dtype));
let out = build_fn(g, pa, pb, pc);
run_graph(g, &[(a, pa), (b, pb), (c, pc)], out, a.device)
}
}
pub fn run_unary_ctx<F>(t: &MpsGraphTensor, f: F) -> MpsGraphTensor
where F: FnOnce(Id, Id) -> Id {
unsafe {
let g = ffi::mpsgraph_new();
let ph = ffi::graph_placeholder(g, shape_to_ns(&t.shape), burn_to_mps_dtype(t.dtype));
let out = f(g, ph);
run_graph(g, &[(t, ph)], out, t.device)
}
}
pub fn run_binary_ctx<F>(a: &MpsGraphTensor, b: &MpsGraphTensor, f: F) -> MpsGraphTensor
where F: FnOnce(Id, Id, Id) -> Id {
unsafe {
let g = ffi::mpsgraph_new();
let pa = ffi::graph_placeholder(g, shape_to_ns(&a.shape), burn_to_mps_dtype(a.dtype));
let pb = ffi::graph_placeholder(g, shape_to_ns(&b.shape), burn_to_mps_dtype(b.dtype));
let out = f(g, pa, pb);
run_graph(g, &[(a, pa), (b, pb)], out, a.device)
}
}
pub fn run_multi_ctx<F>(inputs: &[&MpsGraphTensor], device: MpsGraphDevice, f: F) -> MpsGraphTensor
where F: FnOnce(Id, &[Id]) -> Id {
unsafe {
let g = ffi::mpsgraph_new();
let phs: Vec<Id> = inputs.iter()
.map(|t| ffi::graph_placeholder(g, shape_to_ns(&t.shape), burn_to_mps_dtype(t.dtype)))
.collect();
let out = f(g, &phs);
let pairs: Vec<(&MpsGraphTensor, Id)> = inputs.iter().zip(phs.iter())
.map(|(t, &p)| (*t, p))
.collect();
run_graph(g, &pairs, out, device)
}
}
pub fn run_unary_two_outputs(
t: &MpsGraphTensor,
f: impl FnOnce(Id, Id) -> (Id, Id),
) -> (MpsGraphTensor, MpsGraphTensor) {
unsafe {
let pool = ffi::objc_autoreleasePoolPush();
let g = ffi::mpsgraph_new();
let ph = ffi::graph_placeholder(g, shape_to_ns(&t.shape), burn_to_mps_dtype(t.dtype));
let (out1, out2) = f(g, ph);
let td = make_tensor_data(t);
let feeds = ffi::ns_dictionary(&[ph], &[td]);
let targets = ffi::ns_array(&[out1, out2]);
let results = ffi::graph_run(g, feeds, targets);
let extract = |key: Id| -> MpsGraphTensor {
let rtd = ffi::ns_dict_get(results, key);
let s = ns_to_shape(ffi::tensor_data_shape(rtd));
let dt = mps_to_burn_dtype(ffi::tensor_data_dtype(rtd));
let nb = s.num_elements() * elem_size(dt);
let buf = ffi::mtl_buffer_zeroed(nb);
ffi::tensor_data_read_bytes(
rtd,
std::slice::from_raw_parts_mut(ffi::mtl_buffer_contents(buf) as *mut u8, nb),
);
MpsGraphTensor { buffer: buf, shape: s, dtype: dt, device: t.device }
};
let r1 = extract(out1);
let r2 = extract(out2);
ffi::release(g);
ffi::objc_autoreleasePoolPop(pool);
(r1, r2)
}
}