burn-tch 0.21.0

LibTorch backend for the Burn framework using the tch bindings.
Documentation
use burn_backend::{TensorMetadata, ops::FloatTensorOps};
use burn_tch::{LibTorch, LibTorchDevice};

fn main() {
    assert!(
        tch::utils::has_cuda(),
        "Could not detect valid CUDA configuration"
    );

    type B = LibTorch<f32>;
    let device = LibTorchDevice::Cuda(0);

    // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first
    let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);
    let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());

    // Print the element-wise addition of the two tensors.
    println!("{}", B::float_add(tensor_1, tensor_2));
}