Skip to main content

burn_candle/
backend.rs

1use std::marker::PhantomData;
2
3use burn_backend::{
4    BackTrace, Backend, DType, DeviceId, DeviceOps, ExecutionError, QTensorPrimitive,
5    tensor::Device,
6};
7use burn_std::{
8    rand::{SeedableRng, StdRng},
9    stub::Mutex,
10};
11use candle_core::{DeviceLocation, backend::BackendDevice};
12
13use crate::{
14    CandleTensor, IntoDType,
15    element::{CandleElement, FloatCandleElement, IntCandleElement},
16};
17
18/// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations.
19///
20/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs
21/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.
22#[derive(Clone, Default, Debug)]
23pub struct Candle<F = f32, I = i64>
24where
25    F: FloatCandleElement,
26    I: IntCandleElement,
27{
28    _float: PhantomData<F>,
29    _int: PhantomData<I>,
30}
31
32// Seed for CPU device
33pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
34
35pub(crate) fn get_seeded_rng() -> StdRng {
36    let mut seed = SEED.lock().unwrap();
37    match seed.as_ref() {
38        Some(rng_seeded) => rng_seeded.clone(),
39        None => burn_std::rand::get_seeded_rng(),
40    }
41}
42
43pub(crate) fn set_seeded_rng(rng_seeded: StdRng) {
44    let mut seed = SEED.lock().unwrap();
45    *seed = Some(rng_seeded);
46}
47
48/// The device type for the candle backend.
49#[derive(Clone, Debug, PartialEq, Eq)]
50/// The device struct when using the `candle` backend.
51///
52/// To create a Cuda or Metal device from the index, use the associated methods to create the variant:
53/// ```no_run
54/// use burn_candle::CandleDevice;
55///
56/// // Create a Cuda device from its index
57/// let device = CandleDevice::cuda(0);
58/// // Create a Metal device from its index
59/// let device = CandleDevice::metal(0);
60/// ```
61#[derive(Default)]
62pub enum CandleDevice {
63    /// CPU device.
64    #[default]
65    Cpu,
66
67    /// Cuda device with the given index. The index is the index of the Cuda device in the list of
68    /// all Cuda devices found on the system.
69    Cuda(CudaDevice),
70
71    /// Metal device with the given index. The index is the index of the Metal device in the list of
72    /// all Metal devices found on the system.
73    Metal(MetalDevice),
74}
75
76impl CandleDevice {
77    /// Create a Cuda device with the given index.
78    /// The index is the index of the Cuda device in the list of all Cuda devices found on the system.
79    pub fn cuda(index: usize) -> Self {
80        CandleDevice::Cuda(CudaDevice {
81            device: candle_core::CudaDevice::new(index).unwrap(),
82            index,
83        })
84    }
85
86    /// Create a Metal device with the given index.
87    /// The index is the index of the Metal device in the list of all Metal devices found on the system.
88    pub fn metal(index: usize) -> Self {
89        CandleDevice::Metal(MetalDevice {
90            device: candle_core::MetalDevice::new(index).unwrap(),
91            index,
92        })
93    }
94
95    pub(crate) fn set_seed(&self, seed: u64) {
96        match self {
97            CandleDevice::Cpu => {
98                // candle_core::cpu_backend::CpuDevice.set_seed(seed).unwrap();
99                // Candle does not support seeding the CPU rng so we use a global seed
100                let rng = StdRng::seed_from_u64(seed);
101                set_seeded_rng(rng);
102            }
103            CandleDevice::Cuda(cuda_device) => cuda_device.device.set_seed(seed).unwrap(),
104            CandleDevice::Metal(metal_device) => metal_device.device.set_seed(seed).unwrap(),
105        }
106    }
107}
108
109#[derive(Clone, Debug)]
110/// A Cuda device for the `candle` backend.
111pub struct CudaDevice {
112    pub(crate) device: candle_core::CudaDevice,
113    /// The index of the Cuda device in the list of all devices on the system.
114    pub index: usize,
115}
116
117impl PartialEq for CudaDevice {
118    fn eq(&self, other: &Self) -> bool {
119        self.device.same_device(&other.device) && self.index == other.index
120    }
121}
122
123impl Eq for CudaDevice {}
124
125#[derive(Clone, Debug)]
126/// A Metal device for the `candle` backend.
127pub struct MetalDevice {
128    pub(crate) device: candle_core::MetalDevice,
129    /// The index of the Metal device in the list of all devices on the system.
130    pub index: usize,
131}
132
133impl PartialEq for MetalDevice {
134    fn eq(&self, other: &Self) -> bool {
135        self.device.same_device(&other.device) && self.index == other.index
136    }
137}
138
139impl Eq for MetalDevice {}
140
141impl From<CandleDevice> for candle_core::Device {
142    fn from(device: CandleDevice) -> Self {
143        match device {
144            CandleDevice::Cpu => candle_core::Device::Cpu,
145            CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device),
146            CandleDevice::Metal(device) => candle_core::Device::Metal(device.device),
147        }
148    }
149}
150
151impl From<candle_core::Device> for CandleDevice {
152    fn from(device: candle_core::Device) -> Self {
153        match device.location() {
154            DeviceLocation::Cpu => CandleDevice::Cpu,
155            DeviceLocation::Cuda { gpu_id } => {
156                if let candle_core::Device::Cuda(device) = device {
157                    CandleDevice::Cuda(CudaDevice {
158                        device,
159                        index: gpu_id,
160                    })
161                } else {
162                    panic!("Expected CUDA device.");
163                }
164            }
165            DeviceLocation::Metal { gpu_id } => {
166                if let candle_core::Device::Metal(device) = device {
167                    CandleDevice::Metal(MetalDevice {
168                        device,
169                        index: gpu_id,
170                    })
171                } else {
172                    panic!("Expected Metal device.");
173                }
174            }
175        }
176    }
177}
178
179impl burn_backend::Device for CandleDevice {
180    fn to_id(&self) -> burn_backend::DeviceId {
181        match self {
182            CandleDevice::Cuda(device) => DeviceId::new(0, device.index as u32),
183            CandleDevice::Metal(device) => DeviceId::new(1, device.index as u32),
184            CandleDevice::Cpu => DeviceId::new(2, 0),
185        }
186    }
187
188    fn from_id(device_id: DeviceId) -> Self {
189        match device_id.type_id {
190            0 => CandleDevice::cuda(device_id.index_id as usize),
191            1 => CandleDevice::metal(device_id.index_id as usize),
192            _ => CandleDevice::Cpu,
193        }
194    }
195
196    fn device_count(type_id: u16) -> usize {
197        // TODO: Fix that
198        1
199    }
200}
201impl DeviceOps for CandleDevice {}
202
203impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
204    type Device = CandleDevice;
205
206    type FloatTensorPrimitive = CandleTensor;
207    type FloatElem = F;
208
209    type IntTensorPrimitive = CandleTensor;
210    type IntElem = I;
211
212    type BoolTensorPrimitive = CandleTensor;
213    type BoolElem = u8;
214
215    type QuantizedTensorPrimitive = CandleTensor;
216
217    fn ad_enabled() -> bool {
218        false
219    }
220
221    fn name(device: &Self::Device) -> String {
222        match device {
223            CandleDevice::Cpu => "candle<cpu>",
224            CandleDevice::Cuda(..) => "candle<cuda>",
225            CandleDevice::Metal(..) => "candle<metal>",
226        }
227        .to_string()
228    }
229
230    fn seed(device: &CandleDevice, seed: u64) {
231        device.set_seed(seed);
232    }
233
234    fn sync(device: &Device<Self>) -> Result<(), ExecutionError> {
235        let device: candle_core::Device = (device.clone()).into();
236
237        match device {
238            candle_core::Device::Cpu => (),
239            candle_core::Device::Cuda(device) => {
240                #[cfg(feature = "cuda")]
241                device
242                    .synchronize()
243                    .map_err(|err| ExecutionError::Generic {
244                        reason: format!("Can't sync the cuda device: {err}"),
245                        backtrace: BackTrace::capture(),
246                    })?;
247            }
248            candle_core::Device::Metal(device) => {
249                // For some reason, device.wait_until_completed() does not seem to work,
250                // and neither does writing and reading a value with into_data
251                return Err(ExecutionError::Generic {
252                    reason:
253                        "Device synchronization unavailable with Metal device on Candle backend"
254                            .into(),
255                    backtrace: BackTrace::capture(),
256                });
257            }
258        }
259
260        Ok(())
261    }
262
263    fn supports_dtype(_device: &Device<Self>, dtype: DType) -> bool {
264        dtype.try_into_dtype().is_ok()
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use burn_std::QuantScheme;
271
272    use super::*;
273
274    #[test]
275    fn should_support_dtypes() {
276        type B = Candle<f32>;
277        let device = Default::default();
278
279        assert!(B::supports_dtype(&device, DType::F64));
280        assert!(B::supports_dtype(&device, DType::F32));
281        assert!(B::supports_dtype(&device, DType::Flex32));
282        assert!(B::supports_dtype(&device, DType::F16));
283        assert!(B::supports_dtype(&device, DType::BF16));
284        assert!(B::supports_dtype(&device, DType::I64));
285        assert!(B::supports_dtype(&device, DType::U32));
286        assert!(B::supports_dtype(&device, DType::U8));
287
288        assert!(!B::supports_dtype(&device, DType::U64));
289        assert!(!B::supports_dtype(&device, DType::U16));
290        assert!(!B::supports_dtype(&device, DType::I32));
291        assert!(!B::supports_dtype(&device, DType::I16));
292        assert!(!B::supports_dtype(&device, DType::I8));
293        assert!(!B::supports_dtype(&device, DType::Bool));
294        assert!(!B::supports_dtype(
295            &device,
296            DType::QFloat(QuantScheme::default())
297        ));
298    }
299}