use burn::tensor::{Distribution, Tensor};
use burn_mpsgraph::prelude::*;
type B = MpsGraph;
fn main() {
let device = MpsGraphDevice::default();
println!("=== burn-mpsgraph basic example ===\n");
let a: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device);
let b: Tensor<B, 1> = Tensor::from_floats([10.0, 20.0, 30.0, 40.0], &device);
let sum = a.clone() + b.clone();
println!("a + b = {:?}", sum.into_data().to_vec::<f32>().unwrap());
let product = a.clone() * b.clone();
println!("a * b = {:?}", product.into_data().to_vec::<f32>().unwrap());
let m: Tensor<B, 2> = Tensor::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);
let n: Tensor<B, 2> = Tensor::from_floats([[5.0, 6.0], [7.0, 8.0]], &device);
let mm = m.matmul(n);
println!("matmul 2x2 = {:?}", mm.into_data().to_vec::<f32>().unwrap());
let v: Tensor<B, 1> = Tensor::from_floats([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0], &device);
println!("sum = {:?}", v.clone().sum().into_scalar());
println!("max = {:?}", v.clone().max().into_scalar());
println!("min = {:?}", v.clone().min().into_scalar());
let angles: Tensor<B, 1> = Tensor::from_floats([0.0, 0.5, 1.0, 1.5], &device);
let sines = angles.clone().sin();
println!("sin([0,0.5,1,1.5]) = {:?}", sines.into_data().to_vec::<f32>().unwrap());
let exp: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0], &device);
println!("exp([0,1,2]) = {:?}", exp.exp().into_data().to_vec::<f32>().unwrap());
let t: Tensor<B, 2> = Tensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);
let transposed = t.clone().transpose();
println!("transpose shape = {}", transposed.shape());
let reshaped = t.reshape([6]);
println!("reshape [2,3]->[6] = {:?}", reshaped.into_data().to_vec::<f32>().unwrap());
MpsGraph::seed(&device, 42);
let rand: Tensor<B, 2> = Tensor::random([4, 4], Distribution::Default, &device);
println!("random [4,4] min/max = {:.3} / {:.3}",
rand.clone().min().into_scalar(),
rand.max().into_scalar());
println!("\nAll operations ran on the Apple GPU via MPSGraph.");
}