burn-mpsgraph 0.0.1

Apple MPSGraph backend for the Burn deep learning framework
Documentation
//! Conv2d + MaxPool2d forward pass (image classification feature extractor).
//!
//! Run with: `cargo run --example conv`

use burn::tensor::{Distribution, Tensor};
use burn_mpsgraph::prelude::*;
use burn_backend::ops::{ConvOptions, ModuleOps};

type B = MpsGraph;

fn main() {
    let device = MpsGraphDevice::default();
    println!("=== burn-mpsgraph conv2d example ===\n");

    MpsGraph::seed(&device, 99);

    // ── Fake "image" input: batch=2, RGB, 32×32 ──────────────────────────
    let x: Tensor<B, 4> = Tensor::random([2, 3, 32, 32], Distribution::Default, &device);
    println!("Input  shape: {:?}", x.shape());  // [2, 3, 32, 32]

    // ── Conv2d: 3 → 16 channels, 3×3 kernel, padding=1 ──────────────────
    let w1: Tensor<B, 4> = Tensor::random([16, 3, 3, 3], Distribution::Normal(0.0, 0.1), &device);
    let conv_out = MpsGraph::conv2d(
        x.into_primitive().tensor(),
        w1.into_primitive().tensor(),
        None,
        ConvOptions::new([1, 1], [1, 1], [1, 1], 1),
    );
    use burn_backend::TensorMetadata;
    println!("After conv2d:  {:?}", conv_out.shape()); // [2, 16, 32, 32]

    // ReLU via clamp_min
    let after_relu = Tensor::<B, 4>::from_primitive(
        burn_backend::TensorPrimitive::Float(conv_out)
    ).clamp_min(0.0);

    // ── MaxPool2d: 2×2, stride=2 ─────────────────────────────────────────
    let pooled = MpsGraph::max_pool2d(
        after_relu.into_primitive().tensor(),
        [2, 2],  // kernel
        [2, 2],  // stride
        [0, 0],  // padding
        [1, 1],  // dilation
        false,
    );
    println!("After maxpool: {:?}", pooled.shape()); // [2, 16, 16, 16]

    // ── Second conv: 16 → 32 channels ────────────────────────────────────
    let w2: Tensor<B, 4> = Tensor::random([32, 16, 3, 3], Distribution::Normal(0.0, 0.1), &device);
    let conv2 = MpsGraph::conv2d(
        pooled,
        w2.into_primitive().tensor(),
        None,
        ConvOptions::new([1, 1], [1, 1], [1, 1], 1),
    );
    println!("After conv2:   {:?}", conv2.shape()); // [2, 32, 16, 16]

    // ── Global average pool: reduce H×W → 1×1 ───────────────────────────
    let gap = MpsGraph::adaptive_avg_pool2d(conv2, [1, 1]);
    println!("After GAP:     {:?}", gap.shape()); // [2, 32, 1, 1]

    // Flatten → [2, 32]
    let flat = Tensor::<B, 4>::from_primitive(
        burn_backend::TensorPrimitive::Float(gap)
    ).reshape([2, 32]);
    println!("Flattened:     {:?}", flat.shape()); // [2, 32]

    println!("\nAll computed on Apple GPU via MPSGraph.");
}