use burn_backend::{TensorMetadata, ops::FloatTensorOps};
use burn_tch::{LibTorch, LibTorchDevice};
fn main() {
assert!(tch::utils::has_mps(), "Could not detect MPS");
type B = LibTorch<f32>;
let device = LibTorchDevice::Mps;
let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);
let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());
println!("{}", B::float_add(tensor_1, tensor_2));
}