Skip to main content

Crate burn_mpsgraph

Crate burn_mpsgraph 

Source
Expand description

§burn-mpsgraph

Apple Metal Performance Shaders Graph (MPSGraph) backend for the Burn deep learning framework.

This crate accelerates tensor operations on Apple GPUs (M1/M2/M3/M4 and later) by dispatching to MPSGraph, Apple’s graph-based compute engine that sits on top of Metal. It uses direct Objective-C FFI — no objc2 crate — keeping the dependency tree minimal and compile times fast.

§Quick start

use burn::prelude::*;
use burn_mpsgraph::prelude::*;

type B = MpsGraph;

let device = MpsGraphDevice::default();

let a: Tensor<B, 2> = Tensor::random([128, 64], burn::tensor::Distribution::Default, &device);
let b: Tensor<B, 2> = Tensor::random([64, 256], burn::tensor::Distribution::Default, &device);
let c = a.matmul(b);                   // runs on the Apple GPU
let data = c.into_data();              // copies result back to CPU
println!("shape: {:?}", data.shape);   // [128, 256]

§Architecture

 Burn Tensor API
      │
      ▼
 FloatTensorOps / IntTensorOps / BoolTensorOps / ModuleOps
      │
      ▼
 bridge.rs ── builds an MPSGraph per op, feeds MTLBuffer tensors, runs synchronously
      │
      ▼
 ffi.rs ── raw objc_msgSend calls to Metal, MPS, MPSGraph, Foundation
      │
      ▼
 Apple GPU (Metal)

§Tensor storage

Each MpsGraphTensor wraps a retained MTLBuffer pointer. On Apple Silicon the buffer uses shared memory — both CPU and GPU access the same physical pages with no PCIe copy. Data is only serialised to a Vec<u8> when you call into_data().

§Supported dtypes

DtypeStorageArithmeticAccelerated
F32
F16
BF16
I32
I64
Bool
F64— (panic)

§Implemented operations

All operations required by Burn’s Backend trait are implemented:

  • Arithmetic: add, sub, mul, div, remainder, powf, matmul
  • Unary math: exp, log, log1p, sqrt, abs, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, erf, recip, floor, ceil, round, trunc
  • Comparisons: equal, greater, greater_equal, lower, lower_equal
  • Reductions: sum, prod, mean, max, min, argmax, argmin (all per-axis)
  • Cumulative: cumsum, cumprod, cummin, cummax
  • Sort: sort, argsort
  • Shape: reshape, transpose, permute, flip, slice, slice_assign, cat, expand, unfold
  • Masking: mask_where, mask_fill
  • Gather / scatter: gather, scatter_add, select, select_add
  • Convolution: conv1d, conv2d, conv3d, conv_transpose1d/2d/3d, deform_conv2d
  • Pooling: avg_pool2d, max_pool2d (+ backward + with_indices), adaptive_avg_pool2d
  • Interpolation: nearest, bilinear (+ backward)
  • Attention: scaled dot-product (single-graph fused softmax)
  • Embedding: forward + backward
  • Int ops: full arithmetic, bitwise (and/or/xor/not/shift), casting
  • Bool ops: and, or, not, equal, scatter_or, casting

§Feature flags

None — the crate is macOS/iOS only and requires the Metal and MetalPerformanceShadersGraph frameworks at link time (handled automatically by build.rs).

Modules§

prelude
Convenience prelude — import everything you need in one line.

Structs§

MpsGraph
Apple MPSGraph backend for Burn.
MpsGraphDevice
Device for the MPSGraph backend (Apple GPU).
MpsGraphQTensor
Quantized tensor.
MpsGraphTensor
GPU-resident tensor backed by an MTLBuffer.

Functions§

elem_size
Element size in bytes for a given DType.