gloss_burn_multibackend/
backend.rs1use crate::global_backend;
2use burn::{
3 prelude::Backend,
4 tensor::{backend::DeviceOps, ops::Device},
5};
6
7use crate::tensor::MultiBoolTensor;
8use crate::tensor::MultiFloatTensor;
9use crate::tensor::MultiIntTensor;
10
11#[cfg(feature = "burn-candle")]
14pub type CandleBackend = burn::backend::Candle<f32, i64>;
15#[cfg(feature = "burn-ndarray")]
16pub type NdArrayBackend = burn::backend::NdArray<f32, i32>;
17#[cfg(feature = "burn-wgpu")]
18pub type WgpuBackend = burn::backend::Wgpu<f32, i32>;
19
20#[derive(Clone, Copy, Default, Debug)]
21pub struct MultiBackend;
22
23impl Backend for MultiBackend {
24 type Device = MultiDevice;
25 type FloatTensorPrimitive = MultiFloatTensor;
26 type IntTensorPrimitive = MultiIntTensor;
27 type BoolTensorPrimitive = MultiBoolTensor;
28 type QuantizedTensorPrimitive = MultiIntTensor;
29
30 type FloatElem = f32;
31
32 type IntElem = i32;
34
35 type BoolElem = u8;
36
37 fn name(device: &Self::Device) -> String {
38 match device {
39 #[cfg(feature = "burn-candle")]
40 MultiDevice::Candle(_) => "candle",
41 #[cfg(feature = "burn-ndarray")]
42 MultiDevice::NdArray(_) => "ndarray",
43 #[cfg(feature = "burn-wgpu")]
44 MultiDevice::Wgpu(_) => "wgpu",
45 }
46 .to_string()
47 }
48
49 fn seed(_seed: u64) {
50 todo!()
52 }
53
54 type QuantizedEncoding = f32;
55
56 fn ad_enabled() -> bool {
57 false
58 }
59
60 fn sync(_device: &Self::Device) {}
61}
62
63#[allow(non_snake_case)]
64#[non_exhaustive]
65#[derive(Clone, Debug, PartialEq, Eq)]
66pub enum MultiDevice {
67 #[cfg(feature = "burn-candle")]
68 Candle(Device<CandleBackend>),
69 #[cfg(feature = "burn-ndarray")]
70 NdArray(Device<NdArrayBackend>),
71 #[cfg(feature = "burn-wgpu")]
72 Wgpu(Device<WgpuBackend>),
73 }
76impl Default for MultiDevice {
77 fn default() -> Self {
78 #[allow(unreachable_patterns)]
80 if let Some(global_device) = global_backend::get_global_burn_backend() {
81 match global_device {
82 #[cfg(feature = "burn-candle")]
83 global_backend::GlobalBackend::Candle => return Self::Candle(burn::backend::candle::CandleDevice::default()),
84 #[cfg(feature = "burn-ndarray")]
85 global_backend::GlobalBackend::NdArray => return Self::NdArray(burn::backend::ndarray::NdArrayDevice::default()),
86 #[cfg(feature = "burn-wgpu")]
87 global_backend::GlobalBackend::Wgpu => {
88 let existing_wgpu_device = wgpu_burn_global_device::get_global_wgpu_device();
90 return Self::Wgpu(existing_wgpu_device.unwrap_or_default());
91 }
92 _ => {
93 panic!("This global device {global_device:?} is not available because the corresponding feature is not enabled. Please enable the feature in Cargo.toml.");
94 }
95 }
96 }
97
98 #[cfg(feature = "burn-candle")]
100 {
101 Self::Candle(burn::backend::candle::CandleDevice::default())
102 }
103 #[cfg(all(not(feature = "burn-candle"), feature = "burn-ndarray"))]
104 {
105 Self::NdArray(burn::backend::ndarray::NdArrayDevice::default());
106 }
107 #[cfg(all(not(feature = "burn-candle"), not(feature = "burn-ndarray"), feature = "burn-wgpu"))]
108 {
109 let existing_wgpu_device = wgpu_burn_global_device::get_global_wgpu_device();
111 Self::Wgpu(existing_wgpu_device.unwrap_or_default())
112 }
113 #[cfg(all(not(feature = "burn-candle"), not(feature = "burn-ndarray"), not(feature = "burn-wgpu")))]
114 {
115 compile_error!("No backend feature enabled. Please enable at least one of the features: burn-candle, burn-ndarray, burn-wgpu");
116 }
117 }
118}
119
120#[allow(non_snake_case)]
121impl DeviceOps for MultiDevice {
122 fn id(&self) -> burn::tensor::backend::DeviceId {
123 match self {
124 #[cfg(feature = "burn-candle")]
125 MultiDevice::Candle(_) => burn::tensor::backend::DeviceId::new(0, 0),
126 #[cfg(feature = "burn-ndarray")]
127 MultiDevice::NdArray(_) => burn::tensor::backend::DeviceId::new(1, 0),
128 #[cfg(feature = "burn-wgpu")]
129 MultiDevice::Wgpu(_) => burn::tensor::backend::DeviceId::new(2, 0),
130 }
131 }
132}