burn-mpsgraph
Apple Metal Performance Shaders Graph (MPSGraph) backend for the Burn deep learning framework.
Runs tensor computations on Apple GPUs (M1, M2, M3, M4 and later) by dispatching directly to MPSGraph — Apple's graph-based GPU compute engine that powers Core ML and Create ML.
Quick start
# Cargo.toml
[]
= "0.0.1"
= "0.21.0-pre.2"
use ;
use *;
type B = MpsGraph;
let device = default;
let a: = random;
let b: = random;
let c = a.matmul; // runs on Apple GPU
let data = c.into_data; // copy result back to CPU
Design
Burn Tensor API
│
▼
FloatTensorOps / IntTensorOps / BoolTensorOps / ModuleOps
│
▼
bridge.rs ── builds an MPSGraph per op, feeds MTLBuffer tensors,
runs synchronously inside an ObjC autorelease pool
│
▼
ffi.rs ── raw objc_msgSend calls to Metal, Foundation, MPS, MPSGraph
│
▼
Apple GPU (Metal)
GPU-resident tensors
Each MpsGraphTensor wraps a retained MTLBuffer pointer.
On Apple Silicon the buffer uses shared memory — CPU and GPU share the
same physical pages with no copy overhead.
Data is only serialised to Vec<u8> when you call into_data().
No objc2 dependency
All Apple framework calls go through direct objc_msgSend FFI in ffi.rs.
This removes the entire objc2 / objc2-foundation / objc2-metal family
from the dependency tree.
| burn-mpsgraph | alternative (objc2-based) | |
|---|---|---|
| Runtime deps | 4 | 30+ |
cargo build time |
~1 s | ~20 s |
Supported operations
| Category | Operations |
|---|---|
| Arithmetic | add, sub, mul, div, remainder, powf, matmul |
| Unary math | exp, log, log1p, sqrt, abs, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, erf, recip, floor, ceil, round, trunc |
| Comparisons | equal, not_equal, greater, greater_equal, lower, lower_equal (tensor and scalar) |
| Reductions | sum, prod, mean, max, min, argmax, argmin — full tensor and per-axis |
| Cumulative | cumsum, cumprod, cummin, cummax |
| Sort | sort, argsort (ascending / descending) |
| Shape | reshape, transpose, permute, flip, slice, slice_assign, cat, expand, unfold |
| Masking | mask_where, mask_fill |
| Gather / scatter | gather, scatter_add, select, select_add |
| Convolution | conv1d, conv2d, conv3d, conv_transpose1d/2d/3d, deform_conv2d |
| Pooling | avg_pool2d (+ backward), max_pool2d (+ with_indices + backward), adaptive_avg_pool2d |
| Interpolation | nearest, bilinear (+ backward) |
| Attention | Fused scaled dot-product (single MPSGraph execution) |
| Embedding | forward + backward |
| Int | full arithmetic, cumulative, sort, bitwise (and, or, xor, not, left/right shift), cast |
| Bool | and, or, not, equal, scatter_or, cast-to-int/float |
| Quantization | Basic per-tensor symmetric quantize / dequantize |
Supported dtypes
| DType | Storage | Arithmetic | GPU-accelerated |
|---|---|---|---|
| F32 | ✓ | ✓ | ✓ |
| F16 | ✓ | ✓ | ✓ |
| BF16 | ✓ | ✓ | |
| I32 | ✓ | ✓ | |
| I16 | ✓ | ✓ | |
| I8 | ✓ | ✓ | |
| I64 | ✓ | ✓ | |
| U8–U64 | ✓ | ✓ | |
| Bool | ✓ | ✓ | |
| F64 | — | — | — (Metal unsupported; panics with a clear message) |
Benchmarks
Benchmarks run on Apple M3 Pro against:
burn-wgpuwith themetalfeature (wgpu → MSL → Metal)burn-ndarray(CPU reference)
Matmul (f32, square matrices)
| Size | MPSGraph | metal-wgpu | ndarray (CPU) |
|---|---|---|---|
| 64×64 | 1.9 ms | ~2 ms | 45 µs |
| 256×256 | 2.4 ms | ~3 ms | 250 µs |
| 512×512 | 4.6 ms | ~5 ms | 1.1 ms |
| 1024×1024 | 8.0 ms | ~9 ms | 6.6 ms |
| 2048×2048 | 14 ms | ~15 ms | 50 ms |
| 4096×4096 | 30 ms | ~29 ms | 410 ms |
At 4096×4096 both GPU backends are ~14× faster than CPU.
Element-wise add
| Elements | MPSGraph | metal-wgpu | ndarray |
|---|---|---|---|
| 10K | 2.0 ms | ~1 ms | 6 µs |
| 100K | 2.4 ms | ~1 ms | 30 µs |
| 1M | 8.5 ms | ~2 ms | 300 µs |
Note: For small ops the per-graph-execution overhead (~2 ms) dominates. This is the primary area for future improvement — see Roadmap.
Examples
cargo run --example basic # arithmetic, reductions, math, shapes
cargo run --example inference # 2-layer MLP with softmax
cargo run --example conv # conv2d + maxpool2d feature extractor
Seeded RNG
seed;
let t: = random;
The seed is stored globally per process. Two calls with the same seed produce identical tensors.
Platform requirements
| Requirement | Version |
|---|---|
| macOS | 12.3+ (Monterey) |
| Xcode Command Line Tools | 14+ |
| Apple Silicon or AMD GPU | M1 / M2 / M3 / M4 or supported AMD |
The build.rs links these frameworks automatically:
Foundation, Metal, MetalPerformanceShaders, MetalPerformanceShadersGraph.
Roadmap
- Lazy / deferred execution — accumulate ops into one graph before running. This would eliminate the ~2 ms per-op overhead and make element-wise chains as fast as native MPSGraph programs.
- MTLBuffer pool — reuse freed buffers instead of allocating new ones.
- Async execution — use
encodeToCommandBufferto pipeline work. - BF16 arithmetic — currently storage-only.
- F64 via emulation — two F32 ops for
reduce_sum/mean. - Multi-device — map
MpsGraphDevice { index }to discrete GPUs on Mac Pro.
License
Licensed under either of
at your option.