1use crate::rand::NdArrayRng;
2use crate::{NdArrayQTensor, NdArrayTensor};
3use crate::{
4 SharedArray,
5 element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
6};
7use alloc::string::String;
8use burn_common::stub::Mutex;
9use burn_ir::{BackendIr, HandleKind, TensorHandle};
10use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
11use burn_tensor::ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
12use core::marker::PhantomData;
13use rand::SeedableRng;
14
15pub(crate) static SEED: Mutex<Option<NdArrayRng>> = Mutex::new(None);
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
19pub enum NdArrayDevice {
20 #[default]
22 Cpu,
23}
24
25impl DeviceOps for NdArrayDevice {}
26
27impl burn_common::device::Device for NdArrayDevice {
28 fn from_id(_device_id: DeviceId) -> Self {
29 Self::Cpu
30 }
31
32 fn to_id(&self) -> DeviceId {
33 DeviceId {
34 type_id: 0,
35 index_id: 0,
36 }
37 }
38
39 fn device_count(_type_id: u16) -> usize {
40 1
41 }
42}
43
44#[derive(Clone, Copy, Default, Debug)]
49pub struct NdArray<E = f32, I = i64, Q = i8>
50where
51 NdArrayTensor: From<SharedArray<E>>,
52 NdArrayTensor: From<SharedArray<I>>,
53{
54 _e: PhantomData<E>,
55 _i: PhantomData<I>,
56 _q: PhantomData<Q>,
57}
58
59impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q>
60where
61 NdArrayTensor: From<SharedArray<E>>,
62 NdArrayTensor: From<SharedArray<I>>,
63{
64 type Device = NdArrayDevice;
65
66 type FloatTensorPrimitive = NdArrayTensor;
67 type FloatElem = E;
68
69 type IntTensorPrimitive = NdArrayTensor;
70 type IntElem = I;
71
72 type BoolTensorPrimitive = NdArrayTensor;
73 type BoolElem = bool;
74
75 type QuantizedTensorPrimitive = NdArrayQTensor;
76
77 fn ad_enabled() -> bool {
78 false
79 }
80
81 fn name(_device: &Self::Device) -> String {
82 String::from("ndarray")
83 }
84
85 fn seed(_device: &Self::Device, seed: u64) {
86 let rng = NdArrayRng::seed_from_u64(seed);
87 let mut seed = SEED.lock().unwrap();
88 *seed = Some(rng);
89 }
90}
91
92impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendIr for NdArray<E, I, Q>
93where
94 NdArrayTensor: From<SharedArray<E>>,
95 NdArrayTensor: From<SharedArray<I>>,
96{
97 type Handle = HandleKind<Self>;
98
99 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
100 match handle.handle {
101 HandleKind::Float(handle) => handle,
102 _ => panic!("Expected float handle, got {}", handle.handle.name()),
103 }
104 }
105
106 fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
107 match handle.handle {
108 HandleKind::Int(handle) => handle,
109 _ => panic!("Expected int handle, got {}", handle.handle.name()),
110 }
111 }
112
113 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
114 match handle.handle {
115 HandleKind::Bool(handle) => handle,
116 _ => panic!("Expected bool handle, got {}", handle.handle.name()),
117 }
118 }
119
120 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
121 match handle.handle {
122 HandleKind::Quantized(handle) => handle,
123 _ => panic!("Expected quantized handle, got {}", handle.handle.name()),
124 }
125 }
126
127 fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
128 HandleKind::Float(tensor)
129 }
130
131 fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
132 HandleKind::Int(tensor)
133 }
134
135 fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
136 HandleKind::Bool(tensor)
137 }
138
139 fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
140 HandleKind::Quantized(tensor)
141 }
142}