use std::marker::PhantomData;
use burn_tensor::{
backend::{Backend, DeviceId, DeviceOps, SyncType},
quantization::{QTensorPrimitive, QuantizationStrategy},
Device,
};
use candle_core::DeviceLocation;
use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
CandleQTensor, CandleTensor, PrecisionBridge,
};
#[derive(Clone, Copy, Default, Debug)]
pub struct Candle<F = f32, I = i64>
where
F: FloatCandleElement,
I: IntCandleElement,
{
_float: PhantomData<F>,
_int: PhantomData<I>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CandleDevice {
Cpu,
Cuda(usize),
Metal(usize),
}
impl From<CandleDevice> for candle_core::Device {
fn from(device: CandleDevice) -> Self {
match device {
CandleDevice::Cpu => candle_core::Device::Cpu,
CandleDevice::Cuda(ordinal) => candle_core::Device::new_cuda(ordinal).unwrap(),
CandleDevice::Metal(ordinal) => candle_core::Device::new_metal(ordinal).unwrap(),
}
}
}
impl From<candle_core::Device> for CandleDevice {
fn from(device: candle_core::Device) -> Self {
match device.location() {
DeviceLocation::Cpu => CandleDevice::Cpu,
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
DeviceLocation::Metal { gpu_id } => CandleDevice::Metal(gpu_id),
}
}
}
impl DeviceOps for CandleDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
match self {
CandleDevice::Cpu => DeviceId::new(0, 0),
CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32),
CandleDevice::Metal(index) => DeviceId::new(2, *index as u32),
}
}
}
impl Default for CandleDevice {
fn default() -> Self {
Self::Cpu
}
}
impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type Device = CandleDevice;
type FullPrecisionBridge = PrecisionBridge<f32>;
type FloatTensorPrimitive<const D: usize> = CandleTensor<Self::FloatElem, D>;
type FloatElem = F;
type IntTensorPrimitive<const D: usize> = CandleTensor<Self::IntElem, D>;
type IntElem = I;
type BoolTensorPrimitive<const D: usize> = CandleTensor<u8, D>;
type QuantizedTensorPrimitive<const D: usize> = CandleQTensor<D>;
fn ad_enabled() -> bool {
false
}
fn name() -> String {
"candle".to_string()
}
fn seed(seed: u64) {
panic!("Manual seed not supported by Candle. ")
}
fn sync(device: &Device<Self>, sync_type: SyncType) {
match sync_type {
SyncType::Wait => {
let device: candle_core::Device = (*device).into();
match device {
candle_core::Device::Cpu => (),
candle_core::Device::Cuda(device) => {
#[cfg(feature = "cuda")]
device.synchronize().unwrap();
}
candle_core::Device::Metal(device) => {
panic!("Device synchronization unavailable with Metal device on Candle backend")
}
}
}
SyncType::Flush => (), };
}
}