burn-tch 0.21.0

LibTorch backend for the Burn framework using the tch bindings.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
use burn_backend::{TensorMetadata, ops::FloatTensorOps};
use burn_tch::{LibTorch, LibTorchDevice};

fn main() {
    assert!(tch::utils::has_mps(), "Could not detect MPS");

    type B = LibTorch<f32>;
    let device = LibTorchDevice::Mps;

    // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
    let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);
    let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());

    // Print the element-wise addition of the two tensors.
    println!("{}", B::float_add(tensor_1, tensor_2));
}