burn-flex 0.21.0

A fast, portable CPU backend for the Burn framework
Documentation

burn-flex

A fast, memory-efficient CPU backend for Burn with multi-threading, SIMD, and optimized matrix multiplication. Runs on std, no_std, and WebAssembly. Supports f16/bf16, zero-copy data loading, and is thread-safe by design.

Detailed comparison with burn-ndarray: Full architecture, feature coverage, operation-by-operation analysis, and migration path.

Features

  • Zero-Copy Operations: Many operations return strided views without copying data:
    • transpose, permute, flip, narrow, slice
    • unfold (sliding windows as strided view instead of materialization)
    • expand (broadcast via zero strides)
  • Arc-based Copy-on-Write: O(1) tensor cloning with automatic COW semantics. In-place mutation when uniquely owned.
  • Convolutions: Unified 3D implementation with im2col + gemm. Conv1d/2d delegate to conv3d. Direct conv path for small spatial 1D convs (decomposes into kw gemm calls on NCHW data, skipping NHWC conversion and im2col). 1x1 pointwise fast path. Supports groups, dilation, padding.
  • Attention: Auto-selecting between two gemm-backed strategies based on sequence length. Short sequences (score matrix seq_q * seq_kv <= 256K elements) use naive attention with two large gemm calls for lower overhead. Larger shapes use tiled flash attention with online softmax for O(TILE_KV) memory per row. Both support causal masking, additive bias (ALiBi), softcap, custom scale, and cross-attention.
  • Pooling: Max pool, avg pool, adaptive avg pool. All via unified 3D with backward pass support.
  • Conv Transpose: Scatter-based transposed convolutions for upsampling.
  • Portable SIMD: Uses macerator for automatic dispatch:
    • aarch64: NEON
    • x86_64: AVX2, AVX512, SSE
    • wasm32: SIMD128
    • Embedded/other: Scalar fallback
  • Matrix Multiplication: Optimized via gemm with native f16 support. Strided batched matmul passes layout strides directly to gemm, so transposed views (e.g. q.matmul(k.swap_dims(-2,-1))) run at contiguous speed with no copy.
  • FFT: Forward (rfft) and inverse (irfft) real FFT via Cooley-Tukey with complex packing, compile-time twiddle tables, SIMD butterflies, and unrolled small kernels. Works in no_std.
  • Parallel Execution: Optional rayon for large tensors
  • Quantization: Full quantize/dequantize support with per-tensor and per-block symmetric schemes. All ~40 quantized ops (arithmetic, trig, reductions, sorting, etc.) work out of the box. Layout ops on quantized tensors (permute, flip, expand, slice, select) are zero-copy. Stores scales separately for direct scale * x_q dequantization instead of reparsing packed bytes.
  • Dtype Support: f32, f64, f16 (native), bf16 (via f32 conversion), i8-i64, u8-u64
  • Built on Burn: Leverages Burn's native infrastructure (Bytes, Shape, TensorData, Element trait) from burn-backend and burn-std

Feature Flags

Default: std, simd, rayon.

Flag Default Description
std Yes Standard library support
simd Yes Portable SIMD via macerator; also enables gemm/wasm-simd128-enable
rayon Yes Parallel execution for large tensors (forwards gemm/rayon)
x86-v4 No AVX-512 kernels in gemm for x86_64 (Sapphire Rapids, Zen 4/5)
apple-amx No Apple Silicon AMX matrix coprocessor in gemm (experimental)
tracing No Propagate tracing instrumentation
critical-section No Support for no_std targets without atomic CAS

Enable opt-in paths by passing them to Cargo, either directly on burn-flex or through the top-level burn crate:

# Direct
burn-flex = { version = "0.21", features = ["apple-amx"] }

# Via burn
burn = { version = "0.21", features = ["flex", "apple-amx"] }

See ARCHITECTURE.md#feature-flags for per-case benchmark impact.

Why replace burn-ndarray?

burn-ndarray depends on the ndarray crate, which has been slow to accept contributions and evolve.

burn-flex was built as a from-scratch replacement that addresses the gaps while maintaining full compatibility with Burn's backend test suite.

Performance vs burn-ndarray (Apple M3 Max)

Compute Performance

burn-ndarray now uses macerator SIMD for f32 elementwise ops (tracel-ai/burn#2851), so contiguous f32 binary/unary ops are at parity. Flex advantages come from gemm, integer ops (i32 vs i64), structural zero-copy, and fused kernels.

Category Speedup Highlights
Binary ops (f32) ~1x Both use macerator SIMD for f32
Binary ops (i32) 1.8-5.3x Flex uses i32, NdArray uses i64
Matmul (square) 1.4-3.1x gemm at small/large; tied at mid-sizes
Matmul (batched) 1.3-2.2x Multi-head attention shapes
Matmul (int) 3.7-6.5x gemm vs matrixmultiply for integers
Conv2d (3x3) 1.1-3.7x Larger kernels and batches benefit most
Conv1d 4.3-9.8x
Conv transpose 9.2-84x Direct scatter vs im2col
Attention 1.2-3.0x Fused softmax, 2-8x lower peak memory
Pooling 1.1-3.1x
Interpolation 1.1-6.3x Nearest 4-6x, bilinear 1.7-2.8x
Reductions 1.3-5.4x Near-zero allocation for scalar results
Cumulative ops 2.1-95x 1D cumsum: 95x faster
Gather/scatter 1.2-6.4x
Unary (tanh, sin) 1.3-2.0x tanh 2x, sin/cos 1.3-1.5x
Sort 2.3-29x 2D sort up to 29x
Repeat dim 8.9-12x Single alloc + memcpy vs N slice_assign
Tensor creation 16-33x zeros/ones/full
Embedding 4.8-5.8x
Quantize 1.3-1.5x Fused 2-pass implementation
FFT (rfft/irfft) yes Native implementation, works in no_std

Structural Improvements

These reflect better operation representation, not faster computation. burn-ndarray eagerly materializes data for these operations; burn-flex avoids the work entirely through zero-copy views and separated storage layouts.

Category Improvement What changed
Dequantize 122-238x Direct scale * x_q vs reparsing QuantizedBytes each call
Quantized ops 6.1-125x Dominated by fast dequantize path above
Slice/narrow 2.1-2,400x Zero-copy strided view vs data copy
Unfold 920-130,000x O(1) strided view vs O(n) full materialization
Expand 620-2,800x Zero-copy broadcast (stride=0) vs data copy
Int cast 6.3-26,000x Zero-copy reinterpret vs element-wise conversion

Note on quantization: burn-ndarray simulates quantization by dequantizing to f32 for most operations. The quantized speedups reflect the difference between simulated and native execution, not equivalent algorithms running at different speeds.

See BENCHMARKS.md for the full breakdown.

Performance vs candle-core (Apple M3 Max, pure-Rust, no BLAS)

Per-op comparison against candle-core, pure-Rust on both sides. Across 11 bench files covering every flex op that intersects with candle's CPU API, flex is as fast or faster on every operation category.

Category Representative ratio Notes
Batched matmul 8-11x Strided gemm, no copy
Conv1d (wav2vec2) 1.4-7.6x Direct conv path
Conv2d (ResNet) 1.3-4.0x 1x1 pointwise 4x
Conv transpose 1.5-1.9x
Max/min reductions 3.8-5.1x SIMD + zero-alloc
Pooling (k=3 s=2) 1.8-2.5x
Layer norm (fused) 1.6-3.4x Two-pass Welford kernel
Softmax (fused) 1.4-1.7x Three-pass row kernel
Cat, gather, select 1.3-2.5x
Nearest2d interpolation 1.3-1.4x
Elementwise, matmul, gelu, view ops tied Both at memory bandwidth ceiling

Status

  • All burn-backend-tests pass across all feature flag combinations:
    • no-default-features (no_std, no SIMD, no rayon)
    • no-default-features + simd (no_std with SIMD)
    • std
    • std + simd
    • std + rayon
    • std + simd + rayon (default)
  • Burn's burn-no-std-tests integration suite passes (MNIST model inference in #![no_std])
  • Builds for embedded and WebAssembly targets:
    • thumbv6m-none-eabi (ARM Cortex-M0+, no atomic pointers)
    • thumbv7m-none-eabi (ARM Cortex-M3)
    • wasm32-unknown-unknown
  • Passes Miri (undefined behavior detector) on all burn-flex code (quantization tests skipped due to upstream UB in burn-std). Validates memory safety of unsafe pointer arithmetic, bytemuck casts, and Send/Sync implementations.
  • Tested for edge-case robustness: integer overflow at type boundaries, large-float rounding, invalid pooling parameters, zero-sized dimensions. Safe for embedded devices.
  • All ONNX model checks in burn-onnx pass
  • Real model inference verified:
    • ALBERT (masked language model, all v2 variants)
    • MiniLM (sentence embeddings, L6 and L12)

Documentation