use burn::tensor::{Distribution, Tensor};
use burn_mpsgraph::prelude::*;
use burn_backend::ops::{ConvOptions, ModuleOps};
type B = MpsGraph;
fn main() {
let device = MpsGraphDevice::default();
println!("=== burn-mpsgraph conv2d example ===\n");
MpsGraph::seed(&device, 99);
let x: Tensor<B, 4> = Tensor::random([2, 3, 32, 32], Distribution::Default, &device);
println!("Input shape: {:?}", x.shape());
let w1: Tensor<B, 4> = Tensor::random([16, 3, 3, 3], Distribution::Normal(0.0, 0.1), &device);
let conv_out = MpsGraph::conv2d(
x.into_primitive().tensor(),
w1.into_primitive().tensor(),
None,
ConvOptions::new([1, 1], [1, 1], [1, 1], 1),
);
use burn_backend::TensorMetadata;
println!("After conv2d: {:?}", conv_out.shape());
let after_relu = Tensor::<B, 4>::from_primitive(
burn_backend::TensorPrimitive::Float(conv_out)
).clamp_min(0.0);
let pooled = MpsGraph::max_pool2d(
after_relu.into_primitive().tensor(),
[2, 2], [2, 2], [0, 0], [1, 1], false,
);
println!("After maxpool: {:?}", pooled.shape());
let w2: Tensor<B, 4> = Tensor::random([32, 16, 3, 3], Distribution::Normal(0.0, 0.1), &device);
let conv2 = MpsGraph::conv2d(
pooled,
w2.into_primitive().tensor(),
None,
ConvOptions::new([1, 1], [1, 1], [1, 1], 1),
);
println!("After conv2: {:?}", conv2.shape());
let gap = MpsGraph::adaptive_avg_pool2d(conv2, [1, 1]);
println!("After GAP: {:?}", gap.shape());
let flat = Tensor::<B, 4>::from_primitive(
burn_backend::TensorPrimitive::Float(gap)
).reshape([2, 32]);
println!("Flattened: {:?}", flat.shape());
println!("\nAll computed on Apple GPU via MPSGraph.");
}