burn-mpsgraph 0.0.1

Apple MPSGraph backend for the Burn deep learning framework
Documentation
# burn-mpsgraph

Apple **Metal Performance Shaders Graph** (MPSGraph) backend for the
[Burn](https://burn.dev) deep learning framework.

Runs tensor computations on **Apple GPUs** (M1, M2, M3, M4 and later) by
dispatching directly to MPSGraph — Apple's graph-based GPU compute engine
that powers Core ML and Create ML.

[![Crates.io](https://img.shields.io/crates/v/burn-mpsgraph.svg)](https://crates.io/crates/burn-mpsgraph)
[![docs.rs](https://img.shields.io/docsrs/burn-mpsgraph)](https://docs.rs/burn-mpsgraph)
[![License: MIT OR Apache-2.0](https://img.shields.io/badge/license-MIT%20OR%20Apache--2.0-blue.svg)](#license)

---

## Quick start

```toml
# Cargo.toml
[dependencies]
burn-mpsgraph = "0.0.1"
burn           = "0.21.0-pre.2"
```

```rust
use burn::tensor::{Distribution, Tensor};
use burn_mpsgraph::prelude::*;

type B = MpsGraph;

let device = MpsGraphDevice::default();

let a: Tensor<B, 2> = Tensor::random([128, 64], Distribution::Default, &device);
let b: Tensor<B, 2> = Tensor::random([64, 256], Distribution::Default, &device);
let c = a.matmul(b);          // runs on Apple GPU
let data = c.into_data();     // copy result back to CPU
```

---

## Design

```
Burn Tensor API
FloatTensorOps / IntTensorOps / BoolTensorOps / ModuleOps
bridge.rs ── builds an MPSGraph per op, feeds MTLBuffer tensors,
             runs synchronously inside an ObjC autorelease pool
ffi.rs ── raw objc_msgSend calls to Metal, Foundation, MPS, MPSGraph
Apple GPU (Metal)
```

### GPU-resident tensors

Each `MpsGraphTensor` wraps a retained `MTLBuffer` pointer.  
On **Apple Silicon** the buffer uses **shared memory** — CPU and GPU share the
same physical pages with no copy overhead.  
Data is only serialised to `Vec<u8>` when you call `into_data()`.

### No objc2 dependency

All Apple framework calls go through direct `objc_msgSend` FFI in `ffi.rs`.
This removes the entire `objc2` / `objc2-foundation` / `objc2-metal` family
from the dependency tree.

| | burn-mpsgraph | alternative (objc2-based) |
|---|---|---|
| Runtime deps | 4 | 30+ |
| `cargo build` time | ~1 s | ~20 s |

---

## Supported operations

| Category | Operations |
|---|---|
| Arithmetic | add, sub, mul, div, remainder, powf, matmul |
| Unary math | exp, log, log1p, sqrt, abs, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, erf, recip, floor, ceil, round, trunc |
| Comparisons | equal, not_equal, greater, greater_equal, lower, lower_equal (tensor and scalar) |
| Reductions | sum, prod, mean, max, min, argmax, argmin — full tensor and per-axis |
| Cumulative | cumsum, cumprod, cummin, cummax |
| Sort | sort, argsort (ascending / descending) |
| Shape | reshape, transpose, permute, flip, slice, slice_assign, cat, expand, unfold |
| Masking | mask_where, mask_fill |
| Gather / scatter | gather, scatter_add, select, select_add |
| Convolution | conv1d, conv2d, conv3d, conv_transpose1d/2d/3d, deform_conv2d |
| Pooling | avg_pool2d (+ backward), max_pool2d (+ with_indices + backward), adaptive_avg_pool2d |
| Interpolation | nearest, bilinear (+ backward) |
| Attention | Fused scaled dot-product (single MPSGraph execution) |
| Embedding | forward + backward |
| Int | full arithmetic, cumulative, sort, bitwise (and, or, xor, not, left/right shift), cast |
| Bool | and, or, not, equal, scatter_or, cast-to-int/float |
| Quantization | Basic per-tensor symmetric quantize / dequantize |

### Supported dtypes

| DType | Storage | Arithmetic | GPU-accelerated |
|-------|:-------:|:----------:|:---------------:|
| F32   ||||
| F16   ||||
| BF16  ||            ||
| I32   |||                 |
| I16   |||                 |
| I8    |||                 |
| I64   |||                 |
| U8–U64|||                 |
| Bool  |||                 |
| F64   ||| — (Metal unsupported; panics with a clear message) |

---

## Benchmarks

Benchmarks run on **Apple M3 Pro** against:
- `burn-wgpu` with the `metal` feature (wgpu → MSL → Metal)
- `burn-ndarray` (CPU reference)

### Matmul (`f32`, square matrices)

| Size | MPSGraph | metal-wgpu | ndarray (CPU) |
|------|----------|------------|---------------|
| 64×64   | 1.9 ms  | ~2 ms   | 45 µs   |
| 256×256 | 2.4 ms  | ~3 ms   | 250 µs  |
| 512×512 | 4.6 ms  | ~5 ms   | 1.1 ms  |
| 1024×1024 | 8.0 ms | ~9 ms | 6.6 ms  |
| 2048×2048 | 14 ms  | ~15 ms | 50 ms   |
| 4096×4096 | 30 ms  | ~29 ms | 410 ms  |

At 4096×4096 both GPU backends are **~14× faster than CPU**.

### Element-wise add

| Elements | MPSGraph | metal-wgpu | ndarray |
|----------|----------|------------|---------|
| 10K  | 2.0 ms  | ~1 ms  | 6 µs   |
| 100K | 2.4 ms  | ~1 ms  | 30 µs  |
| 1M   | 8.5 ms  | ~2 ms  | 300 µs |

> **Note:** For small ops the per-graph-execution overhead (~2 ms) dominates.
> This is the primary area for future improvement — see [Roadmap]#roadmap.

---

## Examples

```
cargo run --example basic      # arithmetic, reductions, math, shapes
cargo run --example inference  # 2-layer MLP with softmax
cargo run --example conv       # conv2d + maxpool2d feature extractor
```

---

## Seeded RNG

```rust
MpsGraph::seed(&device, 42);
let t: Tensor<B, 2> = Tensor::random([4, 4], Distribution::Default, &device);
```

The seed is stored globally per process. Two calls with the same seed produce
identical tensors.

---

## Platform requirements

| Requirement | Version |
|---|---|
| macOS | 12.3+ (Monterey) |
| Xcode Command Line Tools | 14+ |
| Apple Silicon or AMD GPU | M1 / M2 / M3 / M4 or supported AMD |

The `build.rs` links these frameworks automatically:
`Foundation`, `Metal`, `MetalPerformanceShaders`, `MetalPerformanceShadersGraph`.

---

## Roadmap

- **Lazy / deferred execution** — accumulate ops into one graph before
  running. This would eliminate the ~2 ms per-op overhead and make
  element-wise chains as fast as native MPSGraph programs.
- **MTLBuffer pool** — reuse freed buffers instead of allocating new ones.
- **Async execution** — use `encodeToCommandBuffer` to pipeline work.
- **BF16 arithmetic** — currently storage-only.
- **F64 via emulation** — two F32 ops for `reduce_sum` / `mean`.
- **Multi-device** — map `MpsGraphDevice { index }` to discrete GPUs on Mac Pro.

---

## License

Licensed under either of

- [MIT License]LICENSE-MIT
- [Apache License, Version 2.0]LICENSE-APACHE

at your option.