burn_ndarray/
backend.rs

1use crate::rand::NdArrayRng;
2use crate::{NdArrayQTensor, NdArrayTensor};
3use crate::{
4    SharedArray,
5    element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
6};
7use alloc::string::String;
8use burn_common::stub::Mutex;
9use burn_ir::{BackendIr, HandleKind, TensorHandle};
10use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
11use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
12use core::marker::PhantomData;
13use rand::SeedableRng;
14
15pub(crate) static SEED: Mutex<Option<NdArrayRng>> = Mutex::new(None);
16
17/// The device type for the ndarray backend.
18#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
19pub enum NdArrayDevice {
20    /// The CPU device.
21    #[default]
22    Cpu,
23}
24
25impl DeviceOps for NdArrayDevice {}
26
27impl burn_common::device::Device for NdArrayDevice {
28    fn from_id(_device_id: DeviceId) -> Self {
29        Self::Cpu
30    }
31
32    fn to_id(&self) -> DeviceId {
33        DeviceId {
34            type_id: 0,
35            index_id: 0,
36        }
37    }
38
39    fn device_count(_type_id: u16) -> usize {
40        1
41    }
42}
43
44/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.
45///
46/// This backend is compatible with CPUs and can be compiled for almost any platform, including
47/// `wasm`, `arm`, and `x86`.
48#[derive(Clone, Copy, Default, Debug)]
49pub struct NdArray<E = f32, I = i64, Q = i8>
50where
51    NdArrayTensor: From<SharedArray<E>>,
52    NdArrayTensor: From<SharedArray<I>>,
53{
54    _e: PhantomData<E>,
55    _i: PhantomData<I>,
56    _q: PhantomData<Q>,
57}
58
59impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q>
60where
61    NdArrayTensor: From<SharedArray<E>>,
62    NdArrayTensor: From<SharedArray<I>>,
63{
64    type Device = NdArrayDevice;
65
66    type FloatTensorPrimitive = NdArrayTensor;
67    type FloatElem = E;
68
69    type IntTensorPrimitive = NdArrayTensor;
70    type IntElem = I;
71
72    type BoolTensorPrimitive = NdArrayTensor;
73    type BoolElem = bool;
74
75    type QuantizedTensorPrimitive = NdArrayQTensor;
76
77    fn ad_enabled() -> bool {
78        false
79    }
80
81    fn name(_device: &Self::Device) -> String {
82        String::from("ndarray")
83    }
84
85    fn seed(_device: &Self::Device, seed: u64) {
86        let rng = NdArrayRng::seed_from_u64(seed);
87        let mut seed = SEED.lock().unwrap();
88        *seed = Some(rng);
89    }
90}
91
92impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendIr for NdArray<E, I, Q>
93where
94    NdArrayTensor: From<SharedArray<E>>,
95    NdArrayTensor: From<SharedArray<I>>,
96{
97    type Handle = HandleKind<Self>;
98
99    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
100        match handle.handle {
101            HandleKind::Float(handle) => handle,
102            _ => panic!("Expected float handle, got {}", handle.handle.name()),
103        }
104    }
105
106    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
107        match handle.handle {
108            HandleKind::Int(handle) => handle,
109            _ => panic!("Expected int handle, got {}", handle.handle.name()),
110        }
111    }
112
113    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
114        match handle.handle {
115            HandleKind::Bool(handle) => handle,
116            _ => panic!("Expected bool handle, got {}", handle.handle.name()),
117        }
118    }
119
120    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
121        match handle.handle {
122            HandleKind::Quantized(handle) => handle,
123            _ => panic!("Expected quantized handle, got {}", handle.handle.name()),
124        }
125    }
126
127    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
128        HandleKind::Float(tensor)
129    }
130
131    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
132        HandleKind::Int(tensor)
133    }
134
135    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
136        HandleKind::Bool(tensor)
137    }
138
139    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
140        HandleKind::Quantized(tensor)
141    }
142}