burn-mpsgraph 0.0.1

Apple MPSGraph backend for the Burn deep learning framework
Documentation
//! # burn-mpsgraph
//!
//! Apple Metal Performance Shaders Graph (MPSGraph) backend for the
//! [Burn](https://burn.dev) deep learning framework.
//!
//! This crate accelerates tensor operations on Apple GPUs (M1/M2/M3/M4 and
//! later) by dispatching to MPSGraph, Apple's graph-based compute engine that
//! sits on top of Metal.  It uses **direct Objective-C FFI** — no `objc2`
//! crate — keeping the dependency tree minimal and compile times fast.
//!
//! ## Quick start
//!
//! ```rust,no_run
//! use burn::prelude::*;
//! use burn_mpsgraph::prelude::*;
//!
//! type B = MpsGraph;
//!
//! let device = MpsGraphDevice::default();
//!
//! let a: Tensor<B, 2> = Tensor::random([128, 64], burn::tensor::Distribution::Default, &device);
//! let b: Tensor<B, 2> = Tensor::random([64, 256], burn::tensor::Distribution::Default, &device);
//! let c = a.matmul(b);                   // runs on the Apple GPU
//! let data = c.into_data();              // copies result back to CPU
//! println!("shape: {:?}", data.shape);   // [128, 256]
//! ```
//!
//! ## Architecture
//!
//! ```text
//!  Burn Tensor API
//!//!//!  FloatTensorOps / IntTensorOps / BoolTensorOps / ModuleOps
//!//!//!  bridge.rs ── builds an MPSGraph per op, feeds MTLBuffer tensors, runs synchronously
//!//!//!  ffi.rs ── raw objc_msgSend calls to Metal, MPS, MPSGraph, Foundation
//!//!//!  Apple GPU (Metal)
//! ```
//!
//! ### Tensor storage
//!
//! Each [`MpsGraphTensor`] wraps a retained `MTLBuffer` pointer.  On Apple
//! Silicon the buffer uses **shared memory** — both CPU and GPU access the same
//! physical pages with no PCIe copy.  Data is only serialised to a `Vec<u8>`
//! when you call `into_data()`.
//!
//! ### Supported dtypes
//!
//! | Dtype | Storage | Arithmetic | Accelerated |
//! |-------|---------|------------|-------------|
//! | F32   | ✓       | ✓          | ✓           |
//! | F16   | ✓       | ✓          | ✓           |
//! | BF16  | ✓       |            | ✓           |
//! | I32   | ✓       | ✓          |             |
//! | I64   | ✓       | ✓          |             |
//! | Bool  | ✓       | ✓          |             |
//! | F64   | —       | —          | — (panic)   |
//!
//! ## Implemented operations
//!
//! All operations required by Burn's `Backend` trait are implemented:
//!
//! - **Arithmetic**: add, sub, mul, div, remainder, powf, matmul
//! - **Unary math**: exp, log, log1p, sqrt, abs, sin, cos, tan, asin, acos,
//!   atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, erf, recip,
//!   floor, ceil, round, trunc
//! - **Comparisons**: equal, greater, greater_equal, lower, lower_equal
//! - **Reductions**: sum, prod, mean, max, min, argmax, argmin (all per-axis)
//! - **Cumulative**: cumsum, cumprod, cummin, cummax
//! - **Sort**: sort, argsort
//! - **Shape**: reshape, transpose, permute, flip, slice, slice_assign,
//!   cat, expand, unfold
//! - **Masking**: mask_where, mask_fill
//! - **Gather / scatter**: gather, scatter_add, select, select_add
//! - **Convolution**: conv1d, conv2d, conv3d, conv_transpose1d/2d/3d,
//!   deform_conv2d
//! - **Pooling**: avg_pool2d, max_pool2d (+ backward + with_indices),
//!   adaptive_avg_pool2d
//! - **Interpolation**: nearest, bilinear (+ backward)
//! - **Attention**: scaled dot-product (single-graph fused softmax)
//! - **Embedding**: forward + backward
//! - **Int ops**: full arithmetic, bitwise (and/or/xor/not/shift), casting
//! - **Bool ops**: and, or, not, equal, scatter_or, casting
//!
//! ## Feature flags
//!
//! None — the crate is macOS/iOS only and requires the Metal and
//! MetalPerformanceShadersGraph frameworks at link time (handled
//! automatically by `build.rs`).

#[macro_use]
pub(crate) mod ffi;
mod backend;
mod bridge;
mod device;
mod ops;
mod tensor;

pub use backend::*;
pub use device::*;
pub use tensor::*;

/// Convenience prelude — import everything you need in one line.
///
/// ```rust,no_run
/// use burn_mpsgraph::prelude::*;
/// ```
pub mod prelude {
    pub use crate::{MpsGraph, MpsGraphDevice, MpsGraphTensor, MpsGraphQTensor};
    pub use burn_backend::{Backend, DType};
    pub use burn_std::Shape;
}