burn-mpsgraph 0.0.1

Apple MPSGraph backend for the Burn deep learning framework
Documentation
//! Basic tensor operations using the MPSGraph backend.
//!
//! Run with: `cargo run --example basic`

use burn::tensor::{Distribution, Tensor};
use burn_mpsgraph::prelude::*;

type B = MpsGraph;

fn main() {
    let device = MpsGraphDevice::default();

    println!("=== burn-mpsgraph basic example ===\n");

    // ── Arithmetic ──────────────────────────────────────────────────────────
    let a: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device);
    let b: Tensor<B, 1> = Tensor::from_floats([10.0, 20.0, 30.0, 40.0], &device);

    let sum = a.clone() + b.clone();
    println!("a + b          = {:?}", sum.into_data().to_vec::<f32>().unwrap());

    let product = a.clone() * b.clone();
    println!("a * b          = {:?}", product.into_data().to_vec::<f32>().unwrap());

    // ── Matmul ──────────────────────────────────────────────────────────────
    let m: Tensor<B, 2> = Tensor::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
    let n: Tensor<B, 2> = Tensor::from_floats([[5.0, 6.0], [7.0, 8.0]], &device);
    let mm = m.matmul(n);
    println!("matmul 2x2     = {:?}", mm.into_data().to_vec::<f32>().unwrap());
    // [[19, 22], [43, 50]]

    // ── Reductions ──────────────────────────────────────────────────────────
    let v: Tensor<B, 1> = Tensor::from_floats([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0], &device);
    println!("sum            = {:?}", v.clone().sum().into_scalar());
    println!("max            = {:?}", v.clone().max().into_scalar());
    println!("min            = {:?}", v.clone().min().into_scalar());

    // ── Math ────────────────────────────────────────────────────────────────
    let angles: Tensor<B, 1> = Tensor::from_floats([0.0, 0.5, 1.0, 1.5], &device);
    let sines = angles.clone().sin();
    println!("sin([0,0.5,1,1.5]) = {:?}", sines.into_data().to_vec::<f32>().unwrap());

    let exp: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0], &device);
    println!("exp([0,1,2])   = {:?}", exp.exp().into_data().to_vec::<f32>().unwrap());

    // ── Shape ops ───────────────────────────────────────────────────────────
    let t: Tensor<B, 2> = Tensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
    let transposed = t.clone().transpose();
    println!("transpose shape = {}", transposed.shape());
    // [3, 2]

    let reshaped = t.reshape([6]);
    println!("reshape [2,3]->[6] = {:?}", reshaped.into_data().to_vec::<f32>().unwrap());

    // ── Random tensors ──────────────────────────────────────────────────────
    MpsGraph::seed(&device, 42);
    let rand: Tensor<B, 2> = Tensor::random([4, 4], Distribution::Default, &device);
    println!("random [4,4] min/max = {:.3} / {:.3}",
        rand.clone().min().into_scalar(),
        rand.max().into_scalar());

    println!("\nAll operations ran on the Apple GPU via MPSGraph.");
}