# burn-mpsgraph
Apple **Metal Performance Shaders Graph** (MPSGraph) backend for the
[Burn](https://burn.dev) deep learning framework.
Runs tensor computations on **Apple GPUs** (M1, M2, M3, M4 and later) by
dispatching directly to MPSGraph — Apple's graph-based GPU compute engine
that powers Core ML and Create ML.
[](https://crates.io/crates/burn-mpsgraph)
[](https://docs.rs/burn-mpsgraph)
[](#license)
---
## Quick start
```toml
# Cargo.toml
[dependencies]
burn-mpsgraph = "0.0.1"
burn = "0.21.0-pre.2"
```
```rust
use burn::tensor::{Distribution, Tensor};
use burn_mpsgraph::prelude::*;
type B = MpsGraph;
let device = MpsGraphDevice::default();
let a: Tensor<B, 2> = Tensor::random([128, 64], Distribution::Default, &device);
let b: Tensor<B, 2> = Tensor::random([64, 256], Distribution::Default, &device);
let c = a.matmul(b); // runs on Apple GPU
let data = c.into_data(); // copy result back to CPU
```
---
## Design
```
Burn Tensor API
│
▼
FloatTensorOps / IntTensorOps / BoolTensorOps / ModuleOps
│
▼
bridge.rs ── builds an MPSGraph per op, feeds MTLBuffer tensors,
runs synchronously inside an ObjC autorelease pool
│
▼
ffi.rs ── raw objc_msgSend calls to Metal, Foundation, MPS, MPSGraph
│
▼
Apple GPU (Metal)
```
### GPU-resident tensors
Each `MpsGraphTensor` wraps a retained `MTLBuffer` pointer.
On **Apple Silicon** the buffer uses **shared memory** — CPU and GPU share the
same physical pages with no copy overhead.
Data is only serialised to `Vec<u8>` when you call `into_data()`.
### No objc2 dependency
All Apple framework calls go through direct `objc_msgSend` FFI in `ffi.rs`.
This removes the entire `objc2` / `objc2-foundation` / `objc2-metal` family
from the dependency tree.
| Runtime deps | 4 | 30+ |
| `cargo build` time | ~1 s | ~20 s |
---
## Supported operations
| Arithmetic | add, sub, mul, div, remainder, powf, matmul |
| Unary math | exp, log, log1p, sqrt, abs, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, erf, recip, floor, ceil, round, trunc |
| Comparisons | equal, not_equal, greater, greater_equal, lower, lower_equal (tensor and scalar) |
| Reductions | sum, prod, mean, max, min, argmax, argmin — full tensor and per-axis |
| Cumulative | cumsum, cumprod, cummin, cummax |
| Sort | sort, argsort (ascending / descending) |
| Shape | reshape, transpose, permute, flip, slice, slice_assign, cat, expand, unfold |
| Masking | mask_where, mask_fill |
| Gather / scatter | gather, scatter_add, select, select_add |
| Convolution | conv1d, conv2d, conv3d, conv_transpose1d/2d/3d, deform_conv2d |
| Pooling | avg_pool2d (+ backward), max_pool2d (+ with_indices + backward), adaptive_avg_pool2d |
| Interpolation | nearest, bilinear (+ backward) |
| Attention | Fused scaled dot-product (single MPSGraph execution) |
| Embedding | forward + backward |
| Int | full arithmetic, cumulative, sort, bitwise (and, or, xor, not, left/right shift), cast |
| Bool | and, or, not, equal, scatter_or, cast-to-int/float |
| Quantization | Basic per-tensor symmetric quantize / dequantize |
### Supported dtypes
| F32 | ✓ | ✓ | ✓ |
| F16 | ✓ | ✓ | ✓ |
| BF16 | ✓ | | ✓ |
| I32 | ✓ | ✓ | |
| I16 | ✓ | ✓ | |
| I8 | ✓ | ✓ | |
| I64 | ✓ | ✓ | |
| U8–U64| ✓ | ✓ | |
| Bool | ✓ | ✓ | |
| F64 | — | — | — (Metal unsupported; panics with a clear message) |
---
## Benchmarks
Benchmarks run on **Apple M3 Pro** against:
- `burn-wgpu` with the `metal` feature (wgpu → MSL → Metal)
- `burn-ndarray` (CPU reference)
### Matmul (`f32`, square matrices)
| 64×64 | 1.9 ms | ~2 ms | 45 µs |
| 256×256 | 2.4 ms | ~3 ms | 250 µs |
| 512×512 | 4.6 ms | ~5 ms | 1.1 ms |
| 1024×1024 | 8.0 ms | ~9 ms | 6.6 ms |
| 2048×2048 | 14 ms | ~15 ms | 50 ms |
| 4096×4096 | 30 ms | ~29 ms | 410 ms |
At 4096×4096 both GPU backends are **~14× faster than CPU**.
### Element-wise add
| 10K | 2.0 ms | ~1 ms | 6 µs |
| 100K | 2.4 ms | ~1 ms | 30 µs |
| 1M | 8.5 ms | ~2 ms | 300 µs |
> **Note:** For small ops the per-graph-execution overhead (~2 ms) dominates.
> This is the primary area for future improvement — see [Roadmap](#roadmap).
---
## Examples
```
cargo run --example basic # arithmetic, reductions, math, shapes
cargo run --example inference # 2-layer MLP with softmax
cargo run --example conv # conv2d + maxpool2d feature extractor
```
---
## Seeded RNG
```rust
MpsGraph::seed(&device, 42);
let t: Tensor<B, 2> = Tensor::random([4, 4], Distribution::Default, &device);
```
The seed is stored globally per process. Two calls with the same seed produce
identical tensors.
---
## Platform requirements
| macOS | 12.3+ (Monterey) |
| Xcode Command Line Tools | 14+ |
| Apple Silicon or AMD GPU | M1 / M2 / M3 / M4 or supported AMD |
The `build.rs` links these frameworks automatically:
`Foundation`, `Metal`, `MetalPerformanceShaders`, `MetalPerformanceShadersGraph`.
---
## Roadmap
- **Lazy / deferred execution** — accumulate ops into one graph before
running. This would eliminate the ~2 ms per-op overhead and make
element-wise chains as fast as native MPSGraph programs.
- **MTLBuffer pool** — reuse freed buffers instead of allocating new ones.
- **Async execution** — use `encodeToCommandBuffer` to pipeline work.
- **BF16 arithmetic** — currently storage-only.
- **F64 via emulation** — two F32 ops for `reduce_sum` / `mean`.
- **Multi-device** — map `MpsGraphDevice { index }` to discrete GPUs on Mac Pro.
---
## License
Licensed under either of
- [MIT License](LICENSE-MIT)
- [Apache License, Version 2.0](LICENSE-APACHE)
at your option.