Skip to main content

burn_ndarray/
backend.rs

1use crate::rand::NdArrayRng;
2use crate::{NdArrayQTensor, NdArrayTensor};
3use crate::{
4    SharedArray,
5    element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
6};
7use alloc::string::String;
8use burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue};
9use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
10use burn_backend::{Backend, BackendTypes, DType, DeviceId, DeviceOps};
11use burn_ir::{BackendIr, HandleKind, TensorHandle};
12use burn_std::BoolStore;
13use burn_std::stub::Mutex;
14use core::marker::PhantomData;
15use rand::SeedableRng;
16
17pub(crate) static SEED: Mutex<Option<NdArrayRng>> = Mutex::new(None);
18
19/// The device type for the ndarray backend.
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
21pub enum NdArrayDevice {
22    /// The CPU device.
23    #[default]
24    Cpu,
25}
26
27impl DeviceOps for NdArrayDevice {}
28
29impl burn_backend::Device for NdArrayDevice {
30    fn from_id(_device_id: DeviceId) -> Self {
31        Self::Cpu
32    }
33
34    fn to_id(&self) -> DeviceId {
35        DeviceId {
36            type_id: 0,
37            index_id: 0,
38        }
39    }
40}
41
42/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
43///
44/// This backend is compatible with CPUs and can be compiled for almost any platform, including
45/// `wasm`, `arm`, and `x86`.
46#[derive(Clone, Copy, Default, Debug)]
47pub struct NdArray<E = f32, I = i64, Q = i8>
48where
49    NdArrayTensor: From<SharedArray<E>>,
50    NdArrayTensor: From<SharedArray<I>>,
51{
52    _e: PhantomData<E>,
53    _i: PhantomData<I>,
54    _q: PhantomData<Q>,
55}
56impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendTypes
57    for NdArray<E, I, Q>
58where
59    NdArrayTensor: From<SharedArray<E>>,
60    NdArrayTensor: From<SharedArray<I>>,
61{
62    type Device = NdArrayDevice;
63
64    type FloatTensorPrimitive = NdArrayTensor;
65    type FloatElem = E;
66
67    type IntTensorPrimitive = NdArrayTensor;
68    type IntElem = I;
69
70    type BoolTensorPrimitive = NdArrayTensor;
71    type BoolElem = bool;
72
73    type QuantizedTensorPrimitive = NdArrayQTensor;
74}
75impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q>
76where
77    NdArrayTensor: From<SharedArray<E>>,
78    NdArrayTensor: From<SharedArray<I>>,
79{
80    fn ad_enabled(_device: &Self::Device) -> bool {
81        false
82    }
83
84    fn name(_device: &Self::Device) -> String {
85        String::from("ndarray")
86    }
87
88    fn seed(_device: &Self::Device, seed: u64) {
89        let rng = NdArrayRng::seed_from_u64(seed);
90        let mut seed = SEED.lock().unwrap();
91        *seed = Some(rng);
92    }
93
94    fn dtype_usage(_device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
95        match dtype {
96            DType::F64
97            | DType::F32
98            | DType::Flex32
99            | DType::I64
100            | DType::I32
101            | DType::I16
102            | DType::I8
103            | DType::U64
104            | DType::U32
105            | DType::U16
106            | DType::U8
107            | DType::Bool(BoolStore::Native) => burn_backend::DTypeUsage::general(),
108            DType::F16 | DType::BF16 | DType::Bool(_) => burn_backend::DTypeUsageSet::empty(),
109            DType::QFloat(scheme) => {
110                match scheme {
111                    QuantScheme {
112                        level: QuantLevel::Tensor | QuantLevel::Block(_),
113                        mode: QuantMode::Symmetric,
114                        #[cfg(not(feature = "export_tests"))]
115                            value: QuantValue::Q8F | QuantValue::Q8S,
116                        // For tests, "native" sub-byte quant serves as a reference for value equality.
117                        // Values are stored as i8 regardless.
118                        #[cfg(feature = "export_tests")]
119                            value:
120                            QuantValue::Q8F
121                            | QuantValue::Q8S
122                            | QuantValue::Q4F
123                            | QuantValue::Q4S
124                            | QuantValue::Q2F
125                            | QuantValue::Q2S,
126                        store: QuantStore::Native,
127                        ..
128                    } => burn_backend::DTypeUsage::general(),
129                    _scheme => burn_backend::DTypeUsageSet::empty(),
130                }
131            }
132        }
133    }
134
135    fn device_count(_: u16) -> usize {
136        1
137    }
138}
139
140impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendIr for NdArray<E, I, Q>
141where
142    NdArrayTensor: From<SharedArray<E>>,
143    NdArrayTensor: From<SharedArray<I>>,
144{
145    type Handle = HandleKind<Self>;
146
147    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
148        match handle.handle {
149            HandleKind::Float(handle) => handle,
150            _ => panic!("Expected float handle, got {}", handle.handle.name()),
151        }
152    }
153
154    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
155        match handle.handle {
156            HandleKind::Int(handle) => handle,
157            _ => panic!("Expected int handle, got {}", handle.handle.name()),
158        }
159    }
160
161    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
162        match handle.handle {
163            HandleKind::Bool(handle) => handle,
164            _ => panic!("Expected bool handle, got {}", handle.handle.name()),
165        }
166    }
167
168    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
169        match handle.handle {
170            HandleKind::Quantized(handle) => handle,
171            _ => panic!("Expected quantized handle, got {}", handle.handle.name()),
172        }
173    }
174
175    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
176        HandleKind::Float(tensor)
177    }
178
179    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
180        HandleKind::Int(tensor)
181    }
182
183    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
184        HandleKind::Bool(tensor)
185    }
186
187    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
188        HandleKind::Quantized(tensor)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use burn_backend::QTensorPrimitive;
196
197    #[test]
198    fn should_support_dtypes() {
199        type B = NdArray<f32>;
200        let device = Default::default();
201
202        assert!(B::supports_dtype(&device, DType::F64));
203        assert!(B::supports_dtype(&device, DType::F32));
204        assert!(B::supports_dtype(&device, DType::Flex32));
205        assert!(B::supports_dtype(&device, DType::I64));
206        assert!(B::supports_dtype(&device, DType::I32));
207        assert!(B::supports_dtype(&device, DType::I16));
208        assert!(B::supports_dtype(&device, DType::I8));
209        assert!(B::supports_dtype(&device, DType::U64));
210        assert!(B::supports_dtype(&device, DType::U32));
211        assert!(B::supports_dtype(&device, DType::U16));
212        assert!(B::supports_dtype(&device, DType::U8));
213        assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native)));
214        assert!(B::supports_dtype(
215            &device,
216            DType::QFloat(NdArrayQTensor::default_scheme())
217        ));
218
219        assert!(!B::supports_dtype(&device, DType::F16));
220        assert!(!B::supports_dtype(&device, DType::BF16));
221        // QuantStore::U32 not supported
222        assert!(!B::supports_dtype(
223            &device,
224            DType::QFloat(QuantScheme::default())
225        ));
226    }
227}