use std::marker::PhantomData;
use crate::IntoKind;
use super::TchTensor;
use super::element::TchElement;
use burn_backend::backend::{Backend, BackendTypes, DeviceId, DeviceOps, ExecutionError};
use burn_backend::ops::IntTensorOps;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[derive(Default)]
pub enum LibTorchDevice {
#[default]
Cpu,
Cuda(usize),
Mps,
Vulkan,
}
impl From<LibTorchDevice> for tch::Device {
#[allow(
unreachable_code,
reason = "CUDA branch always panics if the library is missing"
)]
fn from(device: LibTorchDevice) -> Self {
match device {
LibTorchDevice::Cpu => tch::Device::Cpu,
LibTorchDevice::Cuda(_num) => {
include!(concat!(env!("OUT_DIR"), "/tch_gpu_check.rs"));
tch::Device::Cuda(_num)
}
LibTorchDevice::Mps => tch::Device::Mps,
LibTorchDevice::Vulkan => tch::Device::Vulkan,
}
}
}
impl From<tch::Device> for LibTorchDevice {
fn from(device: tch::Device) -> Self {
match device {
tch::Device::Cpu => LibTorchDevice::Cpu,
tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),
tch::Device::Mps => LibTorchDevice::Mps,
tch::Device::Vulkan => LibTorchDevice::Vulkan,
}
}
}
impl burn_backend::Device for LibTorchDevice {
fn from_id(device_id: DeviceId) -> Self {
match device_id.type_id {
0 => Self::Cuda(device_id.index_id as usize),
1 => Self::Mps,
2 => Self::Cpu,
3 => Self::Vulkan,
_ => LibTorchDevice::Cpu,
}
}
fn to_id(&self) -> DeviceId {
match self {
LibTorchDevice::Cuda(index) => DeviceId::new(0, *index as u16),
LibTorchDevice::Mps => DeviceId::new(1, 0),
LibTorchDevice::Cpu => DeviceId::new(2, 0),
LibTorchDevice::Vulkan => DeviceId::new(3, 0),
}
}
}
impl DeviceOps for LibTorchDevice {}
#[derive(Clone, Copy, Default, Debug)]
pub struct LibTorch<E = f32> {
_e: PhantomData<E>,
}
impl<E: TchElement> BackendTypes for LibTorch<E> {
type Device = LibTorchDevice;
type FloatTensorPrimitive = TchTensor;
type FloatElem = E;
type IntTensorPrimitive = TchTensor;
type IntElem = i64;
type BoolTensorPrimitive = TchTensor;
type BoolElem = bool;
type QuantizedTensorPrimitive = TchTensor;
}
impl<E: TchElement> Backend for LibTorch<E> {
fn seed(_device: &Self::Device, seed: u64) {
tch::manual_seed(seed as i64);
}
fn ad_enabled(_device: &Self::Device) -> bool {
false
}
fn name(device: &Self::Device) -> String {
match device {
LibTorchDevice::Cpu => "libtorch<cpu>",
LibTorchDevice::Cuda(_) => "libtorch<cuda>",
LibTorchDevice::Mps => "libtorch<metal>",
LibTorchDevice::Vulkan => "libtorch<vulkan>",
}
.to_string()
}
fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
match device {
LibTorchDevice::Cpu => (),
LibTorchDevice::Cuda(index) => {
tch::Cuda::synchronize(*index as i64);
}
_ => {
burn_backend::read_sync(Self::int_into_data(Self::int_zeros(
[1].into(),
device,
<Self::IntElem as burn_backend::Element>::dtype().into(),
)))
.unwrap();
}
};
Ok(())
}
fn dtype_usage(
_device: &Self::Device,
dtype: burn_backend::DType,
) -> burn_backend::DTypeUsageSet {
if dtype.try_into_kind().is_ok() {
burn_backend::DTypeUsage::general()
} else {
burn_backend::DTypeUsageSet::empty()
}
}
fn device_count(_: u16) -> usize {
1
}
}