Expand description
§burn-mpsgraph
Apple Metal Performance Shaders Graph (MPSGraph) backend for the Burn deep learning framework.
This crate accelerates tensor operations on Apple GPUs (M1/M2/M3/M4 and
later) by dispatching to MPSGraph, Apple’s graph-based compute engine that
sits on top of Metal. It uses direct Objective-C FFI — no objc2
crate — keeping the dependency tree minimal and compile times fast.
§Quick start
use burn::prelude::*;
use burn_mpsgraph::prelude::*;
type B = MpsGraph;
let device = MpsGraphDevice::default();
let a: Tensor<B, 2> = Tensor::random([128, 64], burn::tensor::Distribution::Default, &device);
let b: Tensor<B, 2> = Tensor::random([64, 256], burn::tensor::Distribution::Default, &device);
let c = a.matmul(b); // runs on the Apple GPU
let data = c.into_data(); // copies result back to CPU
println!("shape: {:?}", data.shape); // [128, 256]§Architecture
Burn Tensor API
│
▼
FloatTensorOps / IntTensorOps / BoolTensorOps / ModuleOps
│
▼
bridge.rs ── builds an MPSGraph per op, feeds MTLBuffer tensors, runs synchronously
│
▼
ffi.rs ── raw objc_msgSend calls to Metal, MPS, MPSGraph, Foundation
│
▼
Apple GPU (Metal)§Tensor storage
Each MpsGraphTensor wraps a retained MTLBuffer pointer. On Apple
Silicon the buffer uses shared memory — both CPU and GPU access the same
physical pages with no PCIe copy. Data is only serialised to a Vec<u8>
when you call into_data().
§Supported dtypes
| Dtype | Storage | Arithmetic | Accelerated |
|---|---|---|---|
| F32 | ✓ | ✓ | ✓ |
| F16 | ✓ | ✓ | ✓ |
| BF16 | ✓ | ✓ | |
| I32 | ✓ | ✓ | |
| I64 | ✓ | ✓ | |
| Bool | ✓ | ✓ | |
| F64 | — | — | — (panic) |
§Implemented operations
All operations required by Burn’s Backend trait are implemented:
- Arithmetic: add, sub, mul, div, remainder, powf, matmul
- Unary math: exp, log, log1p, sqrt, abs, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, erf, recip, floor, ceil, round, trunc
- Comparisons: equal, greater, greater_equal, lower, lower_equal
- Reductions: sum, prod, mean, max, min, argmax, argmin (all per-axis)
- Cumulative: cumsum, cumprod, cummin, cummax
- Sort: sort, argsort
- Shape: reshape, transpose, permute, flip, slice, slice_assign, cat, expand, unfold
- Masking: mask_where, mask_fill
- Gather / scatter: gather, scatter_add, select, select_add
- Convolution: conv1d, conv2d, conv3d, conv_transpose1d/2d/3d, deform_conv2d
- Pooling: avg_pool2d, max_pool2d (+ backward + with_indices), adaptive_avg_pool2d
- Interpolation: nearest, bilinear (+ backward)
- Attention: scaled dot-product (single-graph fused softmax)
- Embedding: forward + backward
- Int ops: full arithmetic, bitwise (and/or/xor/not/shift), casting
- Bool ops: and, or, not, equal, scatter_or, casting
§Feature flags
None — the crate is macOS/iOS only and requires the Metal and
MetalPerformanceShadersGraph frameworks at link time (handled
automatically by build.rs).
Modules§
- prelude
- Convenience prelude — import everything you need in one line.
Structs§
- MpsGraph
- Apple MPSGraph backend for Burn.
- MpsGraph
Device - Device for the MPSGraph backend (Apple GPU).
- MpsGraphQ
Tensor - Quantized tensor.
- MpsGraph
Tensor - GPU-resident tensor backed by an
MTLBuffer.
Functions§
- elem_
size - Element size in bytes for a given DType.