use burn::tensor::{Distribution, Tensor};
use burn_mpsgraph::prelude::*;
type B = MpsGraph;
struct Mlp {
w1: Tensor<B, 2>,
b1: Tensor<B, 1>,
w2: Tensor<B, 2>,
b2: Tensor<B, 1>,
}
impl Mlp {
fn new(in_dim: usize, hidden: usize, out_dim: usize, device: &MpsGraphDevice) -> Self {
MpsGraph::seed(device, 7);
let scale = |t: Tensor<B, 2>| t * 0.1;
Self {
w1: scale(Tensor::random([in_dim, hidden], Distribution::Normal(0.0, 1.0), device)),
b1: Tensor::zeros([hidden], device),
w2: scale(Tensor::random([hidden, out_dim], Distribution::Normal(0.0, 1.0), device)),
b2: Tensor::zeros([out_dim], device),
}
}
fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let h = x.matmul(self.w1.clone()) + self.b1.clone().unsqueeze::<2>();
let h = h.clamp_min(0.0);
let logits = h.matmul(self.w2.clone()) + self.b2.clone().unsqueeze::<2>();
let max = logits.clone().max_dim(1);
let shifted = logits - max;
let e = shifted.exp();
let s = e.clone().sum_dim(1);
e / s
}
}
fn main() {
let device = MpsGraphDevice::default();
println!("=== burn-mpsgraph inference example ===\n");
let mlp = Mlp::new(16, 64, 10, &device);
let x: Tensor<B, 2> = Tensor::random([8, 16], Distribution::Default, &device);
let probs = mlp.forward(x);
println!("Output shape: {}", probs.shape());
let row_sums = probs.clone().sum_dim(1);
let sums = row_sums.into_data().to_vec::<f32>().unwrap();
println!("Row sums (should all be ≈ 1.0):");
for (i, s) in sums.iter().enumerate() {
println!(" sample {i}: {s:.6}");
}
let preds = probs.argmax(1);
println!("\nPredicted classes: {:?}", preds.into_data().to_vec::<i32>().unwrap());
}