1use alloc::string::String;
2use core::marker::PhantomData;
3
4use burn_backend::{Backend, BackendTypes, DType, DTypeUsage, DTypeUsageSet, DeviceId, DeviceOps};
5use burn_ir::{BackendIr, HandleKind, TensorHandle};
6use burn_std::device::Device;
7use burn_std::rand::{SeedableRng, StdRng};
8use burn_std::stub::Mutex;
9
10use crate::qtensor::FlexQTensor;
11use crate::tensor::FlexTensor;
12
13pub type FlexRng = StdRng;
15
16pub(crate) static SEED: Mutex<Option<FlexRng>> = Mutex::new(None);
19
20pub(crate) fn get_seeded_rng() -> FlexRng {
27 burn_std::rand::get_seeded_rng()
28}
29
30#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)]
34pub struct FlexDevice;
35
36impl Device for FlexDevice {
37 fn to_id(&self) -> DeviceId {
38 DeviceId::new(0, 0)
39 }
40
41 fn from_id(_id: DeviceId) -> Self {
42 Self
43 }
44}
45
46impl DeviceOps for FlexDevice {}
47
48impl core::fmt::Display for FlexDevice {
49 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
50 write!(f, "Cpu")
51 }
52}
53
54impl core::fmt::Debug for FlexDevice {
55 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
56 core::fmt::Display::fmt(self, f)
57 }
58}
59
60#[derive(Clone, Copy, Debug, Default)]
98pub struct Flex<E = f32, I = i32> {
99 _e: PhantomData<E>,
100 _i: PhantomData<I>,
101}
102
103impl BackendTypes for Flex {
104 type Device = FlexDevice;
105
106 type FloatTensorPrimitive = FlexTensor;
107 type FloatElem = f32;
111
112 type IntTensorPrimitive = FlexTensor;
113 type IntElem = i32;
118
119 type BoolTensorPrimitive = FlexTensor;
120 type BoolElem = bool;
121
122 type QuantizedTensorPrimitive = FlexQTensor;
123}
124
125impl Backend for Flex {
126 fn name(_device: &Self::Device) -> String {
127 "flex".into()
128 }
129
130 fn seed(_device: &Self::Device, seed: u64) {
131 let rng = FlexRng::seed_from_u64(seed);
132 let mut seed_lock = SEED.lock().unwrap();
133 *seed_lock = Some(rng);
134 }
135
136 fn device_count(_type_id: u16) -> usize {
137 1
138 }
139
140 fn dtype_usage(_device: &Self::Device, dtype: DType) -> DTypeUsageSet {
141 match dtype {
142 DType::F64 | DType::F32 | DType::F16 | DType::BF16 => {
144 DTypeUsage::Storage | DTypeUsage::Arithmetic
145 }
146 DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
147 DTypeUsage::Storage | DTypeUsage::Arithmetic
148 }
149 DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
150 DTypeUsage::Storage | DTypeUsage::Arithmetic
151 }
152 DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8) => {
157 DTypeUsage::Storage | DTypeUsage::Arithmetic
158 }
159 DType::Bool(burn_std::BoolStore::U32) => DTypeUsageSet::empty(),
160 DType::QFloat(_) => DTypeUsage::Storage.into(),
162 _ => DTypeUsageSet::empty(),
163 }
164 }
165}
166
167impl BackendIr for Flex {
168 type Handle = HandleKind<Self>;
169
170 fn float_tensor(handle: TensorHandle<Self::Handle>) -> FlexTensor {
171 match handle.handle {
172 HandleKind::Float(t) => t,
173 _ => panic!("Expected float handle, got {}", handle.handle.name()),
174 }
175 }
176
177 fn int_tensor(handle: TensorHandle<Self::Handle>) -> FlexTensor {
178 match handle.handle {
179 HandleKind::Int(t) => t,
180 _ => panic!("Expected int handle, got {}", handle.handle.name()),
181 }
182 }
183
184 fn bool_tensor(handle: TensorHandle<Self::Handle>) -> FlexTensor {
185 match handle.handle {
186 HandleKind::Bool(t) => t,
187 _ => panic!("Expected bool handle, got {}", handle.handle.name()),
188 }
189 }
190
191 fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> FlexQTensor {
192 match handle.handle {
193 HandleKind::Quantized(t) => t,
194 _ => panic!("Expected quantized handle, got {}", handle.handle.name()),
195 }
196 }
197
198 fn float_tensor_handle(tensor: FlexTensor) -> Self::Handle {
199 HandleKind::Float(tensor)
200 }
201
202 fn int_tensor_handle(tensor: FlexTensor) -> Self::Handle {
203 HandleKind::Int(tensor)
204 }
205
206 fn bool_tensor_handle(tensor: FlexTensor) -> Self::Handle {
207 HandleKind::Bool(tensor)
208 }
209
210 fn quantized_tensor_handle(tensor: FlexQTensor) -> Self::Handle {
211 HandleKind::Quantized(tensor)
212 }
213}
214
215#[cfg(test)]
218mod tests {
219 use burn_backend::{Backend, DType};
220 use burn_std::BoolStore;
221
222 use super::*;
223
224 #[test]
225 fn supports_bool_native() {
226 let device = FlexDevice;
227 assert!(Flex::supports_dtype(
228 &device,
229 DType::Bool(BoolStore::Native)
230 ));
231 }
232
233 #[test]
234 fn supports_bool_u8() {
235 let device = FlexDevice;
236 assert!(Flex::supports_dtype(&device, DType::Bool(BoolStore::U8)));
237 }
238
239 #[test]
240 fn does_not_support_bool_u32() {
241 let device = FlexDevice;
242 assert!(
243 !Flex::supports_dtype(&device, DType::Bool(BoolStore::U32)),
244 "Bool(U32) should not be supported: flex stores bools as 1 byte per element"
245 );
246 }
247
248 #[test]
249 fn bool_empty_preserves_native_dtype() {
250 use burn_backend::ops::BoolTensorOps;
251 let shape = burn_std::Shape::from(alloc::vec![3]);
252 let t = Flex::bool_empty(shape, &FlexDevice, burn_std::BoolDType::Native);
253 assert_eq!(t.dtype(), DType::Bool(BoolStore::Native));
254 }
255
256 #[test]
257 fn bool_empty_preserves_u8_dtype() {
258 use burn_backend::ops::BoolTensorOps;
259 let shape = burn_std::Shape::from(alloc::vec![3]);
260 let t = Flex::bool_empty(shape, &FlexDevice, burn_std::BoolDType::U8);
261 assert_eq!(t.dtype(), DType::Bool(BoolStore::U8));
262 }
263
264 #[test]
265 fn device_prints_as_cpu() {
266 use alloc::format;
267 assert_eq!(format!("{:?}", FlexDevice), "Cpu");
268 assert_eq!(format!("{}", FlexDevice), "Cpu");
269 }
270
271 #[test]
272 fn comparison_preserves_out_dtype_native() {
273 let lhs = FlexTensor::from_data(burn_backend::TensorData::from([1.0f32, 2.0, 3.0]));
274 let rhs = FlexTensor::from_data(burn_backend::TensorData::from([2.0f32, 2.0, 1.0]));
275 let result = crate::ops::comparison::greater(lhs, rhs, burn_std::BoolDType::Native);
276 assert_eq!(result.dtype(), DType::Bool(BoolStore::Native));
277 }
278
279 #[test]
280 fn comparison_preserves_out_dtype_u8() {
281 let lhs = FlexTensor::from_data(burn_backend::TensorData::from([1.0f32, 2.0, 3.0]));
282 let rhs = FlexTensor::from_data(burn_backend::TensorData::from([2.0f32, 2.0, 1.0]));
283 let result = crate::ops::comparison::greater(lhs, rhs, burn_std::BoolDType::U8);
284 assert_eq!(result.dtype(), DType::Bool(BoolStore::U8));
285 }
286
287 #[test]
288 #[should_panic(expected = "Bool(U32)")]
289 fn comparison_u32_panics() {
290 let lhs = FlexTensor::from_data(burn_backend::TensorData::from([1.0f32, 2.0]));
291 let rhs = FlexTensor::from_data(burn_backend::TensorData::from([2.0f32, 1.0]));
292 let _ = crate::ops::comparison::greater(lhs, rhs, burn_std::BoolDType::U32);
293 }
294
295 #[test]
296 fn bool_not_preserves_u8_dtype() {
297 use burn_backend::ops::BoolTensorOps;
298 let t_u8 = crate::ops::comparison::make_bool_tensor(
302 alloc::vec![1, 0, 1],
303 burn_std::Shape::from(alloc::vec![3]),
304 burn_std::BoolDType::U8,
305 );
306 let result = Flex::bool_not(t_u8);
307 assert_eq!(result.dtype(), DType::Bool(BoolStore::U8));
308 let data: &[u8] = result.bytes();
309 assert_eq!(&data[..3], &[0, 1, 0]);
310 }
311}