1use crate::rand::NdArrayRng;
2use crate::{NdArrayQTensor, NdArrayTensor};
3use crate::{
4 SharedArray,
5 element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
6};
7use alloc::string::String;
8use burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue};
9use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
10use burn_backend::{Backend, BackendTypes, DType, DeviceId, DeviceOps};
11use burn_ir::{BackendIr, HandleKind, TensorHandle};
12use burn_std::BoolStore;
13use burn_std::stub::Mutex;
14use core::marker::PhantomData;
15use rand::SeedableRng;
16
17pub(crate) static SEED: Mutex<Option<NdArrayRng>> = Mutex::new(None);
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
21pub enum NdArrayDevice {
22 #[default]
24 Cpu,
25}
26
27impl DeviceOps for NdArrayDevice {}
28
29impl burn_backend::Device for NdArrayDevice {
30 fn from_id(_device_id: DeviceId) -> Self {
31 Self::Cpu
32 }
33
34 fn to_id(&self) -> DeviceId {
35 DeviceId {
36 type_id: 0,
37 index_id: 0,
38 }
39 }
40}
41
42#[derive(Clone, Copy, Default, Debug)]
47pub struct NdArray<E = f32, I = i64, Q = i8>
48where
49 NdArrayTensor: From<SharedArray<E>>,
50 NdArrayTensor: From<SharedArray<I>>,
51{
52 _e: PhantomData<E>,
53 _i: PhantomData<I>,
54 _q: PhantomData<Q>,
55}
56impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendTypes
57 for NdArray<E, I, Q>
58where
59 NdArrayTensor: From<SharedArray<E>>,
60 NdArrayTensor: From<SharedArray<I>>,
61{
62 type Device = NdArrayDevice;
63
64 type FloatTensorPrimitive = NdArrayTensor;
65 type FloatElem = E;
66
67 type IntTensorPrimitive = NdArrayTensor;
68 type IntElem = I;
69
70 type BoolTensorPrimitive = NdArrayTensor;
71 type BoolElem = bool;
72
73 type QuantizedTensorPrimitive = NdArrayQTensor;
74}
75impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q>
76where
77 NdArrayTensor: From<SharedArray<E>>,
78 NdArrayTensor: From<SharedArray<I>>,
79{
80 fn ad_enabled(_device: &Self::Device) -> bool {
81 false
82 }
83
84 fn name(_device: &Self::Device) -> String {
85 String::from("ndarray")
86 }
87
88 fn seed(_device: &Self::Device, seed: u64) {
89 let rng = NdArrayRng::seed_from_u64(seed);
90 let mut seed = SEED.lock().unwrap();
91 *seed = Some(rng);
92 }
93
94 fn dtype_usage(_device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
95 match dtype {
96 DType::F64
97 | DType::F32
98 | DType::Flex32
99 | DType::I64
100 | DType::I32
101 | DType::I16
102 | DType::I8
103 | DType::U64
104 | DType::U32
105 | DType::U16
106 | DType::U8
107 | DType::Bool(BoolStore::Native) => burn_backend::DTypeUsage::general(),
108 DType::F16 | DType::BF16 | DType::Bool(_) => burn_backend::DTypeUsageSet::empty(),
109 DType::QFloat(scheme) => {
110 match scheme {
111 QuantScheme {
112 level: QuantLevel::Tensor | QuantLevel::Block(_),
113 mode: QuantMode::Symmetric,
114 #[cfg(not(feature = "export_tests"))]
115 value: QuantValue::Q8F | QuantValue::Q8S,
116 #[cfg(feature = "export_tests")]
119 value:
120 QuantValue::Q8F
121 | QuantValue::Q8S
122 | QuantValue::Q4F
123 | QuantValue::Q4S
124 | QuantValue::Q2F
125 | QuantValue::Q2S,
126 store: QuantStore::Native,
127 ..
128 } => burn_backend::DTypeUsage::general(),
129 _scheme => burn_backend::DTypeUsageSet::empty(),
130 }
131 }
132 }
133 }
134
135 fn device_count(_: u16) -> usize {
136 1
137 }
138}
139
140impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendIr for NdArray<E, I, Q>
141where
142 NdArrayTensor: From<SharedArray<E>>,
143 NdArrayTensor: From<SharedArray<I>>,
144{
145 type Handle = HandleKind<Self>;
146
147 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
148 match handle.handle {
149 HandleKind::Float(handle) => handle,
150 _ => panic!("Expected float handle, got {}", handle.handle.name()),
151 }
152 }
153
154 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
155 match handle.handle {
156 HandleKind::Int(handle) => handle,
157 _ => panic!("Expected int handle, got {}", handle.handle.name()),
158 }
159 }
160
161 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
162 match handle.handle {
163 HandleKind::Bool(handle) => handle,
164 _ => panic!("Expected bool handle, got {}", handle.handle.name()),
165 }
166 }
167
168 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
169 match handle.handle {
170 HandleKind::Quantized(handle) => handle,
171 _ => panic!("Expected quantized handle, got {}", handle.handle.name()),
172 }
173 }
174
175 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
176 HandleKind::Float(tensor)
177 }
178
179 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
180 HandleKind::Int(tensor)
181 }
182
183 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
184 HandleKind::Bool(tensor)
185 }
186
187 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
188 HandleKind::Quantized(tensor)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use burn_backend::QTensorPrimitive;
196
197 #[test]
198 fn should_support_dtypes() {
199 type B = NdArray<f32>;
200 let device = Default::default();
201
202 assert!(B::supports_dtype(&device, DType::F64));
203 assert!(B::supports_dtype(&device, DType::F32));
204 assert!(B::supports_dtype(&device, DType::Flex32));
205 assert!(B::supports_dtype(&device, DType::I64));
206 assert!(B::supports_dtype(&device, DType::I32));
207 assert!(B::supports_dtype(&device, DType::I16));
208 assert!(B::supports_dtype(&device, DType::I8));
209 assert!(B::supports_dtype(&device, DType::U64));
210 assert!(B::supports_dtype(&device, DType::U32));
211 assert!(B::supports_dtype(&device, DType::U16));
212 assert!(B::supports_dtype(&device, DType::U8));
213 assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native)));
214 assert!(B::supports_dtype(
215 &device,
216 DType::QFloat(NdArrayQTensor::default_scheme())
217 ));
218
219 assert!(!B::supports_dtype(&device, DType::F16));
220 assert!(!B::supports_dtype(&device, DType::BF16));
221 assert!(!B::supports_dtype(
223 &device,
224 DType::QFloat(QuantScheme::default())
225 ));
226 }
227}