burn_mpsgraph/lib.rs
1//! # burn-mpsgraph
2//!
3//! Apple Metal Performance Shaders Graph (MPSGraph) backend for the
4//! [Burn](https://burn.dev) deep learning framework.
5//!
6//! This crate accelerates tensor operations on Apple GPUs (M1/M2/M3/M4 and
7//! later) by dispatching to MPSGraph, Apple's graph-based compute engine that
8//! sits on top of Metal. It uses **direct Objective-C FFI** — no `objc2`
9//! crate — keeping the dependency tree minimal and compile times fast.
10//!
11//! ## Quick start
12//!
13//! ```rust,no_run
14//! use burn::prelude::*;
15//! use burn_mpsgraph::prelude::*;
16//!
17//! type B = MpsGraph;
18//!
19//! let device = MpsGraphDevice::default();
20//!
21//! let a: Tensor<B, 2> = Tensor::random([128, 64], burn::tensor::Distribution::Default, &device);
22//! let b: Tensor<B, 2> = Tensor::random([64, 256], burn::tensor::Distribution::Default, &device);
23//! let c = a.matmul(b); // runs on the Apple GPU
24//! let data = c.into_data(); // copies result back to CPU
25//! println!("shape: {:?}", data.shape); // [128, 256]
26//! ```
27//!
28//! ## Architecture
29//!
30//! ```text
31//! Burn Tensor API
32//! │
33//! ▼
34//! FloatTensorOps / IntTensorOps / BoolTensorOps / ModuleOps
35//! │
36//! ▼
37//! bridge.rs ── builds an MPSGraph per op, feeds MTLBuffer tensors, runs synchronously
38//! │
39//! ▼
40//! ffi.rs ── raw objc_msgSend calls to Metal, MPS, MPSGraph, Foundation
41//! │
42//! ▼
43//! Apple GPU (Metal)
44//! ```
45//!
46//! ### Tensor storage
47//!
48//! Each [`MpsGraphTensor`] wraps a retained `MTLBuffer` pointer. On Apple
49//! Silicon the buffer uses **shared memory** — both CPU and GPU access the same
50//! physical pages with no PCIe copy. Data is only serialised to a `Vec<u8>`
51//! when you call `into_data()`.
52//!
53//! ### Supported dtypes
54//!
55//! | Dtype | Storage | Arithmetic | Accelerated |
56//! |-------|---------|------------|-------------|
57//! | F32 | ✓ | ✓ | ✓ |
58//! | F16 | ✓ | ✓ | ✓ |
59//! | BF16 | ✓ | | ✓ |
60//! | I32 | ✓ | ✓ | |
61//! | I64 | ✓ | ✓ | |
62//! | Bool | ✓ | ✓ | |
63//! | F64 | — | — | — (panic) |
64//!
65//! ## Implemented operations
66//!
67//! All operations required by Burn's `Backend` trait are implemented:
68//!
69//! - **Arithmetic**: add, sub, mul, div, remainder, powf, matmul
70//! - **Unary math**: exp, log, log1p, sqrt, abs, sin, cos, tan, asin, acos,
71//! atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, erf, recip,
72//! floor, ceil, round, trunc
73//! - **Comparisons**: equal, greater, greater_equal, lower, lower_equal
74//! - **Reductions**: sum, prod, mean, max, min, argmax, argmin (all per-axis)
75//! - **Cumulative**: cumsum, cumprod, cummin, cummax
76//! - **Sort**: sort, argsort
77//! - **Shape**: reshape, transpose, permute, flip, slice, slice_assign,
78//! cat, expand, unfold
79//! - **Masking**: mask_where, mask_fill
80//! - **Gather / scatter**: gather, scatter_add, select, select_add
81//! - **Convolution**: conv1d, conv2d, conv3d, conv_transpose1d/2d/3d,
82//! deform_conv2d
83//! - **Pooling**: avg_pool2d, max_pool2d (+ backward + with_indices),
84//! adaptive_avg_pool2d
85//! - **Interpolation**: nearest, bilinear (+ backward)
86//! - **Attention**: scaled dot-product (single-graph fused softmax)
87//! - **Embedding**: forward + backward
88//! - **Int ops**: full arithmetic, bitwise (and/or/xor/not/shift), casting
89//! - **Bool ops**: and, or, not, equal, scatter_or, casting
90//!
91//! ## Feature flags
92//!
93//! None — the crate is macOS/iOS only and requires the Metal and
94//! MetalPerformanceShadersGraph frameworks at link time (handled
95//! automatically by `build.rs`).
96
97#[macro_use]
98pub(crate) mod ffi;
99mod backend;
100mod bridge;
101mod device;
102mod ops;
103mod tensor;
104
105pub use backend::*;
106pub use device::*;
107pub use tensor::*;
108
109/// Convenience prelude — import everything you need in one line.
110///
111/// ```rust,no_run
112/// use burn_mpsgraph::prelude::*;
113/// ```
114pub mod prelude {
115 pub use crate::{MpsGraph, MpsGraphDevice, MpsGraphTensor, MpsGraphQTensor};
116 pub use burn_backend::{Backend, DType};
117 pub use burn_std::Shape;
118}