use std::marker::PhantomData;
use burn_tensor::{backend::Backend, Device};
use candle_core::DeviceLocation;
use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
CandleTensor, PrecisionBridge,
};
#[derive(Clone, Copy, Default, Debug)]
pub struct Candle<F = f32, I = i64>
where
F: FloatCandleElement,
I: IntCandleElement,
{
_float: PhantomData<F>,
_int: PhantomData<I>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CandleDevice {
Cpu,
Cuda(usize),
Metal(usize),
}
impl From<CandleDevice> for candle_core::Device {
fn from(device: CandleDevice) -> Self {
match device {
CandleDevice::Cpu => candle_core::Device::Cpu,
CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(),
CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(),
}
}
}
impl From<candle_core::Device> for CandleDevice {
fn from(device: candle_core::Device) -> Self {
match device.location() {
DeviceLocation::Cpu => CandleDevice::Cpu,
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id),
}
}
}
impl Default for CandleDevice {
fn default() -> Self {
Self::Cpu
}
}
impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type Device = CandleDevice;
type FullPrecisionBridge = PrecisionBridge<f32>;
type FloatTensorPrimitive<const D: usize> = CandleTensor<Self::FloatElem, D>;
type FloatElem = F;
type IntTensorPrimitive<const D: usize> = CandleTensor<Self::IntElem, D>;
type IntElem = I;
type BoolTensorPrimitive<const D: usize> = CandleTensor<u8, D>;
fn ad_enabled() -> bool {
false
}
fn name() -> String {
"candle".to_string()
}
fn seed(seed: u64) {
panic!("Manual seed not supported by Candle. ")
}
fn sync(device: &Device<Self>) {
let device: candle_core::Device = (*device).into();
match device {
candle_core::Device::Cpu => (),
candle_core::Device::Cuda(device) => {
#[cfg(feature = "cuda")]
device.synchronize().unwrap();
}
candle_core::Device::Metal(device) => {
panic!("Device synchronization unavailable with Metal device on Candle backend")
}
}
}
}