burn-mpsgraph 0.0.1

Apple MPSGraph backend for the Burn deep learning framework
Documentation
//! Softmax and linear-layer forward pass.
//!
//! Run with: `cargo run --example inference`

use burn::tensor::{Distribution, Tensor};
use burn_mpsgraph::prelude::*;

type B = MpsGraph;

/// A minimal two-layer MLP: Linear → ReLU → Linear → Softmax.
struct Mlp {
    w1: Tensor<B, 2>,
    b1: Tensor<B, 1>,
    w2: Tensor<B, 2>,
    b2: Tensor<B, 1>,
}

impl Mlp {
    fn new(in_dim: usize, hidden: usize, out_dim: usize, device: &MpsGraphDevice) -> Self {
        MpsGraph::seed(device, 7);
        let scale = |t: Tensor<B, 2>| t * 0.1;
        Self {
            w1: scale(Tensor::random([in_dim, hidden], Distribution::Normal(0.0, 1.0), device)),
            b1: Tensor::zeros([hidden], device),
            w2: scale(Tensor::random([hidden, out_dim], Distribution::Normal(0.0, 1.0), device)),
            b2: Tensor::zeros([out_dim], device),
        }
    }

    fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
        // Layer 1: x @ w1 + b1, then ReLU
        let h = x.matmul(self.w1.clone()) + self.b1.clone().unsqueeze::<2>();
        let h = h.clamp_min(0.0); // ReLU

        // Layer 2: h @ w2 + b2
        let logits = h.matmul(self.w2.clone()) + self.b2.clone().unsqueeze::<2>();

        // Softmax over last dim
        let max = logits.clone().max_dim(1);
        let shifted = logits - max;
        let e = shifted.exp();
        let s = e.clone().sum_dim(1);
        e / s
    }
}

fn main() {
    let device = MpsGraphDevice::default();
    println!("=== burn-mpsgraph inference example ===\n");

    // Batch of 8 samples, 16 input features → 64 hidden → 10 classes
    let mlp = Mlp::new(16, 64, 10, &device);

    let x: Tensor<B, 2> = Tensor::random([8, 16], Distribution::Default, &device);
    let probs = mlp.forward(x);

    println!("Output shape:  {}", probs.shape());  // [8, 10]

    // Probability rows should sum to ~1.0
    let row_sums = probs.clone().sum_dim(1);
    let sums = row_sums.into_data().to_vec::<f32>().unwrap();
    println!("Row sums (should all be ≈ 1.0):");
    for (i, s) in sums.iter().enumerate() {
        println!("  sample {i}: {s:.6}");
    }

    // Predicted class indices
    let preds = probs.argmax(1);
    println!("\nPredicted classes: {:?}", preds.into_data().to_vec::<i32>().unwrap());
}