1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
use crate::{PrecisionBridge, QuantElement, TchQTensor};
use super::element::TchElement;
use super::TchTensor;
use burn_tensor::backend::{Backend, DeviceId, DeviceOps, SyncType};
use burn_tensor::ops::IntTensorOps;
use burn_tensor::{Int, Tensor};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
/// The device struct when using the `tch` backend.
///
/// Note that you need to provide the device index when using Cuda.
///
/// # Example
///
/// ```no_run
/// use burn_tch::LibTorchDevice;
///
/// let device_gpu_1 = LibTorchDevice::Cuda(0); // First GPU
/// let device_gpu_2 = LibTorchDevice::Cuda(1); // Second GPU
/// let device_cpu = LibTorchDevice::Cpu; // CPU
/// let device_mps = LibTorchDevice::Mps; // Metal Performance Shaders
/// let device_vulkan = LibTorchDevice::Vulkan; // Vulkan
/// ```
pub enum LibTorchDevice {
/// CPU device.
Cpu,
/// Cuda device with the given index. The index is the index of the Cuda device in the list of
/// all Cuda devices found on the system.
Cuda(usize),
/// Metal Performance Shaders device.
Mps,
/// Vulkan device.
Vulkan,
}
impl From<LibTorchDevice> for tch::Device {
fn from(device: LibTorchDevice) -> Self {
match device {
LibTorchDevice::Cpu => tch::Device::Cpu,
LibTorchDevice::Cuda(num) => tch::Device::Cuda(num),
LibTorchDevice::Mps => tch::Device::Mps,
LibTorchDevice::Vulkan => tch::Device::Vulkan,
}
}
}
impl From<tch::Device> for LibTorchDevice {
fn from(device: tch::Device) -> Self {
match device {
tch::Device::Cpu => LibTorchDevice::Cpu,
tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),
tch::Device::Mps => LibTorchDevice::Mps,
tch::Device::Vulkan => LibTorchDevice::Vulkan,
}
}
}
impl DeviceOps for LibTorchDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
match self {
LibTorchDevice::Cpu => DeviceId::new(0, 0),
LibTorchDevice::Cuda(index) => DeviceId::new(1, *index as u32),
LibTorchDevice::Mps => DeviceId::new(2, 0),
LibTorchDevice::Vulkan => DeviceId::new(3, 0),
}
}
}
impl Default for LibTorchDevice {
fn default() -> Self {
Self::Cpu
}
}
/// Tensor backend that uses `LibTorch` with the [tch] crate for executing tensor operations.
///
/// This backend is compatible with a wide range of hardwares ranging from CPUs to GPUs, but
/// requires `LibTorch` to be installed correctly. The CPU version can be downloaded
/// automatically and the CUDA version as well by setting the `TORCH_CUDA_VERSION` environment
/// variable. For more complex configurations, check out the manual installation for
/// [burn-tch](https://github.com/tracel-ai/burn/tree/main/burn-tch).
///
/// Refer to the [tch] crate for more information.
#[derive(Clone, Copy, Default, Debug)]
pub struct LibTorch<E = f32, Q = i8> {
_e: E,
_q: Q,
}
impl<E: TchElement, Q: QuantElement> Backend for LibTorch<E, Q> {
type Device = LibTorchDevice;
type FullPrecisionBridge = PrecisionBridge<f32>;
type FloatTensorPrimitive<const D: usize> = TchTensor<E, D>;
type FloatElem = E;
type IntTensorPrimitive<const D: usize> = TchTensor<i64, D>;
type IntElem = i64;
type BoolTensorPrimitive<const D: usize> = TchTensor<bool, D>;
type QuantizedTensorPrimitive<const D: usize> = TchQTensor<Q, D>;
fn seed(seed: u64) {
tch::manual_seed(seed as i64);
}
fn ad_enabled() -> bool {
false
}
fn name() -> String {
"tch".to_string()
}
fn sync(device: &Self::Device, sync_type: SyncType) {
if sync_type == SyncType::Wait {
match device {
LibTorchDevice::Cpu => (),
LibTorchDevice::Cuda(index) => {
tch::Cuda::synchronize(*index as i64);
}
_ => {
// When there is no explicit way to synchronize, we write and read one value to sync
Tensor::<Self, 1, Int>::from_primitive(
<Self as IntTensorOps<Self>>::int_zeros([1].into(), device),
)
.into_data();
}
}
}
}
}