burn_ndarray/
backend.rs

1use crate::rand::NdArrayRng;
2use crate::{NdArrayQTensor, NdArrayTensor};
3use crate::{
4    SharedArray,
5    element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
6};
7use alloc::string::String;
8use burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue};
9use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
10use burn_backend::{Backend, DType, DeviceId, DeviceOps};
11use burn_ir::{BackendIr, HandleKind, TensorHandle};
12use burn_std::stub::Mutex;
13use core::marker::PhantomData;
14use rand::SeedableRng;
15
16pub(crate) static SEED: Mutex<Option<NdArrayRng>> = Mutex::new(None);
17
18/// The device type for the ndarray backend.
19#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
20pub enum NdArrayDevice {
21    /// The CPU device.
22    #[default]
23    Cpu,
24}
25
26impl DeviceOps for NdArrayDevice {}
27
28impl burn_backend::Device for NdArrayDevice {
29    fn from_id(_device_id: DeviceId) -> Self {
30        Self::Cpu
31    }
32
33    fn to_id(&self) -> DeviceId {
34        DeviceId {
35            type_id: 0,
36            index_id: 0,
37        }
38    }
39
40    fn device_count(_type_id: u16) -> usize {
41        1
42    }
43}
44
45/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
46///
47/// This backend is compatible with CPUs and can be compiled for almost any platform, including
48/// `wasm`, `arm`, and `x86`.
49#[derive(Clone, Copy, Default, Debug)]
50pub struct NdArray<E = f32, I = i64, Q = i8>
51where
52    NdArrayTensor: From<SharedArray<E>>,
53    NdArrayTensor: From<SharedArray<I>>,
54{
55    _e: PhantomData<E>,
56    _i: PhantomData<I>,
57    _q: PhantomData<Q>,
58}
59
60impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q>
61where
62    NdArrayTensor: From<SharedArray<E>>,
63    NdArrayTensor: From<SharedArray<I>>,
64{
65    type Device = NdArrayDevice;
66
67    type FloatTensorPrimitive = NdArrayTensor;
68    type FloatElem = E;
69
70    type IntTensorPrimitive = NdArrayTensor;
71    type IntElem = I;
72
73    type BoolTensorPrimitive = NdArrayTensor;
74    type BoolElem = bool;
75
76    type QuantizedTensorPrimitive = NdArrayQTensor;
77
78    fn ad_enabled() -> bool {
79        false
80    }
81
82    fn name(_device: &Self::Device) -> String {
83        String::from("ndarray")
84    }
85
86    fn seed(_device: &Self::Device, seed: u64) {
87        let rng = NdArrayRng::seed_from_u64(seed);
88        let mut seed = SEED.lock().unwrap();
89        *seed = Some(rng);
90    }
91
92    fn supports_dtype(_device: &Self::Device, dtype: DType) -> bool {
93        match dtype {
94            DType::F64
95            | DType::F32
96            | DType::Flex32
97            | DType::I64
98            | DType::I32
99            | DType::I16
100            | DType::I8
101            | DType::U64
102            | DType::U32
103            | DType::U16
104            | DType::U8
105            | DType::Bool => true,
106            DType::F16 | DType::BF16 => false,
107            DType::QFloat(scheme) => {
108                match scheme {
109                    QuantScheme {
110                        level: QuantLevel::Tensor | QuantLevel::Block(_),
111                        mode: QuantMode::Symmetric,
112                        #[cfg(not(feature = "export_tests"))]
113                            value: QuantValue::Q8F | QuantValue::Q8S,
114                        // For tests, "native" sub-byte quant serves as a reference for value equality.
115                        // Values are stored as i8 regardless.
116                        #[cfg(feature = "export_tests")]
117                            value:
118                            QuantValue::Q8F
119                            | QuantValue::Q8S
120                            | QuantValue::Q4F
121                            | QuantValue::Q4S
122                            | QuantValue::Q2F
123                            | QuantValue::Q2S,
124                        store: QuantStore::Native,
125                        ..
126                    } => true,
127                    _scheme => false,
128                }
129            }
130        }
131    }
132}
133
134impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendIr for NdArray<E, I, Q>
135where
136    NdArrayTensor: From<SharedArray<E>>,
137    NdArrayTensor: From<SharedArray<I>>,
138{
139    type Handle = HandleKind<Self>;
140
141    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
142        match handle.handle {
143            HandleKind::Float(handle) => handle,
144            _ => panic!("Expected float handle, got {}", handle.handle.name()),
145        }
146    }
147
148    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
149        match handle.handle {
150            HandleKind::Int(handle) => handle,
151            _ => panic!("Expected int handle, got {}", handle.handle.name()),
152        }
153    }
154
155    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
156        match handle.handle {
157            HandleKind::Bool(handle) => handle,
158            _ => panic!("Expected bool handle, got {}", handle.handle.name()),
159        }
160    }
161
162    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
163        match handle.handle {
164            HandleKind::Quantized(handle) => handle,
165            _ => panic!("Expected quantized handle, got {}", handle.handle.name()),
166        }
167    }
168
169    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
170        HandleKind::Float(tensor)
171    }
172
173    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
174        HandleKind::Int(tensor)
175    }
176
177    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
178        HandleKind::Bool(tensor)
179    }
180
181    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
182        HandleKind::Quantized(tensor)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use burn_backend::QTensorPrimitive;
190
191    #[test]
192    fn should_support_dtypes() {
193        type B = NdArray<f32>;
194        let device = Default::default();
195
196        assert!(B::supports_dtype(&device, DType::F64));
197        assert!(B::supports_dtype(&device, DType::F32));
198        assert!(B::supports_dtype(&device, DType::Flex32));
199        assert!(B::supports_dtype(&device, DType::I64));
200        assert!(B::supports_dtype(&device, DType::I32));
201        assert!(B::supports_dtype(&device, DType::I16));
202        assert!(B::supports_dtype(&device, DType::I8));
203        assert!(B::supports_dtype(&device, DType::U64));
204        assert!(B::supports_dtype(&device, DType::U32));
205        assert!(B::supports_dtype(&device, DType::U16));
206        assert!(B::supports_dtype(&device, DType::U8));
207        assert!(B::supports_dtype(&device, DType::Bool));
208        assert!(B::supports_dtype(
209            &device,
210            DType::QFloat(NdArrayQTensor::default_scheme())
211        ));
212
213        assert!(!B::supports_dtype(&device, DType::F16));
214        assert!(!B::supports_dtype(&device, DType::BF16));
215        // QuantStore::U32 not supported
216        assert!(!B::supports_dtype(
217            &device,
218            DType::QFloat(QuantScheme::default())
219        ));
220    }
221}