1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
//! # burn-mpsgraph
//!
//! Apple Metal Performance Shaders Graph (MPSGraph) backend for the
//! [Burn](https://burn.dev) deep learning framework.
//!
//! This crate accelerates tensor operations on Apple GPUs (M1/M2/M3/M4 and
//! later) by dispatching to MPSGraph, Apple's graph-based compute engine that
//! sits on top of Metal. It uses **direct Objective-C FFI** — no `objc2`
//! crate — keeping the dependency tree minimal and compile times fast.
//!
//! ## Quick start
//!
//! ```rust,no_run
//! use burn::prelude::*;
//! use burn_mpsgraph::prelude::*;
//!
//! type B = MpsGraph;
//!
//! let device = MpsGraphDevice::default();
//!
//! let a: Tensor<B, 2> = Tensor::random([128, 64], burn::tensor::Distribution::Default, &device);
//! let b: Tensor<B, 2> = Tensor::random([64, 256], burn::tensor::Distribution::Default, &device);
//! let c = a.matmul(b); // runs on the Apple GPU
//! let data = c.into_data(); // copies result back to CPU
//! println!("shape: {:?}", data.shape); // [128, 256]
//! ```
//!
//! ## Architecture
//!
//! ```text
//! Burn Tensor API
//! │
//! ▼
//! FloatTensorOps / IntTensorOps / BoolTensorOps / ModuleOps
//! │
//! ▼
//! bridge.rs ── builds an MPSGraph per op, feeds MTLBuffer tensors, runs synchronously
//! │
//! ▼
//! ffi.rs ── raw objc_msgSend calls to Metal, MPS, MPSGraph, Foundation
//! │
//! ▼
//! Apple GPU (Metal)
//! ```
//!
//! ### Tensor storage
//!
//! Each [`MpsGraphTensor`] wraps a retained `MTLBuffer` pointer. On Apple
//! Silicon the buffer uses **shared memory** — both CPU and GPU access the same
//! physical pages with no PCIe copy. Data is only serialised to a `Vec<u8>`
//! when you call `into_data()`.
//!
//! ### Supported dtypes
//!
//! | Dtype | Storage | Arithmetic | Accelerated |
//! |-------|---------|------------|-------------|
//! | F32 | ✓ | ✓ | ✓ |
//! | F16 | ✓ | ✓ | ✓ |
//! | BF16 | ✓ | | ✓ |
//! | I32 | ✓ | ✓ | |
//! | I64 | ✓ | ✓ | |
//! | Bool | ✓ | ✓ | |
//! | F64 | — | — | — (panic) |
//!
//! ## Implemented operations
//!
//! All operations required by Burn's `Backend` trait are implemented:
//!
//! - **Arithmetic**: add, sub, mul, div, remainder, powf, matmul
//! - **Unary math**: exp, log, log1p, sqrt, abs, sin, cos, tan, asin, acos,
//! atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, erf, recip,
//! floor, ceil, round, trunc
//! - **Comparisons**: equal, greater, greater_equal, lower, lower_equal
//! - **Reductions**: sum, prod, mean, max, min, argmax, argmin (all per-axis)
//! - **Cumulative**: cumsum, cumprod, cummin, cummax
//! - **Sort**: sort, argsort
//! - **Shape**: reshape, transpose, permute, flip, slice, slice_assign,
//! cat, expand, unfold
//! - **Masking**: mask_where, mask_fill
//! - **Gather / scatter**: gather, scatter_add, select, select_add
//! - **Convolution**: conv1d, conv2d, conv3d, conv_transpose1d/2d/3d,
//! deform_conv2d
//! - **Pooling**: avg_pool2d, max_pool2d (+ backward + with_indices),
//! adaptive_avg_pool2d
//! - **Interpolation**: nearest, bilinear (+ backward)
//! - **Attention**: scaled dot-product (single-graph fused softmax)
//! - **Embedding**: forward + backward
//! - **Int ops**: full arithmetic, bitwise (and/or/xor/not/shift), casting
//! - **Bool ops**: and, or, not, equal, scatter_or, casting
//!
//! ## Feature flags
//!
//! None — the crate is macOS/iOS only and requires the Metal and
//! MetalPerformanceShadersGraph frameworks at link time (handled
//! automatically by `build.rs`).
pub
pub use *;
pub use *;
pub use *;
/// Convenience prelude — import everything you need in one line.
///
/// ```rust,no_run
/// use burn_mpsgraph::prelude::*;
/// ```