1use std::marker::PhantomData;
2
3use burn_backend::{
4 BackTrace, Backend, DType, DeviceId, DeviceOps, ExecutionError, QTensorPrimitive,
5 tensor::Device,
6};
7use burn_std::{
8 rand::{SeedableRng, StdRng},
9 stub::Mutex,
10};
11use candle_core::{DeviceLocation, backend::BackendDevice};
12
13use crate::{
14 CandleTensor, IntoDType,
15 element::{CandleElement, FloatCandleElement, IntCandleElement},
16};
17
18#[derive(Clone, Default, Debug)]
23pub struct Candle<F = f32, I = i64>
24where
25 F: FloatCandleElement,
26 I: IntCandleElement,
27{
28 _float: PhantomData<F>,
29 _int: PhantomData<I>,
30}
31
32pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
34
35pub(crate) fn get_seeded_rng() -> StdRng {
36 let mut seed = SEED.lock().unwrap();
37 match seed.as_ref() {
38 Some(rng_seeded) => rng_seeded.clone(),
39 None => burn_std::rand::get_seeded_rng(),
40 }
41}
42
43pub(crate) fn set_seeded_rng(rng_seeded: StdRng) {
44 let mut seed = SEED.lock().unwrap();
45 *seed = Some(rng_seeded);
46}
47
48#[derive(Clone, Debug, PartialEq, Eq)]
50#[derive(Default)]
62pub enum CandleDevice {
63 #[default]
65 Cpu,
66
67 Cuda(CudaDevice),
70
71 Metal(MetalDevice),
74}
75
76impl CandleDevice {
77 pub fn cuda(index: usize) -> Self {
80 CandleDevice::Cuda(CudaDevice {
81 device: candle_core::CudaDevice::new(index).unwrap(),
82 index,
83 })
84 }
85
86 pub fn metal(index: usize) -> Self {
89 CandleDevice::Metal(MetalDevice {
90 device: candle_core::MetalDevice::new(index).unwrap(),
91 index,
92 })
93 }
94
95 pub(crate) fn set_seed(&self, seed: u64) {
96 match self {
97 CandleDevice::Cpu => {
98 let rng = StdRng::seed_from_u64(seed);
101 set_seeded_rng(rng);
102 }
103 CandleDevice::Cuda(cuda_device) => cuda_device.device.set_seed(seed).unwrap(),
104 CandleDevice::Metal(metal_device) => metal_device.device.set_seed(seed).unwrap(),
105 }
106 }
107}
108
109#[derive(Clone, Debug)]
110pub struct CudaDevice {
112 pub(crate) device: candle_core::CudaDevice,
113 pub index: usize,
115}
116
117impl PartialEq for CudaDevice {
118 fn eq(&self, other: &Self) -> bool {
119 self.device.same_device(&other.device) && self.index == other.index
120 }
121}
122
123impl Eq for CudaDevice {}
124
125#[derive(Clone, Debug)]
126pub struct MetalDevice {
128 pub(crate) device: candle_core::MetalDevice,
129 pub index: usize,
131}
132
133impl PartialEq for MetalDevice {
134 fn eq(&self, other: &Self) -> bool {
135 self.device.same_device(&other.device) && self.index == other.index
136 }
137}
138
139impl Eq for MetalDevice {}
140
141impl From<CandleDevice> for candle_core::Device {
142 fn from(device: CandleDevice) -> Self {
143 match device {
144 CandleDevice::Cpu => candle_core::Device::Cpu,
145 CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device),
146 CandleDevice::Metal(device) => candle_core::Device::Metal(device.device),
147 }
148 }
149}
150
151impl From<candle_core::Device> for CandleDevice {
152 fn from(device: candle_core::Device) -> Self {
153 match device.location() {
154 DeviceLocation::Cpu => CandleDevice::Cpu,
155 DeviceLocation::Cuda { gpu_id } => {
156 if let candle_core::Device::Cuda(device) = device {
157 CandleDevice::Cuda(CudaDevice {
158 device,
159 index: gpu_id,
160 })
161 } else {
162 panic!("Expected CUDA device.");
163 }
164 }
165 DeviceLocation::Metal { gpu_id } => {
166 if let candle_core::Device::Metal(device) = device {
167 CandleDevice::Metal(MetalDevice {
168 device,
169 index: gpu_id,
170 })
171 } else {
172 panic!("Expected Metal device.");
173 }
174 }
175 }
176 }
177}
178
179impl burn_backend::Device for CandleDevice {
180 fn to_id(&self) -> burn_backend::DeviceId {
181 match self {
182 CandleDevice::Cuda(device) => DeviceId::new(0, device.index as u32),
183 CandleDevice::Metal(device) => DeviceId::new(1, device.index as u32),
184 CandleDevice::Cpu => DeviceId::new(2, 0),
185 }
186 }
187
188 fn from_id(device_id: DeviceId) -> Self {
189 match device_id.type_id {
190 0 => CandleDevice::cuda(device_id.index_id as usize),
191 1 => CandleDevice::metal(device_id.index_id as usize),
192 _ => CandleDevice::Cpu,
193 }
194 }
195
196 fn device_count(type_id: u16) -> usize {
197 1
199 }
200}
201impl DeviceOps for CandleDevice {}
202
203impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
204 type Device = CandleDevice;
205
206 type FloatTensorPrimitive = CandleTensor;
207 type FloatElem = F;
208
209 type IntTensorPrimitive = CandleTensor;
210 type IntElem = I;
211
212 type BoolTensorPrimitive = CandleTensor;
213 type BoolElem = u8;
214
215 type QuantizedTensorPrimitive = CandleTensor;
216
217 fn ad_enabled() -> bool {
218 false
219 }
220
221 fn name(device: &Self::Device) -> String {
222 match device {
223 CandleDevice::Cpu => "candle<cpu>",
224 CandleDevice::Cuda(..) => "candle<cuda>",
225 CandleDevice::Metal(..) => "candle<metal>",
226 }
227 .to_string()
228 }
229
230 fn seed(device: &CandleDevice, seed: u64) {
231 device.set_seed(seed);
232 }
233
234 fn sync(device: &Device<Self>) -> Result<(), ExecutionError> {
235 let device: candle_core::Device = (device.clone()).into();
236
237 match device {
238 candle_core::Device::Cpu => (),
239 candle_core::Device::Cuda(device) => {
240 #[cfg(feature = "cuda")]
241 device
242 .synchronize()
243 .map_err(|err| ExecutionError::Generic {
244 reason: format!("Can't sync the cuda device: {err}"),
245 backtrace: BackTrace::capture(),
246 })?;
247 }
248 candle_core::Device::Metal(device) => {
249 return Err(ExecutionError::Generic {
252 reason:
253 "Device synchronization unavailable with Metal device on Candle backend"
254 .into(),
255 backtrace: BackTrace::capture(),
256 });
257 }
258 }
259
260 Ok(())
261 }
262
263 fn supports_dtype(_device: &Device<Self>, dtype: DType) -> bool {
264 dtype.try_into_dtype().is_ok()
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use burn_std::QuantScheme;
271
272 use super::*;
273
274 #[test]
275 fn should_support_dtypes() {
276 type B = Candle<f32>;
277 let device = Default::default();
278
279 assert!(B::supports_dtype(&device, DType::F64));
280 assert!(B::supports_dtype(&device, DType::F32));
281 assert!(B::supports_dtype(&device, DType::Flex32));
282 assert!(B::supports_dtype(&device, DType::F16));
283 assert!(B::supports_dtype(&device, DType::BF16));
284 assert!(B::supports_dtype(&device, DType::I64));
285 assert!(B::supports_dtype(&device, DType::U32));
286 assert!(B::supports_dtype(&device, DType::U8));
287
288 assert!(!B::supports_dtype(&device, DType::U64));
289 assert!(!B::supports_dtype(&device, DType::U16));
290 assert!(!B::supports_dtype(&device, DType::I32));
291 assert!(!B::supports_dtype(&device, DType::I16));
292 assert!(!B::supports_dtype(&device, DType::I8));
293 assert!(!B::supports_dtype(&device, DType::Bool));
294 assert!(!B::supports_dtype(
295 &device,
296 DType::QFloat(QuantScheme::default())
297 ));
298 }
299}