burn-mlx 0.1.2

MLX backend for Burn deep learning framework - native Apple Silicon GPU acceleration
docs.rs failed to build burn-mlx-0.1.2
Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.

burn-mlx

Crates.io Documentation License

MLX backend for Burn — native Apple Silicon GPU acceleration for deep learning.

This crate provides a Burn backend using Apple's MLX framework, enabling high-performance machine learning on M1/M2/M3/M4 Macs.

Features

  • Native Apple Silicon: Direct GPU acceleration via Metal
  • Unified Memory: Zero-copy data sharing between CPU and GPU
  • Lazy Evaluation: Automatic operation fusion and optimization
  • Full Burn Backend: FloatTensorOps, IntTensorOps, BoolTensorOps, ModuleOps, ActivationOps
  • Training Support: Pooling operations with backward passes for autodiff

Requirements

  • macOS with Apple Silicon (M1/M2/M3/M4)
  • Rust 1.75+

Installation

Add to your Cargo.toml:

[dependencies]
burn-mlx = "0.1"
burn = "0.16"

Quick Start

use burn::tensor::Tensor;
use burn_mlx::{Mlx, MlxDevice};

// Create tensors on Apple Silicon GPU
let device = MlxDevice::Gpu;
let a: Tensor<Mlx, 2> = Tensor::ones([2, 3], &device);
let b: Tensor<Mlx, 2> = Tensor::ones([2, 3], &device);
let c = a + b;

println!("Result shape: {:?}", c.shape());

Using with Autodiff

use burn::backend::Autodiff;
use burn_mlx::Mlx;

type TrainBackend = Autodiff<Mlx>;

// Now use TrainBackend for training with automatic differentiation

Pooling Operations

burn-mlx provides full support for pooling operations with both forward and backward passes, enabling their use in training workflows.

Average Pooling

use burn::tensor::Tensor;
use burn::nn::pool::{AvgPool2d, AvgPool2dConfig};
use burn_mlx::{Mlx, MlxDevice};

let device = MlxDevice::Gpu;

// Create a 4D tensor: [batch, channels, height, width]
let input: Tensor<Mlx, 4> = Tensor::ones([1, 3, 32, 32], &device);

// Create avg pool layer with 2x2 kernel and stride 2
let config = AvgPool2dConfig::new([2, 2]).with_strides([2, 2]);
let pool = AvgPool2d::new(config);

let output = pool.forward(input);
// Output shape: [1, 3, 16, 16]

Max Pooling

use burn::tensor::Tensor;
use burn::nn::pool::{MaxPool2d, MaxPool2dConfig};
use burn_mlx::{Mlx, MlxDevice};

let device = MlxDevice::Gpu;

let input: Tensor<Mlx, 4> = Tensor::ones([1, 3, 32, 32], &device);

// Create max pool layer with 2x2 kernel and stride 2
let config = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]);
let pool = MaxPool2d::new(config);

let output = pool.forward(input);
// Output shape: [1, 3, 16, 16]

1D Pooling

use burn::tensor::Tensor;
use burn::nn::pool::{AvgPool1d, AvgPool1dConfig, MaxPool1d, MaxPool1dConfig};
use burn_mlx::{Mlx, MlxDevice};

let device = MlxDevice::Gpu;

// Create a 3D tensor: [batch, channels, length]
let input: Tensor<Mlx, 3> = Tensor::ones([1, 64, 128], &device);

// Average pooling
let avg_config = AvgPool1dConfig::new(4).with_stride(4);
let avg_pool = AvgPool1d::new(avg_config);
let avg_output = avg_pool.forward(input.clone());
// Output shape: [1, 64, 32]

// Max pooling
let max_config = MaxPool1dConfig::new(4).with_stride(4);
let max_pool = MaxPool1d::new(max_config);
let max_output = max_pool.forward(input);
// Output shape: [1, 64, 32]

Adaptive Pooling

use burn::tensor::Tensor;
use burn::nn::pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig};
use burn_mlx::{Mlx, MlxDevice};

let device = MlxDevice::Gpu;

let input: Tensor<Mlx, 4> = Tensor::ones([1, 512, 14, 14], &device);

// Adaptive pool to fixed output size (common before FC layers)
let config = AdaptiveAvgPool2dConfig::new([1, 1]);
let pool = AdaptiveAvgPool2d::new(config);

let output = pool.forward(input);
// Output shape: [1, 512, 1, 1]

Low-Level Tensor API

use burn_mlx::{MlxTensor, MlxDevice};

let device = MlxDevice::Gpu;

// Create tensors
let a = MlxTensor::<f32>::ones(&[1024, 1024], device);
let b = MlxTensor::<f32>::ones(&[1024, 1024], device);

// Operations
let c = a.matmul(&b);
let d = c.relu();
let e = d.softmax();

// Evaluate lazy computation
e.eval().expect("evaluation failed");

Supported Operations

Tensor Operations

  • Arithmetic: add, sub, mul, div, matmul
  • Math: exp, log, sqrt, abs, neg, pow
  • Reductions: sum, mean, max, min, argmax, argmin
  • Shape: reshape, transpose, permute, expand, slice, flip, scatter

Activation Functions

  • ReLU, Sigmoid, Tanh, GELU, LeakyReLU
  • Softmax, LogSoftmax, HardSigmoid

Neural Network Layers

  • Conv1d, Conv2d (with proper NCHW layout handling)
  • Embedding lookup
  • Pooling (full forward and backward support):
    • AvgPool1d, AvgPool2d
    • MaxPool1d, MaxPool2d
    • MaxPool2d with indices
    • AdaptiveAvgPool1d, AdaptiveAvgPool2d

Implementation Details

Pooling Operations

The pooling operations are implemented using MLX's as_strided function combined with reduction operations:

  1. Forward Pass: Uses as_strided to create sliding window views over the input, then applies mean_axes (avg pool) or max_axes (max pool) for reduction.

  2. Backward Pass:

    • AvgPool: Distributes gradients evenly across each pooling window using scatter_add
    • MaxPool: Uses saved indices from forward pass to scatter gradients to max positions
  3. Layout Handling: Automatically converts between Burn's NCHW format and MLX's native NHWC format.

Performance

On Apple M-series chips, burn-mlx leverages:

  • Metal Performance Shaders for optimized GPU kernels
  • Unified memory architecture for efficient data transfer
  • Lazy evaluation for automatic operation fusion

Typical matmul performance (1024x1024):

  • ~12ms per operation on M1/M2
  • Scales well with larger matrices

Limitations

  • macOS only (Apple Silicon required)
  • Conv3d and ConvTranspose operations are placeholders
  • Quantization support is minimal
  • Dilation in pooling operations is not yet supported

License

Apache-2.0

Acknowledgments

  • Burn - Rust deep learning framework
  • MLX - Apple's machine learning framework
  • mlx-rs - Rust bindings for MLX