Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.
burn-mlx
MLX backend for Burn — native Apple Silicon GPU acceleration for deep learning.
This crate provides a Burn backend using Apple's MLX framework, enabling high-performance machine learning on M1/M2/M3/M4 Macs.
Features
- Native Apple Silicon: Direct GPU acceleration via Metal
- Unified Memory: Zero-copy data sharing between CPU and GPU
- Lazy Evaluation: Automatic operation fusion and optimization
- Full Burn Backend: FloatTensorOps, IntTensorOps, BoolTensorOps, ModuleOps, ActivationOps
- Training Support: Pooling operations with backward passes for autodiff
Requirements
- macOS with Apple Silicon (M1/M2/M3/M4)
- Rust 1.75+
Installation
Add to your Cargo.toml:
[]
= "0.1"
= "0.16"
Quick Start
use Tensor;
use ;
// Create tensors on Apple Silicon GPU
let device = Gpu;
let a: = ones;
let b: = ones;
let c = a + b;
println!;
Using with Autodiff
use Autodiff;
use Mlx;
type TrainBackend = ;
// Now use TrainBackend for training with automatic differentiation
Pooling Operations
burn-mlx provides full support for pooling operations with both forward and backward passes, enabling their use in training workflows.
Average Pooling
use Tensor;
use ;
use ;
let device = Gpu;
// Create a 4D tensor: [batch, channels, height, width]
let input: = ones;
// Create avg pool layer with 2x2 kernel and stride 2
let config = new.with_strides;
let pool = new;
let output = pool.forward;
// Output shape: [1, 3, 16, 16]
Max Pooling
use Tensor;
use ;
use ;
let device = Gpu;
let input: = ones;
// Create max pool layer with 2x2 kernel and stride 2
let config = new.with_strides;
let pool = new;
let output = pool.forward;
// Output shape: [1, 3, 16, 16]
1D Pooling
use Tensor;
use ;
use ;
let device = Gpu;
// Create a 3D tensor: [batch, channels, length]
let input: = ones;
// Average pooling
let avg_config = new.with_stride;
let avg_pool = new;
let avg_output = avg_pool.forward;
// Output shape: [1, 64, 32]
// Max pooling
let max_config = new.with_stride;
let max_pool = new;
let max_output = max_pool.forward;
// Output shape: [1, 64, 32]
Adaptive Pooling
use Tensor;
use ;
use ;
let device = Gpu;
let input: = ones;
// Adaptive pool to fixed output size (common before FC layers)
let config = new;
let pool = new;
let output = pool.forward;
// Output shape: [1, 512, 1, 1]
Low-Level Tensor API
use ;
let device = Gpu;
// Create tensors
let a = ones;
let b = ones;
// Operations
let c = a.matmul;
let d = c.relu;
let e = d.softmax;
// Evaluate lazy computation
e.eval.expect;
Supported Operations
Tensor Operations
- Arithmetic: add, sub, mul, div, matmul
- Math: exp, log, sqrt, abs, neg, pow
- Reductions: sum, mean, max, min, argmax, argmin
- Shape: reshape, transpose, permute, expand, slice, flip, scatter
Activation Functions
- ReLU, Sigmoid, Tanh, GELU, LeakyReLU
- Softmax, LogSoftmax, HardSigmoid
Neural Network Layers
- Conv1d, Conv2d (with proper NCHW layout handling)
- Embedding lookup
- Pooling (full forward and backward support):
- AvgPool1d, AvgPool2d
- MaxPool1d, MaxPool2d
- MaxPool2d with indices
- AdaptiveAvgPool1d, AdaptiveAvgPool2d
Implementation Details
Pooling Operations
The pooling operations are implemented using MLX's as_strided function combined with reduction operations:
-
Forward Pass: Uses
as_stridedto create sliding window views over the input, then appliesmean_axes(avg pool) ormax_axes(max pool) for reduction. -
Backward Pass:
- AvgPool: Distributes gradients evenly across each pooling window using
scatter_add - MaxPool: Uses saved indices from forward pass to scatter gradients to max positions
- AvgPool: Distributes gradients evenly across each pooling window using
-
Layout Handling: Automatically converts between Burn's NCHW format and MLX's native NHWC format.
Performance
On Apple M-series chips, burn-mlx leverages:
- Metal Performance Shaders for optimized GPU kernels
- Unified memory architecture for efficient data transfer
- Lazy evaluation for automatic operation fusion
Typical matmul performance (1024x1024):
- ~12ms per operation on M1/M2
- Scales well with larger matrices
Limitations
- macOS only (Apple Silicon required)
- Conv3d and ConvTranspose operations are placeholders
- Quantization support is minimal
- Dilation in pooling operations is not yet supported
License
Apache-2.0