use alloc::string::String;
use core::marker::PhantomData;
use burn_backend::{Backend, BackendTypes, DType, DTypeUsage, DTypeUsageSet, DeviceId, DeviceOps};
use burn_ir::{BackendIr, HandleKind, TensorHandle};
use burn_std::device::Device;
use burn_std::rand::{SeedableRng, StdRng};
use burn_std::stub::Mutex;
use crate::qtensor::FlexQTensor;
use crate::tensor::FlexTensor;
pub type FlexRng = StdRng;
pub(crate) static SEED: Mutex<Option<FlexRng>> = Mutex::new(None);
pub(crate) fn get_seeded_rng() -> FlexRng {
burn_std::rand::get_seeded_rng()
}
#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)]
pub struct FlexDevice;
impl Device for FlexDevice {
fn to_id(&self) -> DeviceId {
DeviceId::new(0, 0)
}
fn from_id(_id: DeviceId) -> Self {
Self
}
}
impl DeviceOps for FlexDevice {}
impl core::fmt::Display for FlexDevice {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Cpu")
}
}
impl core::fmt::Debug for FlexDevice {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
core::fmt::Display::fmt(self, f)
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct Flex<E = f32, I = i32> {
_e: PhantomData<E>,
_i: PhantomData<I>,
}
impl BackendTypes for Flex {
type Device = FlexDevice;
type FloatTensorPrimitive = FlexTensor;
type FloatElem = f32;
type IntTensorPrimitive = FlexTensor;
type IntElem = i32;
type BoolTensorPrimitive = FlexTensor;
type BoolElem = bool;
type QuantizedTensorPrimitive = FlexQTensor;
}
impl Backend for Flex {
fn name(_device: &Self::Device) -> String {
"flex".into()
}
fn seed(_device: &Self::Device, seed: u64) {
let rng = FlexRng::seed_from_u64(seed);
let mut seed_lock = SEED.lock().unwrap();
*seed_lock = Some(rng);
}
fn device_count(_type_id: u16) -> usize {
1
}
fn dtype_usage(_device: &Self::Device, dtype: DType) -> DTypeUsageSet {
match dtype {
DType::F64 | DType::F32 | DType::F16 | DType::BF16 => {
DTypeUsage::Storage | DTypeUsage::Arithmetic
}
DType::I64 | DType::I32 | DType::I16 | DType::I8 => {
DTypeUsage::Storage | DTypeUsage::Arithmetic
}
DType::U64 | DType::U32 | DType::U16 | DType::U8 => {
DTypeUsage::Storage | DTypeUsage::Arithmetic
}
DType::Bool(burn_std::BoolStore::Native | burn_std::BoolStore::U8) => {
DTypeUsage::Storage | DTypeUsage::Arithmetic
}
DType::Bool(burn_std::BoolStore::U32) => DTypeUsageSet::empty(),
DType::QFloat(_) => DTypeUsage::Storage.into(),
_ => DTypeUsageSet::empty(),
}
}
}
impl BackendIr for Flex {
type Handle = HandleKind<Self>;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FlexTensor {
match handle.handle {
HandleKind::Float(t) => t,
_ => panic!("Expected float handle, got {}", handle.handle.name()),
}
}
fn int_tensor(handle: TensorHandle<Self::Handle>) -> FlexTensor {
match handle.handle {
HandleKind::Int(t) => t,
_ => panic!("Expected int handle, got {}", handle.handle.name()),
}
}
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> FlexTensor {
match handle.handle {
HandleKind::Bool(t) => t,
_ => panic!("Expected bool handle, got {}", handle.handle.name()),
}
}
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> FlexQTensor {
match handle.handle {
HandleKind::Quantized(t) => t,
_ => panic!("Expected quantized handle, got {}", handle.handle.name()),
}
}
fn float_tensor_handle(tensor: FlexTensor) -> Self::Handle {
HandleKind::Float(tensor)
}
fn int_tensor_handle(tensor: FlexTensor) -> Self::Handle {
HandleKind::Int(tensor)
}
fn bool_tensor_handle(tensor: FlexTensor) -> Self::Handle {
HandleKind::Bool(tensor)
}
fn quantized_tensor_handle(tensor: FlexQTensor) -> Self::Handle {
HandleKind::Quantized(tensor)
}
}
#[cfg(test)]
mod tests {
use burn_backend::{Backend, DType};
use burn_std::BoolStore;
use super::*;
#[test]
fn supports_bool_native() {
let device = FlexDevice;
assert!(Flex::supports_dtype(
&device,
DType::Bool(BoolStore::Native)
));
}
#[test]
fn supports_bool_u8() {
let device = FlexDevice;
assert!(Flex::supports_dtype(&device, DType::Bool(BoolStore::U8)));
}
#[test]
fn does_not_support_bool_u32() {
let device = FlexDevice;
assert!(
!Flex::supports_dtype(&device, DType::Bool(BoolStore::U32)),
"Bool(U32) should not be supported: flex stores bools as 1 byte per element"
);
}
#[test]
fn bool_empty_preserves_native_dtype() {
use burn_backend::ops::BoolTensorOps;
let shape = burn_std::Shape::from(alloc::vec![3]);
let t = Flex::bool_empty(shape, &FlexDevice, burn_std::BoolDType::Native);
assert_eq!(t.dtype(), DType::Bool(BoolStore::Native));
}
#[test]
fn bool_empty_preserves_u8_dtype() {
use burn_backend::ops::BoolTensorOps;
let shape = burn_std::Shape::from(alloc::vec![3]);
let t = Flex::bool_empty(shape, &FlexDevice, burn_std::BoolDType::U8);
assert_eq!(t.dtype(), DType::Bool(BoolStore::U8));
}
#[test]
fn device_prints_as_cpu() {
use alloc::format;
assert_eq!(format!("{:?}", FlexDevice), "Cpu");
assert_eq!(format!("{}", FlexDevice), "Cpu");
}
#[test]
fn comparison_preserves_out_dtype_native() {
let lhs = FlexTensor::from_data(burn_backend::TensorData::from([1.0f32, 2.0, 3.0]));
let rhs = FlexTensor::from_data(burn_backend::TensorData::from([2.0f32, 2.0, 1.0]));
let result = crate::ops::comparison::greater(lhs, rhs, burn_std::BoolDType::Native);
assert_eq!(result.dtype(), DType::Bool(BoolStore::Native));
}
#[test]
fn comparison_preserves_out_dtype_u8() {
let lhs = FlexTensor::from_data(burn_backend::TensorData::from([1.0f32, 2.0, 3.0]));
let rhs = FlexTensor::from_data(burn_backend::TensorData::from([2.0f32, 2.0, 1.0]));
let result = crate::ops::comparison::greater(lhs, rhs, burn_std::BoolDType::U8);
assert_eq!(result.dtype(), DType::Bool(BoolStore::U8));
}
#[test]
#[should_panic(expected = "Bool(U32)")]
fn comparison_u32_panics() {
let lhs = FlexTensor::from_data(burn_backend::TensorData::from([1.0f32, 2.0]));
let rhs = FlexTensor::from_data(burn_backend::TensorData::from([2.0f32, 1.0]));
let _ = crate::ops::comparison::greater(lhs, rhs, burn_std::BoolDType::U32);
}
#[test]
fn bool_not_preserves_u8_dtype() {
use burn_backend::ops::BoolTensorOps;
let t_u8 = crate::ops::comparison::make_bool_tensor(
alloc::vec![1, 0, 1],
burn_std::Shape::from(alloc::vec![3]),
burn_std::BoolDType::U8,
);
let result = Flex::bool_not(t_u8);
assert_eq!(result.dtype(), DType::Bool(BoolStore::U8));
let data: &[u8] = result.bytes();
assert_eq!(&data[..3], &[0, 1, 0]);
}
}