Skip to main content

burn_flex/
backend.rs

1use alloc::string::String;
2use core::marker::PhantomData;
3
4use burn_backend::{Backend, BackendTypes, DType, DTypeUsage, DTypeUsageSet, DeviceId, DeviceOps};
5use burn_ir::{BackendIr, HandleKind, TensorHandle};
6use burn_std::device::Device;
7use burn_std::rand::{SeedableRng, StdRng};
8use burn_std::stub::Mutex;
9
10use crate::qtensor::FlexQTensor;
11use crate::tensor::FlexTensor;
12
13/// Type alias for the RNG used by Flex.
14pub type FlexRng = StdRng;
15
16/// Global seed storage for reproducible random number generation.
17/// Uses Mutex for thread-safe RNG state management.
18pub(crate) static SEED: Mutex<Option<FlexRng>> = Mutex::new(None);
19
20/// Fallback RNG when `SEED` is empty (consumed or never set).
21///
22/// The seeding flow is: `Backend::seed()` stores a `FlexRng` in `SEED`. Random
23/// ops (`float_random`, `int_random`) call `SEED.lock().take()`, consuming it for
24/// that op and falling back to this function for subsequent calls. This function
25/// delegates to burn_std's own entropy source.
26pub(crate) fn get_seeded_rng() -> FlexRng {
27    burn_std::rand::get_seeded_rng()
28}
29
30/// CPU device for the Flex backend.
31///
32/// Unit struct since there's only one CPU device.
33#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)]
34pub struct FlexDevice;
35
36impl Device for FlexDevice {
37    fn to_id(&self) -> DeviceId {
38        DeviceId::new(0, 0)
39    }
40
41    fn from_id(_id: DeviceId) -> Self {
42        Self
43    }
44}
45
46impl DeviceOps for FlexDevice {}
47
48impl core::fmt::Display for FlexDevice {
49    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
50        write!(f, "Cpu")
51    }
52}
53
54impl core::fmt::Debug for FlexDevice {
55    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
56        core::fmt::Display::fmt(self, f)
57    }
58}
59
60/// The Flex backend, a fast, portable CPU backend for Burn.
61///
62/// The `E` and `I` type parameters exist purely to match the shape of other Burn
63/// backends (e.g. `NdArray<E, I, Q>`) so `Flex` slots into `burn-dispatch`'s
64/// generic dispatch macros. The body of `Flex` uses runtime `DType` dispatch, so
65/// both parameters are phantom and unused at runtime.
66///
67/// # Limitations of the phantom generics
68///
69/// The `Backend` impl is provided only for the default instantiation
70/// `Flex<f32, i32>`. Writing `Flex` (with no arguments) resolves to the default
71/// and works exactly as before. Writing `Flex<f64, i64>` or any other non-default
72/// combination is a valid Rust type but will not satisfy trait bounds requiring
73/// `Backend`, producing errors like:
74///
75/// ```text
76/// the trait bound `Flex<f64, i64>: Backend` is not satisfied
77/// ```
78///
79/// This is a deliberate compromise for the initial migration: making `Flex`
80/// generic over element types at the trait-impl level is a follow-up that would
81/// require rewriting all `impl FooOps<Flex> for Flex` blocks plus internal
82/// `Flex::method()` calls (tracked in
83/// [#4762](https://github.com/tracel-ai/burn/issues/4762)). Until then, treat
84/// the generic parameters as opaque shape placeholders; real element-type
85/// selection happens at runtime via `DType`.
86///
87/// The bound is locked in by a compile-fail doctest so that if someone later
88/// makes the `Backend` impl generic over `E`/`I`, this documentation gets
89/// flagged as out of date:
90///
91/// ```compile_fail
92/// use burn_backend::Backend;
93/// use burn_flex::Flex;
94/// fn requires_backend<B: Backend>() {}
95/// requires_backend::<Flex<f64, i64>>();
96/// ```
97#[derive(Clone, Copy, Debug, Default)]
98pub struct Flex<E = f32, I = i32> {
99    _e: PhantomData<E>,
100    _i: PhantomData<I>,
101}
102
103impl BackendTypes for Flex {
104    type Device = FlexDevice;
105
106    type FloatTensorPrimitive = FlexTensor;
107    /// Default float element type. Determines the dtype for `.float()` conversions and
108    /// `Tensor::from_data` when no explicit dtype is provided.
109    /// Prefer explicit dtypes via `(&device, DType::F32)`.
110    type FloatElem = f32;
111
112    type IntTensorPrimitive = FlexTensor;
113    /// Default int element type. Determines the dtype for `.int()` conversions and
114    /// `Tensor::from_data` when no explicit dtype is provided.
115    /// Set to i32 to match burn's ecosystem default (test suite, record settings, burn-remote).
116    /// Prefer explicit dtypes via `(&device, DType::I32)`.
117    type IntElem = i32;
118
119    type BoolTensorPrimitive = FlexTensor;
120    type BoolElem = bool;
121
122    type QuantizedTensorPrimitive = FlexQTensor;
123}
124
125impl Backend for Flex {
126    fn name(_device: &Self::Device) -> String {
127        "flex".into()
128    }
129
130    fn seed(_device: &Self::Device, seed: u64) {
131        let rng = FlexRng::seed_from_u64(seed);
132        let mut seed_lock = SEED.lock().unwrap();
133        *seed_lock = Some(rng);
134    }
135
136    fn device_count(_type_id: u16) -> usize {
137        1
138    }
139
140    fn dtype_usage(_device: &Self::Device, dtype: DType) -> DTypeUsageSet {
141        match dtype {
142            // Full support for standard types
143            DType::F64 | DType::F32 | DType::F16 | DType::BF16 => {
144                DTypeUsage::Storage | DTypeUsage::Arithmetic
145            }
146            DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
147                DTypeUsage::Storage | DTypeUsage::Arithmetic
148            }
149            DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
150                DTypeUsage::Storage | DTypeUsage::Arithmetic
151            }
152            // Bool storage: flex stores bools as 1 byte per element, so Native and
153            // U8 are both supported (they share the same layout, only the tag
154            // differs). Bool(U32) would require 4-byte-per-element storage
155            // throughout the backend and is not yet implemented.
156            DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8) => {
157                DTypeUsage::Storage | DTypeUsage::Arithmetic
158            }
159            DType::Bool(burn_std::BoolStore::U32) => DTypeUsageSet::empty(),
160            // Quantized types: storage only for now
161            DType::QFloat(_) => DTypeUsage::Storage.into(),
162            _ => DTypeUsageSet::empty(),
163        }
164    }
165}
166
167impl BackendIr for Flex {
168    type Handle = HandleKind<Self>;
169
170    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FlexTensor {
171        match handle.handle {
172            HandleKind::Float(t) => t,
173            _ => panic!("Expected float handle, got {}", handle.handle.name()),
174        }
175    }
176
177    fn int_tensor(handle: TensorHandle<Self::Handle>) -> FlexTensor {
178        match handle.handle {
179            HandleKind::Int(t) => t,
180            _ => panic!("Expected int handle, got {}", handle.handle.name()),
181        }
182    }
183
184    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> FlexTensor {
185        match handle.handle {
186            HandleKind::Bool(t) => t,
187            _ => panic!("Expected bool handle, got {}", handle.handle.name()),
188        }
189    }
190
191    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> FlexQTensor {
192        match handle.handle {
193            HandleKind::Quantized(t) => t,
194            _ => panic!("Expected quantized handle, got {}", handle.handle.name()),
195        }
196    }
197
198    fn float_tensor_handle(tensor: FlexTensor) -> Self::Handle {
199        HandleKind::Float(tensor)
200    }
201
202    fn int_tensor_handle(tensor: FlexTensor) -> Self::Handle {
203        HandleKind::Int(tensor)
204    }
205
206    fn bool_tensor_handle(tensor: FlexTensor) -> Self::Handle {
207        HandleKind::Bool(tensor)
208    }
209
210    fn quantized_tensor_handle(tensor: FlexQTensor) -> Self::Handle {
211        HandleKind::Quantized(tensor)
212    }
213}
214
215// Ops traits are implemented in the ops module
216
217#[cfg(test)]
218mod tests {
219    use burn_backend::{Backend, DType};
220    use burn_std::BoolStore;
221
222    use super::*;
223
224    #[test]
225    fn supports_bool_native() {
226        let device = FlexDevice;
227        assert!(Flex::supports_dtype(
228            &device,
229            DType::Bool(BoolStore::Native)
230        ));
231    }
232
233    #[test]
234    fn supports_bool_u8() {
235        let device = FlexDevice;
236        assert!(Flex::supports_dtype(&device, DType::Bool(BoolStore::U8)));
237    }
238
239    #[test]
240    fn does_not_support_bool_u32() {
241        let device = FlexDevice;
242        assert!(
243            !Flex::supports_dtype(&device, DType::Bool(BoolStore::U32)),
244            "Bool(U32) should not be supported: flex stores bools as 1 byte per element"
245        );
246    }
247
248    #[test]
249    fn bool_empty_preserves_native_dtype() {
250        use burn_backend::ops::BoolTensorOps;
251        let shape = burn_std::Shape::from(alloc::vec![3]);
252        let t = Flex::bool_empty(shape, &FlexDevice, burn_std::BoolDType::Native);
253        assert_eq!(t.dtype(), DType::Bool(BoolStore::Native));
254    }
255
256    #[test]
257    fn bool_empty_preserves_u8_dtype() {
258        use burn_backend::ops::BoolTensorOps;
259        let shape = burn_std::Shape::from(alloc::vec![3]);
260        let t = Flex::bool_empty(shape, &FlexDevice, burn_std::BoolDType::U8);
261        assert_eq!(t.dtype(), DType::Bool(BoolStore::U8));
262    }
263
264    #[test]
265    fn device_prints_as_cpu() {
266        use alloc::format;
267        assert_eq!(format!("{:?}", FlexDevice), "Cpu");
268        assert_eq!(format!("{}", FlexDevice), "Cpu");
269    }
270
271    #[test]
272    fn comparison_preserves_out_dtype_native() {
273        let lhs = FlexTensor::from_data(burn_backend::TensorData::from([1.0f32, 2.0, 3.0]));
274        let rhs = FlexTensor::from_data(burn_backend::TensorData::from([2.0f32, 2.0, 1.0]));
275        let result = crate::ops::comparison::greater(lhs, rhs, burn_std::BoolDType::Native);
276        assert_eq!(result.dtype(), DType::Bool(BoolStore::Native));
277    }
278
279    #[test]
280    fn comparison_preserves_out_dtype_u8() {
281        let lhs = FlexTensor::from_data(burn_backend::TensorData::from([1.0f32, 2.0, 3.0]));
282        let rhs = FlexTensor::from_data(burn_backend::TensorData::from([2.0f32, 2.0, 1.0]));
283        let result = crate::ops::comparison::greater(lhs, rhs, burn_std::BoolDType::U8);
284        assert_eq!(result.dtype(), DType::Bool(BoolStore::U8));
285    }
286
287    #[test]
288    #[should_panic(expected = "Bool(U32)")]
289    fn comparison_u32_panics() {
290        let lhs = FlexTensor::from_data(burn_backend::TensorData::from([1.0f32, 2.0]));
291        let rhs = FlexTensor::from_data(burn_backend::TensorData::from([2.0f32, 1.0]));
292        let _ = crate::ops::comparison::greater(lhs, rhs, burn_std::BoolDType::U32);
293    }
294
295    #[test]
296    fn bool_not_preserves_u8_dtype() {
297        use burn_backend::ops::BoolTensorOps;
298        // Construct a Bool(U8) tensor directly to verify bool_not preserves
299        // the dtype tag across the op. from_data would produce Bool(Native),
300        // so we use make_bool_tensor to get the U8 tag.
301        let t_u8 = crate::ops::comparison::make_bool_tensor(
302            alloc::vec![1, 0, 1],
303            burn_std::Shape::from(alloc::vec![3]),
304            burn_std::BoolDType::U8,
305        );
306        let result = Flex::bool_not(t_u8);
307        assert_eq!(result.dtype(), DType::Bool(BoolStore::U8));
308        let data: &[u8] = result.bytes();
309        assert_eq!(&data[..3], &[0, 1, 0]);
310    }
311}