gloss_burn_multibackend/
backend.rs

1use crate::global_backend;
2use burn::{
3    prelude::Backend,
4    tensor::{backend::DeviceOps, ops::Device},
5};
6
7use crate::tensor::MultiBoolTensor;
8use crate::tensor::MultiFloatTensor;
9use crate::tensor::MultiIntTensor;
10
11//TODO maybe switch to i32 for all backends?
12//IF YOU CHANGE THIS, CHANGE THE IntTensorOps int_from_data also together with TensorMetadata for MultiIntTensor
13#[cfg(feature = "burn-candle")]
14pub type CandleBackend = burn::backend::Candle<f32, i64>;
15#[cfg(feature = "burn-ndarray")]
16pub type NdArrayBackend = burn::backend::NdArray<f32, i32>;
17#[cfg(feature = "burn-wgpu")]
18pub type WgpuBackend = burn::backend::Wgpu<f32, i32>;
19
20#[derive(Clone, Copy, Default, Debug)]
21pub struct MultiBackend;
22
23impl Backend for MultiBackend {
24    type Device = MultiDevice;
25    type FloatTensorPrimitive = MultiFloatTensor;
26    type IntTensorPrimitive = MultiIntTensor;
27    type BoolTensorPrimitive = MultiBoolTensor;
28    type QuantizedTensorPrimitive = MultiIntTensor;
29
30    type FloatElem = f32;
31
32    // TODO this probably needs to be i64 if candle is used
33    type IntElem = i32;
34
35    type BoolElem = u8;
36
37    fn name(device: &Self::Device) -> String {
38        match device {
39            #[cfg(feature = "burn-candle")]
40            MultiDevice::Candle(_) => "candle",
41            #[cfg(feature = "burn-ndarray")]
42            MultiDevice::NdArray(_) => "ndarray",
43            #[cfg(feature = "burn-wgpu")]
44            MultiDevice::Wgpu(_) => "wgpu",
45        }
46        .to_string()
47    }
48
49    fn seed(_seed: u64) {
50        //with a newer version of burn we have here access to the device so we can use a match statement
51        todo!()
52    }
53
54    type QuantizedEncoding = f32;
55
56    fn ad_enabled() -> bool {
57        false
58    }
59
60    fn sync(_device: &Self::Device) {}
61}
62
63#[allow(non_snake_case)]
64#[non_exhaustive]
65#[derive(Clone, Debug, PartialEq, Eq)]
66pub enum MultiDevice {
67    #[cfg(feature = "burn-candle")]
68    Candle(Device<CandleBackend>),
69    #[cfg(feature = "burn-ndarray")]
70    NdArray(Device<NdArrayBackend>),
71    #[cfg(feature = "burn-wgpu")]
72    Wgpu(Device<WgpuBackend>),
73    // #[cfg(feature = "autodiff")]
74    // Autodiff(Box<Device<MultiBackend>>),
75}
76impl Default for MultiDevice {
77    fn default() -> Self {
78        //if we set a global device, we select backend based on that
79        #[allow(unreachable_patterns)]
80        if let Some(global_device) = global_backend::get_global_burn_backend() {
81            match global_device {
82                #[cfg(feature = "burn-candle")]
83                global_backend::GlobalBackend::Candle => return Self::Candle(burn::backend::candle::CandleDevice::default()),
84                #[cfg(feature = "burn-ndarray")]
85                global_backend::GlobalBackend::NdArray => return Self::NdArray(burn::backend::ndarray::NdArrayDevice::default()),
86                #[cfg(feature = "burn-wgpu")]
87                global_backend::GlobalBackend::Wgpu => {
88                    //If the viewer has already been initialized, we want to use the same wgpu device, if not we create a new one
89                    let existing_wgpu_device = wgpu_burn_global_device::get_global_wgpu_device();
90                    return Self::Wgpu(existing_wgpu_device.unwrap_or_default());
91                }
92                _ => {
93                    panic!("This global device {global_device:?} is not available because the corresponding feature is not enabled. Please enable the feature in Cargo.toml.");
94                }
95            }
96        }
97
98        //if no global device is set, we default to candle if available, otherwise ndarray, otherwise wgpu
99        #[cfg(feature = "burn-candle")]
100        {
101            Self::Candle(burn::backend::candle::CandleDevice::default())
102        }
103        #[cfg(all(not(feature = "burn-candle"), feature = "burn-ndarray"))]
104        {
105            Self::NdArray(burn::backend::ndarray::NdArrayDevice::default());
106        }
107        #[cfg(all(not(feature = "burn-candle"), not(feature = "burn-ndarray"), feature = "burn-wgpu"))]
108        {
109            //If the viewer has already been initialized, we want to use the same wgpu device, if not we create a new one
110            let existing_wgpu_device = wgpu_burn_global_device::get_global_wgpu_device();
111            Self::Wgpu(existing_wgpu_device.unwrap_or_default())
112        }
113        #[cfg(all(not(feature = "burn-candle"), not(feature = "burn-ndarray"), not(feature = "burn-wgpu")))]
114        {
115            compile_error!("No backend feature enabled. Please enable at least one of the features: burn-candle, burn-ndarray, burn-wgpu");
116        }
117    }
118}
119
120#[allow(non_snake_case)]
121impl DeviceOps for MultiDevice {
122    fn id(&self) -> burn::tensor::backend::DeviceId {
123        match self {
124            #[cfg(feature = "burn-candle")]
125            MultiDevice::Candle(_) => burn::tensor::backend::DeviceId::new(0, 0),
126            #[cfg(feature = "burn-ndarray")]
127            MultiDevice::NdArray(_) => burn::tensor::backend::DeviceId::new(1, 0),
128            #[cfg(feature = "burn-wgpu")]
129            MultiDevice::Wgpu(_) => burn::tensor::backend::DeviceId::new(2, 0),
130        }
131    }
132}